{ "cells": [ { "cell_type": "markdown", "id": "3027c499-28fd-403d-88df-f354af5433e4", "metadata": {}, "source": [ "# Self Attention - Another implementation" ] }, { "cell_type": "code", "execution_count": 1, "id": "75b4980e-64f1-4cf3-8061-522a4e96e758", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import math\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 2, "id": "efcec0f1-2452-46fa-badd-afe28405f357", "metadata": {}, "outputs": [], "source": [ "sequence = torch.tensor([[1,2,3]])" ] }, { "cell_type": "code", "execution_count": 3, "id": "b7d58632-b0ca-4d0d-aca4-9b7d83b1d721", "metadata": {}, "outputs": [], "source": [ "embed_dim=2" ] }, { "cell_type": "code", "execution_count": 4, "id": "5ffce5d1-18f3-412a-8bbb-3c7bf2e96b73", "metadata": {}, "outputs": [], "source": [ "embed = torch.nn.Embedding(4, embed_dim, 0)" ] }, { "cell_type": "code", "execution_count": 5, "id": "7d1cf363-5617-4a3a-96d2-d2e9db468c74", "metadata": {}, "outputs": [], "source": [ "embedded_tokens = embed(sequence)" ] }, { "cell_type": "code", "execution_count": 6, "id": "4f4a5e54-1ab1-4915-a3c4-4240cb363c8f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[-1.2221, -0.5735],\n", " [-1.1988, -0.3782],\n", " [-1.4229, 1.7899]]], grad_fn=)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "embedded_tokens" ] }, { "cell_type": "code", "execution_count": 7, "id": "b1324414-9cc2-41c2-8853-4cd4dcaab03d", "metadata": {}, "outputs": [], "source": [ "input_dim=2\n", "embed_dim=2" ] }, { "cell_type": "code", "execution_count": 8, "id": "106ca1fa-cfed-4636-9469-2c27e47362de", "metadata": {}, "outputs": [], "source": [ "QKV_transformation = nn.Linear(input_dim, 3*embed_dim)" ] }, { "cell_type": "code", "execution_count": 9, "id": "27c43c9d-b01a-4a44-9a88-c4b95968b793", "metadata": {}, "outputs": [], "source": [ "QKV = QKV_transformation(embedded_tokens)" ] }, { "cell_type": "code", "execution_count": 10, "id": "5c9da220-4842-4bdc-9436-60c4a19cedf8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 0.2566, 0.5266, 1.0491, -0.3373, -0.3685, -0.5847],\n", " [ 0.1248, 0.4105, 1.0927, -0.2389, -0.2934, -0.4745],\n", " [-1.1946, -0.8623, 1.7932, 1.0638, 0.6026, 0.7792]]],\n", " grad_fn=)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "QKV" ] }, { "cell_type": "code", "execution_count": 11, "id": "9d11fecd-6210-42b5-a79c-96fffd59b354", "metadata": {}, "outputs": [], "source": [ "Q, K, V = QKV.chunk(3, dim=-1)" ] }, { "cell_type": "code", "execution_count": 12, "id": "41e44ffb-e8dd-43a3-ac45-43622b56cd43", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 0.2566, 0.5266],\n", " [ 0.1248, 0.4105],\n", " [-1.1946, -0.8623]]], grad_fn=)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Q" ] }, { "cell_type": "code", "execution_count": 13, "id": "e82d031e-3544-4baa-a80f-b5c08d639493", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 1.0491, -0.3373],\n", " [ 1.0927, -0.2389],\n", " [ 1.7932, 1.0638]]], grad_fn=)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "K" ] }, { "cell_type": "code", "execution_count": 14, "id": "510bab32-cdc3-46a1-94aa-3c53cb9eb223", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[-0.3685, -0.5847],\n", " [-0.2934, -0.4745],\n", " [ 0.6026, 0.7792]]], grad_fn=)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "V" ] }, { "cell_type": "code", "execution_count": 15, "id": "52042eae-22b4-422b-9314-03f5b3cfc2c0", "metadata": {}, "outputs": [], "source": [ "attn_logits = torch.matmul(Q[0], K[0].T)" ] }, { "cell_type": "code", "execution_count": 16, "id": "574f5490-925b-40b9-ad77-76ee4aedcb38", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.0916, 0.1546, 1.0204],\n", " [-0.0075, 0.0384, 0.6606],\n", " [-0.9624, -1.0994, -3.0595]], grad_fn=)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "attn_logits" ] }, { "cell_type": "code", "execution_count": 17, "id": "495b3712-da39-4096-9798-481a69e7bf7d", "metadata": {}, "outputs": [], "source": [ "d_k = Q[0].size()[-1]" ] }, { "cell_type": "code", "execution_count": 18, "id": "4376eb9c-ff96-4a04-9e50-28e0d0eb5dc7", "metadata": {}, "outputs": [], "source": [ "scaled_attention_logits = torch.matmul(Q[0], K[0].T)/ math.sqrt(d_k)" ] }, { "cell_type": "code", "execution_count": 19, "id": "1d85cac0-9daf-4aae-ba6e-af96e5e0b836", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.0648, 0.1093, 0.7215],\n", " [-0.0053, 0.0271, 0.4671],\n", " [-0.6805, -0.7774, -2.1634]], grad_fn=)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scaled_attention_logits" ] }, { "cell_type": "code", "execution_count": 20, "id": "3c5ce1c7-442a-4d5a-9f96-14a28264526d", "metadata": {}, "outputs": [], "source": [ "attention_weights = F.softmax(scaled_attention_logits, dim=-1)" ] }, { "cell_type": "code", "execution_count": 21, "id": "27fa947c-217c-412e-827e-dab5d0f29c95", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.2516, 0.2631, 0.4853],\n", " [0.2750, 0.2840, 0.4410],\n", " [0.4685, 0.4252, 0.1063]], grad_fn=)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "attention_weights" ] }, { "cell_type": "code", "execution_count": 22, "id": "77f1d84e-6747-4b57-9748-dfe9dc9d7223", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.1225, 0.1061],\n", " [ 0.0810, 0.0481],\n", " [-0.2333, -0.3928]], grad_fn=)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "attention_weights @ V[0]" ] }, { "cell_type": "code", "execution_count": 23, "id": "90dd2960-3856-436b-98ec-4accc379979e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.1225, 0.1061],\n", " [ 0.0810, 0.0481],\n", " [-0.2333, -0.3928]], grad_fn=)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.matmul(attention_weights, V[0])" ] }, { "cell_type": "code", "execution_count": null, "id": "aaf70c34-5ba2-477d-8622-8e1fe4ec918f", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "92befed5-07e3-4228-bb35-eb1c3d8318b5", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "96d1732a-3533-4450-9e24-97fa06a49a5f", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "0312a144-55d1-4e76-be99-3b773312e786", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "p11", "language": "python", "name": "p11" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }