{ "cells": [ { "cell_type": "markdown", "id": "5e16155e-d194-4b24-a9f3-3d1d6d4693ef", "metadata": {}, "source": [ "# Self Attention - Basics" ] }, { "cell_type": "code", "execution_count": 1, "id": "4787ffc8-ac76-460b-9ce5-e46e2d4c5755", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import math\n", "import torch.nn.functional as F\n", "import torch\n", "import torch.nn as nn" ] }, { "cell_type": "markdown", "id": "8b94ec84-15e1-4776-b9f5-52589fc5b644", "metadata": {}, "source": [ "![title](assets/self_attention.png)" ] }, { "cell_type": "markdown", "id": "93019565", "metadata": {}, "source": [ "This is how attetion is described in the [Attention is All You Need](https://arxiv.org/abs/1706.03762) paper.
\n", "\n", "\"An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.\"\n", "\n", "-- Attention Is All You Need, Vaswani et. al." ] }, { "cell_type": "markdown", "id": "d119281c-42f6-4361-8141-e7b7442eb51d", "metadata": {}, "source": [ "#### Let's understand what happens in the self-attention mechanism" ] }, { "cell_type": "markdown", "id": "8d20a6a2-d6f6-4a35-a037-51306db0df08", "metadata": {}, "source": [ "The output of applying attention to a sequence of tokens is given by," ] }, { "cell_type": "markdown", "id": "cd1a78ad", "metadata": {}, "source": [ "$Attention(Q,K,V) = softmax( \\frac{QK^{T}} {\\sqrt{d_{k}}} )V$" ] }, { "cell_type": "markdown", "id": "fa789f0d", "metadata": {}, "source": [ "Let's suppose that we have a sentence of three words. For example 'sky is blue'. Let's also assume that these three words are represented using three numbers, 1, 2 and 3. These are called tokens. Therefore, our sequence of tokens is given by the tensor, $[[1,2,3]]$. Tokenization is the process of breaking a sequence of characters into different parts. There are many tokenization methods. What I have done here is a very simple tokenization method for the sake of easliy explaining the concepts of the attention mechanism." ] }, { "cell_type": "code", "execution_count": 2, "id": "cf8ce180", "metadata": {}, "outputs": [], "source": [ "sequence = torch.tensor([[1,2,3]])" ] }, { "cell_type": "code", "execution_count": 3, "id": "98d136b9-b6db-42ad-96c2-6f84d126b63d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1, 2, 3]])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sequence" ] }, { "cell_type": "markdown", "id": "55ddbac1-815c-4c83-89cf-24acc5987ca5", "metadata": {}, "source": [ "In this case, our `sequence length` is 3." ] }, { "cell_type": "markdown", "id": "21b6ca82-e4c7-427b-8404-3c07e0505499", "metadata": {}, "source": [ "Tokenization alone is not enough to create a information-rich representation of a character sequence. We will have to come up with a vector representation for each token. These are called embeddings." ] }, { "cell_type": "markdown", "id": "a6550cb4-55d0-431d-99e0-9de98804f477", "metadata": {}, "source": [ "let's assume that the embedding vectors of 1,2 and 3 are as follows." ] }, { "cell_type": "code", "execution_count": 4, "id": "7dddf9ce-6d6e-49f6-b9ba-4f18469ede60", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{1: [-1.072, -0.5001], 2: [-0.012, -0.4311], 3: [-0.005, -0.5321]}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "embeddings = {1: [-1.0720, -0.5001], 2: [-0.0120, -0.4311], 3: [-0.0050, -0.5321]}\n", "embeddings" ] }, { "cell_type": "markdown", "id": "7ad64433-e455-4d5f-b778-b65e16bcd062", "metadata": {}, "source": [ "Therefore, our embedded tokens are," ] }, { "cell_type": "code", "execution_count": 5, "id": "4c0e9092", "metadata": {}, "outputs": [], "source": [ "embedded_tokens = torch.tensor([[-1.0720, -0.5001], [-0.0120, -0.4311], [-0.0050, -0.5321] ])" ] }, { "cell_type": "code", "execution_count": 6, "id": "5c1fd4b6-d734-4769-bf90-9219170b6747", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-1.0720, -0.5001],\n", " [-0.0120, -0.4311],\n", " [-0.0050, -0.5321]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "embedded_tokens" ] }, { "cell_type": "code", "execution_count": null, "id": "6765fcc5-68d0-40af-940f-ca4df8e1d7cf", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "e7023e59-e47e-4011-b294-fbe1b438967b", "metadata": {}, "source": [ "As the next step, let's create the Q,K,V matrices. This is done by transforming our embedding matrix using three weight vectors. WQ, WK and WV." ] }, { "cell_type": "markdown", "id": "08aff325-6ca7-4d62-972a-7184c34cfce3", "metadata": {}, "source": [ "let's cosinder the following WQ, WK and WV matrices." ] }, { "cell_type": "code", "execution_count": 7, "id": "cea2dcd5-ab1a-47de-93f1-94e0892401d1", "metadata": {}, "outputs": [], "source": [ "WQ = torch.tensor([[-0.0271, -0.3840],\n", " [-0.3940, -0.6610]]).requires_grad_(True)\n", "\n", "WK = torch.tensor([[-0.4109, 0.5777],\n", " [-0.1162, -0.1661]]).requires_grad_(True)\n", "\n", "WV = torch.tensor([[-0.2045, 0.1210],\n", " [-0.1712, -0.4462]]).requires_grad_(True)" ] }, { "cell_type": "code", "execution_count": null, "id": "da77a2a9-af6c-4510-8b14-06938ce4a320", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "4af20efa-8a2d-4305-b4eb-e2bc3c7b9653", "metadata": {}, "source": [ "Now, let's transform our embedded tokens using the transformation matrices WQ, WK and WV to obtain Q, K, and V tensors respectively." ] }, { "cell_type": "code", "execution_count": 8, "id": "9c9e25da-d436-40da-b0d6-644991d6e671", "metadata": {}, "outputs": [], "source": [ "Q = embedded_tokens @ WQ\n", "K = embedded_tokens @ WK\n", "V = embedded_tokens @ WV" ] }, { "cell_type": "code", "execution_count": 9, "id": "32fc41a5-2b0c-4a22-b9e6-95f7f4e33a66", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.2261, 0.7422],\n", " [0.1702, 0.2896],\n", " [0.2098, 0.3536]], grad_fn=)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Q" ] }, { "cell_type": "code", "execution_count": 10, "id": "ffaf9dd2-1980-417f-82cf-dda918a5f6de", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.4986, -0.5362],\n", " [ 0.0550, 0.0647],\n", " [ 0.0639, 0.0855]], grad_fn=)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "K" ] }, { "cell_type": "code", "execution_count": 11, "id": "bad769b3-08c1-4137-b09b-32ba3cf25317", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.3048, 0.0934],\n", " [0.0763, 0.1909],\n", " [0.0921, 0.2368]], grad_fn=)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "V" ] }, { "cell_type": "markdown", "id": "7250ac3a-7667-4374-9054-7f29d8e571c1", "metadata": {}, "source": [ "What do Q,K and V represent? Well, their roles will become more apparent later. But for now, they can be considered as different feature matrices of our sequence, 'sky is blue'. This is because all we have done to obtain Q, K and V is to matrix multiply the token embeddings using three randomly initialized matrices." ] }, { "cell_type": "code", "execution_count": null, "id": "966905fe-f1aa-4d9d-8844-ee53a912a7f1", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "9e8f6e05-6f46-45b5-ab53-a66a8fedfda4", "metadata": {}, "source": [ "Below, I have written the Q, K, V matrices using their components to highlight what they contain." ] }, { "cell_type": "markdown", "id": "e84d6987-8e82-448e-b23f-50abddd49283", "metadata": {}, "source": [ "$Q$ can be considered as consisting of three feature vectors. One for each of the tokens in our sequence. That is, $\\vec{Q_{1}}$ is a feature vector for token $1$, or for the word $sky$. $\\vec{Q_{2}}$ is a feature vector for token $2$, or for the word $is$. $\\vec{Q_{3}}$ is a feature vector for token $3$, or for the word $blue$." ] }, { "cell_type": "markdown", "id": "120eb74c-7034-4854-93a1-37c56bcfabc3", "metadata": {}, "source": [ "$$ Q = \\begin{bmatrix} Q_{1,1} & Q_{1,2} \\\\ Q_{2,1} & Q_{2,2} \\\\ Q_{3,1} & Q_{3,2} \\end{bmatrix}$$" ] }, { "cell_type": "markdown", "id": "97773379-37de-477f-b75a-973c03617aed", "metadata": {}, "source": [ "$$ Q = \\begin{bmatrix} \\vec{Q_{1}} \\\\ \\vec{Q_{2}} \\\\ \\vec{Q_{3}} \\end{bmatrix}$$" ] }, { "cell_type": "markdown", "id": "fcb7ea30-3e2f-4ac0-b289-5d5b8369a915", "metadata": {}, "source": [ "Similarly for the $K$ matrix. $\\vec{K_{1}}$ is a feature vector for token $1$, or for the word $sky$. $\\vec{K_{2}}$ is a feature vector for token $2$, or for the word $is$. $\\vec{K_{3}}$ is a feature vector for token $3$, or for the word $blue$." ] }, { "cell_type": "markdown", "id": "3be5011c-5842-4b69-929c-c5478b1a2035", "metadata": {}, "source": [ "$$ K = \\begin{bmatrix} K_{1,1} & K_{1,2} \\\\ K_{2,1} & K_{2,2} \\\\ K_{3,1} & K_{3,2} \\end{bmatrix}$$" ] }, { "cell_type": "markdown", "id": "16658adc-c4d0-4c0d-b130-f1008f9865d7", "metadata": {}, "source": [ "$$ K^{T} = \\begin{bmatrix} \\vec{K_{1}} \\\\ \\vec{K_{2}} \\\\ \\vec{K_{3}} \\end{bmatrix}$$" ] }, { "cell_type": "markdown", "id": "72e047c5-0ccc-4242-b25d-970a11f4c22c", "metadata": {}, "source": [ "$$ K^{T} = \\begin{bmatrix} K_{1,1} & K_{2,1} \\\\ K_{1,2} & K_{2,2} \\\\ K_{1,3} & K_{2,3} \\end{bmatrix}$$" ] }, { "cell_type": "markdown", "id": "de20eb10-10a8-4dc1-93a0-8620d2fe46b4", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "f2aa0726-5d59-45ac-850e-bdbac03635bf", "metadata": {}, "source": [ "Similarly, $V$ is another feature matrix." ] }, { "cell_type": "markdown", "id": "56bdba19-c34b-481e-9cf3-04c9292f07c2", "metadata": {}, "source": [ "$$ V = \\begin{bmatrix} V_{1,1} & V_{1,2} \\\\ V_{2,1} & V_{2,2} \\\\ V_{3,1} & V_{3,2} \\end{bmatrix}$$" ] }, { "cell_type": "markdown", "id": "28c20358-96dc-48e3-acf6-3b8b7f39f6b2", "metadata": {}, "source": [ "$$ V = \\begin{bmatrix} \\vec{V_{1}} \\\\ \\vec{V_{2}} \\\\ \\vec{V{3}} \\end{bmatrix}$$" ] }, { "cell_type": "markdown", "id": "a7f2f239-ca78-4b0c-9aec-3ec0f65370c9", "metadata": {}, "source": [ "Attention logits in the attention mechanism are given as $QK^{T}$." ] }, { "cell_type": "markdown", "id": "269fc364-8c22-4217-8957-702600226885", "metadata": {}, "source": [ "$$ QK^{T} = \\begin{bmatrix} Q_{1,1}\\times K_{1,1} + Q_{1,2}\\times K_{1,2} &\n", "Q_{1,1}\\times K_{2,1} + Q_{1,2}\\times K_{2,2} & \n", "Q_{1,1}\\times K_{3,1} + Q_{1,2}\\times K_{3,2} \\\\ \n", "Q_{2,1}\\times K_{1,1} + Q_{2,2}\\times K_{1,2} & \n", "Q_{2,1}\\times K_{2,1} + Q_{2,2}\\times K_{2,2} & \n", "Q_{2,1}\\times K_{3,1} + Q_{2,2}\\times K_{3,2} \\\\\n", "Q_{3,1}\\times K_{1,1} + Q_{3,2}\\times K_{1,2} & \n", "Q_{3,1}\\times K_{2,1} + Q_{3,2}\\times K_{2,2} & \n", "Q_{3,1}\\times K_{3,1} + Q_{3,2}\\times K_{3,2} \\end{bmatrix}$$" ] }, { "cell_type": "markdown", "id": "00c42922-0432-412b-93ac-184cc31ecfdb", "metadata": {}, "source": [ "Equivalently," ] }, { "cell_type": "markdown", "id": "4ae4f5b9-a138-4438-b819-b25680a8f313", "metadata": {}, "source": [ "$$ QK^{T} = \\begin{bmatrix} \\vec{Q_{1}}\\cdot \\vec{K_{1}} & \n", "\\vec{Q_{1}}\\cdot \\vec{K_{2}} & \n", "\\vec{Q_{1}}\\cdot \\vec{K_{3}} \\\\ \n", "\\vec{Q_{2}}\\cdot \\vec{K_{1}} & \n", "\\vec{Q_{2}}\\cdot \\vec{K_{2}} & \n", "\\vec{Q_{2}}\\cdot \\vec{K_{3}} \\\\\n", "\\vec{Q_{3}}\\cdot \\vec{K_{1}} & \n", "\\vec{Q_{3}}\\cdot \\vec{K_{2}} & \n", "\\vec{Q_{3}}\\cdot \\vec{K_{3}} \\end{bmatrix}$$" ] }, { "cell_type": "markdown", "id": "424af669-dfd9-4186-849d-22cf5f3c97a7", "metadata": {}, "source": [ "What does this matrix gives us? Notice that the $1^{st}$ row of $QK^{T}$ consists of the dot products of $Q_{1}$ feature with respect to all the $K$ features. The value of the dot product indicates how similar two vectors are. Therefore, the $1^{st}$ row shows us how similar $Q_{1}$ feature to all the $K$ features. $2^{nd}$ row shows us how similar $Q_{2}$ feature to all the K features. 3rd row shows us how similar $Q_{3}$ feature to all the K features. " ] }, { "cell_type": "markdown", "id": "c8fd5141-5b1d-407b-b82d-28536b1407e1", "metadata": {}, "source": [ "Let's obtain the attention logits using the actual values." ] }, { "cell_type": "code", "execution_count": 12, "id": "03f96b33-4adc-48c0-983f-c47f3a67ac9b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.2853, 0.0604, 0.0779],\n", " [-0.0704, 0.0281, 0.0356],\n", " [-0.0850, 0.0344, 0.0436]], grad_fn=)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "attn_logits = torch.matmul(Q, K.T)\n", "attn_logits" ] }, { "attachments": {}, "cell_type": "markdown", "id": "788f8c16-6efc-4ac7-b214-b1396619a02e", "metadata": {}, "source": [ "Again, each element in the `attn_logits` matrix is a measure of how similar two $\\vec{Q}$ and $\\vec{K}$ vectors are.\n", "\n", "Each row of the `attn_logits` matrix shows how similar a query to all the keys. For example, row 1 is => (similarity between Q1 and K1, similarity between Q1 and K2, similarity between Q1 and K3)." ] }, { "cell_type": "code", "execution_count": null, "id": "67822f3c-4f04-4475-b118-af876228d941", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "fe25f7f6-abc5-4dee-8bec-089995d38f16", "metadata": {}, "source": [ "The next step is scaling the `attention_logits` matrix with the scaling factor, $\\sqrt(d_{k}$.\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "b4693965", "metadata": {}, "outputs": [], "source": [ "d_k = Q.size()[-1]" ] }, { "cell_type": "markdown", "id": "96dec1b8-212c-440c-b8cc-29fc565370ec", "metadata": {}, "source": [ "$$ QK^{T}/\\sqrt(d_{k}) = \\begin{bmatrix} \\frac{\\vec{Q_{1}}\\cdot \\vec{K_{1}}}{\\sqrt(d_{k})} & \n", "\\frac{\\vec{Q_{1}}\\cdot \\vec{K_{2}}}{ \\sqrt(d_{k}) } & \n", "\\frac{\\vec{Q_{1}}\\cdot \\vec{K_{3}}}{ \\sqrt(d_{k}) } \\\\ \n", "\\frac{\\vec{Q_{2}}\\cdot \\vec{K_{1}}}{ \\sqrt(d_{k}) } & \n", "\\frac{\\vec{Q_{2}}\\cdot \\vec{K_{2}}}{ \\sqrt(d_{k}) } & \n", "\\frac{\\vec{Q_{2}}\\cdot \\vec{K_{3}}}{ \\sqrt(d_{k}) } \\\\\n", "\\frac{\\vec{Q_{3}}\\cdot \\vec{K_{1}}}{ \\sqrt(d_{k}) } & \n", "\\frac{\\vec{Q_{3}}\\cdot \\vec{K_{2}}}{ \\sqrt(d_{k}) } & \n", "\\frac{\\vec{Q_{3}}\\cdot \\vec{K_{3}}}{ \\sqrt(d_{k}) } \\end{bmatrix}$$" ] }, { "cell_type": "code", "execution_count": 14, "id": "28f7d892-414f-4d6a-b189-c380fb49c5ca", "metadata": {}, "outputs": [], "source": [ "scaled_attention_logits = torch.matmul(Q, K.T)/ math.sqrt(d_k)" ] }, { "cell_type": "code", "execution_count": 15, "id": "61cc140c-e29b-4449-9ffb-1eda8a23c0f5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.2017, 0.0427, 0.0551],\n", " [-0.0498, 0.0199, 0.0252],\n", " [-0.0601, 0.0243, 0.0309]], grad_fn=)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scaled_attention_logits" ] }, { "cell_type": "markdown", "id": "0b433e76-5f77-4695-a90f-031509b19c78", "metadata": {}, "source": [ "Finally, we apply the softmax operation to each row. This normalizes the values of each row so that the sum of these values is 1." ] }, { "cell_type": "markdown", "id": "be3e6192-6540-48a8-aec8-a0c60d88e86e", "metadata": {}, "source": [ "$$ softmax(QK^{T}/\\sqrt(d_{k})) = \\begin{bmatrix} \\frac{exp({\\frac{\\vec{Q_{1}}\\cdot \\vec{K_{1}}}{\\sqrt(d_{k})}}) }{ exp({\\frac{\\vec{Q_{1}}\\cdot \\vec{K_{1}}}{\\sqrt(d_{k})}}) + exp({\\frac{\\vec{Q_{1}}\\cdot \\vec{K_{2}}}{\\sqrt(d_{k})}}) + \\exp({\\frac{\\vec{Q_{1}}\\cdot \\vec{K_{3}}}{\\sqrt(d_{k})}})} & \n", "\\frac{exp(\\frac{\\vec{Q_{1}}\\cdot \\vec{K_{2}}}{ \\sqrt(d_{k}) })}{ exp({\\frac{\\vec{Q_{1}}\\cdot \\vec{K_{1}}}{\\sqrt(d_{k})}}) + exp({\\frac{\\vec{Q_{1}}\\cdot \\vec{K_{2}}}{\\sqrt(d_{k})}}) + \\exp({\\frac{\\vec{Q_{1}}\\cdot \\vec{K_{3}}}{\\sqrt(d_{k})}}) } & \n", "\\frac{exp(\\frac{\\vec{Q_{1}}\\cdot \\vec{K_{3}}}{ \\sqrt(d_{k}) })}{ exp({\\frac{\\vec{Q_{1}}\\cdot \\vec{K_{1}}}{\\sqrt(d_{k})}}) + exp({\\frac{\\vec{Q_{1}}\\cdot \\vec{K_{2}}}{\\sqrt(d_{k})}}) + \\exp({\\frac{\\vec{Q_{1}}\\cdot \\vec{K_{3}}}{\\sqrt(d_{k})}}) } \\\\ \n", "\\frac{exp(\\frac{\\vec{Q_{2}}\\cdot \\vec{K_{1}}}{ \\sqrt(d_{k}) })}{ exp({\\frac{\\vec{Q_{2}}\\cdot \\vec{K_{1}}}{\\sqrt(d_{k})}}) + exp({\\frac{\\vec{Q_{2}}\\cdot \\vec{K_{2}}}{\\sqrt(d_{k})}}) + \\exp({\\frac{\\vec{Q_{2}}\\cdot \\vec{K_{3}}}{\\sqrt(d_{k})}}) } & \n", "\\frac{exp(\\frac{\\vec{Q_{2}}\\cdot \\vec{K_{2}}}{ \\sqrt(d_{k}) })}{ exp({\\frac{\\vec{Q_{2}}\\cdot \\vec{K_{1}}}{\\sqrt(d_{k})}}) + exp({\\frac{\\vec{Q_{2}}\\cdot \\vec{K_{2}}}{\\sqrt(d_{k})}}) + \\exp({\\frac{\\vec{Q_{2}}\\cdot \\vec{K_{3}}}{\\sqrt(d_{k})}}) } & \n", "\\frac{exp(\\frac{\\vec{Q_{2}}\\cdot \\vec{K_{3}}}{ \\sqrt(d_{k}) })}{ exp({\\frac{\\vec{Q_{2}}\\cdot \\vec{K_{1}}}{\\sqrt(d_{k})}}) + exp({\\frac{\\vec{Q_{2}}\\cdot \\vec{K_{2}}}{\\sqrt(d_{k})}}) + \\exp({\\frac{\\vec{Q_{2}}\\cdot \\vec{K_{3}}}{\\sqrt(d_{k})}}) } \\\\\n", "\\frac{exp(\\frac{\\vec{Q_{3}}\\cdot \\vec{K_{1}}}{ \\sqrt(d_{k}) })}{ exp({\\frac{\\vec{Q_{3}}\\cdot \\vec{K_{1}}}{\\sqrt(d_{k})}}) + exp({\\frac{\\vec{Q_{3}}\\cdot \\vec{K_{2}}}{\\sqrt(d_{k})}}) + \\exp({\\frac{\\vec{Q_{3}}\\cdot \\vec{K_{3}}}{\\sqrt(d_{k})}}) } & \n", "\\frac{exp(\\frac{\\vec{Q_{3}}\\cdot \\vec{K_{2}}}{ \\sqrt(d_{k}) })}{ exp({\\frac{\\vec{Q_{3}}\\cdot \\vec{K_{1}}}{\\sqrt(d_{k})}}) + exp({\\frac{\\vec{Q_{3}}\\cdot \\vec{K_{2}}}{\\sqrt(d_{k})}}) + \\exp({\\frac{\\vec{Q_{3}}\\cdot \\vec{K_{3}}}{\\sqrt(d_{k})}}) } & \n", "\\frac{exp(\\frac{\\vec{Q_{3}}\\cdot \\vec{K_{3}}}{ \\sqrt(d_{k}) })}{ exp({\\frac{\\vec{Q_{3}}\\cdot \\vec{K_{1}}}{\\sqrt(d_{k})}}) + exp({\\frac{\\vec{Q_{3}}\\cdot \\vec{K_{2}}}{\\sqrt(d_{k})}}) + \\exp({\\frac{\\vec{Q_{3}}\\cdot \\vec{K_{3}}}{\\sqrt(d_{k})}}) } \\end{bmatrix}$$" ] }, { "cell_type": "code", "execution_count": 16, "id": "0fd5a91f-5c31-4661-8717-1cdd3e1671cc", "metadata": {}, "outputs": [], "source": [ "attention_weights = F.softmax(scaled_attention_logits, dim=-1)" ] }, { "cell_type": "code", "execution_count": 17, "id": "61cb7e78-2659-4051-84c9-8c5423792360", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.2801, 0.3577, 0.3622],\n", " [0.3175, 0.3404, 0.3422],\n", " [0.3141, 0.3418, 0.3441]], grad_fn=)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "attention_weights" ] }, { "cell_type": "markdown", "id": "63c54673-add7-4173-877f-c25dccc60ffa", "metadata": {}, "source": [ "Note that values in each row add up to 1." ] }, { "cell_type": "code", "execution_count": 18, "id": "e20f93e6-680c-4a83-bb39-6adf5dcd6976", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(1., grad_fn=)\n", "tensor(1., grad_fn=)\n", "tensor(1.0000, grad_fn=)\n" ] } ], "source": [ "print(attention_weights[0].sum())\n", "print(attention_weights[1].sum())\n", "print(attention_weights[2].sum())" ] }, { "cell_type": "markdown", "id": "f73de8bd-09cf-4edd-b893-ee44a51ca1d0", "metadata": {}, "source": [ "Therefore, each element in a given row can be considered as normalized weight values." ] }, { "cell_type": "code", "execution_count": null, "id": "c6731b50-f544-483d-bfd8-c6be7d4f08fa", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "d34ac2d5-0e68-436f-b2bf-758291998ae1", "metadata": {}, "source": [ "output values" ] }, { "cell_type": "code", "execution_count": 19, "id": "9af3135b-4a41-47cf-b1e7-970dc33d9907", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0.1460, 0.1802],\n", " [0.1543, 0.1757],\n", " [0.1535, 0.1761]], grad_fn=)\n" ] } ], "source": [ "O = attention_weights @ V\n", "\n", "print(O)" ] }, { "cell_type": "markdown", "id": "b56f6bec-ed0c-4713-8f9d-9c8739ea8944", "metadata": {}, "source": [ "$$ O = \\begin{bmatrix} O_{1,1} & O_{1,2} \\\\ O_{2,1} & O_{2,2} \\\\ O_{3,1} & O_{3,2} \\end{bmatrix}$$" ] }, { "cell_type": "markdown", "id": "604c6f37-c00f-40af-b602-386399b09a65", "metadata": {}, "source": [ "let's try to understand the meaning of the output values." ] }, { "cell_type": "markdown", "id": "362e10b0-4f34-4148-ab0e-96a427dd6502", "metadata": {}, "source": [ "Below is how we obtained the $O_{1,1}$ output values matrix." ] }, { "cell_type": "markdown", "id": "2f9778db-39ea-4994-84ca-317e7f000aa4", "metadata": {}, "source": [ "$O_{1,1}$ = $\\vec{Q_{1}}$,$\\vec{K_{1}}$ similarity * $V_{1,1}$ + \n", "$\\vec{Q_{1}}$,$\\vec{K_{2}}$ similarity * $V_{2,1}$ +\n", "$\\vec{Q_{1}}$,$\\vec{K_{3}}$ similarity * $V_{3,1}$" ] }, { "cell_type": "markdown", "id": "297333d2-43a1-4613-9207-8e7a80719c42", "metadata": {}, "source": [ "$O_{1,2}$ = $\\vec{Q_{1}}$,$\\vec{K_{1}}$ similarity * $V_{1,2}$ + \n", "$\\vec{Q_{1}}$,$\\vec{K_{2}}$ similarity * $V_{2,2}$ +\n", "$\\vec{Q_{1}}$,$\\vec{K_{3}}$ similarity * $V_{3,2}$" ] }, { "cell_type": "markdown", "id": "7dae8bfd-dd4a-4f14-88cb-8c83e0c153e0", "metadata": {}, "source": [ "$O_{2,1}$ = $\\vec{Q_{2}}$,$\\vec{K_{1}}$ similarity * $V_{1,1}$ + \n", "$\\vec{Q_{2}}$,$\\vec{K_{2}}$ similarity * $V_{2,1}$ +\n", "$\\vec{Q_{2}}$,$\\vec{K_{3}}$ similarity * $V_{3,1}$" ] }, { "cell_type": "markdown", "id": "bdf7e8ed-45a4-4523-9dd5-7b6b30b54a11", "metadata": {}, "source": [ "$O_{2,2}$ = $\\vec{Q_{2}}$,$\\vec{K_{1}}$ similarity * $V_{1,2}$ + \n", "$\\vec{Q_{2}}$,$\\vec{K_{2}}$ similarity * $V_{2,2}$ +\n", "$\\vec{Q_{2}}$,$\\vec{K_{3}}$ similarity * $V_{3,2}$" ] }, { "cell_type": "markdown", "id": "964070a8-7aca-4d30-b20b-b227ef6be988", "metadata": {}, "source": [ "Think of this as obtaining a weighted feature values. $\\vec{Q_{i}}.\\vec{K_{j}}$ are the weights. $V_{j}$ are the input features. $O_{j}$ are the output features." ] }, { "cell_type": "code", "execution_count": null, "id": "05a2bffa-4fac-4e73-9e41-9fab137604ee", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "e042a9fc-6519-4d97-8790-ca98558b2c46", "metadata": {}, "source": [ "Let's try to demistify this further." ] }, { "cell_type": "markdown", "id": "8d6754b7-7967-45fa-b777-6697632659a1", "metadata": {}, "source": [ "$O_{1,1}$ = ($\\vec{Q_{1}}$,$\\vec{K_{1}}$ similarity $\\times$ $1^{st}$ feature of token 1) + \n", "($\\vec{Q_{1}}$,$\\vec{K_{2}}$ similarity $\\times$ $1^{st}$ feature of token 2) +\n", "($\\vec{Q_{1}}$,$\\vec{K_{3}}$ similarity $\\times$ $1^{st}$ feature of token 3)" ] }, { "cell_type": "markdown", "id": "f39b9865-e22d-4cf9-8d06-5e0095cad7ee", "metadata": {}, "source": [ "$O_{2,1}$ = ($\\vec{Q_{2}}$,$\\vec{K_{1}}$ similarity $\\times$ $1^{st}$ feature of token 1) + \n", "($\\vec{Q_{2}}$,$\\vec{K_{2}}$ similarity $\\times$ $1^{st}$ feature of token 2) +\n", "($\\vec{Q_{2}}$,$\\vec{K_{3}}$ similarity $\\times$ $1^{st}$ feature of token 3)" ] }, { "cell_type": "markdown", "id": "5d04f91d-e508-4699-864f-f2c261a828af", "metadata": {}, "source": [ "Notice that the index of the token is equal to the index of the Key in each \"$\\vec{Q}$, $\\vec{K}$ $\\times$ feature token\" part above." ] }, { "cell_type": "markdown", "id": "79a383df-2275-418f-b56a-18303295143c", "metadata": {}, "source": [ "We get large values for elements in $\\vec{O_{1}}$, when any of $\\vec{K_{1}}$, $\\vec{K_{2}}$, or $\\vec{K_{3}}$ has a large dot product with $\\vec{Q_{1}}$.\n", "\n", "We get large values for elements in $\\vec{O_{2}}$, when any of $\\vec{K_{1}}$, $\\vec{K_{2}}$, or $\\vec{K_{3}}$ has a large dot product with $\\vec{Q_{2}}$." ] }, { "cell_type": "code", "execution_count": null, "id": "db3996c9-a64b-4678-a956-a7d8d49bb179", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "d44828b4-5973-4da4-aaff-7997671f1c09", "metadata": {}, "source": [ "$\\vec{Q_{i}}$ can be thought of as containing information about what each token is looking for. $\\vec{K_{j}}$ is about what information each token has. We calculate dot products between $\\vec{Q_{i}}$ and $\\vec{K_{j}}$ to figure out how closely related they are, and use them as weights to muliply our feature vectors, $\\vec{V_{j}}$ to get updated feature vectors $\\vec{O_{j}}$." ] }, { "cell_type": "code", "execution_count": null, "id": "fcb15fb4-4b54-46d7-8af1-ece8701be405", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "a1b610e7-15b9-49a2-a225-d0a5837a0c31", "metadata": {}, "source": [ "Using the words of our sequence as the indices," ] }, { "cell_type": "markdown", "id": "68acfa7b-900d-4dd8-967b-0d99e236992d", "metadata": {}, "source": [ "$O_{sky,1}$ = $Q_{sky}$,$K_{sky}$ similarity * $V_{sky,1}$ + \n", "$Q_{sky}$,$K_{is}$ similarity * $V_{is,1}$ +\n", "$Q_{sky}$,$K_{blue}$ similarity * $V_{blue,1}$" ] }, { "cell_type": "markdown", "id": "6c30fb18-1cf0-4955-8eb4-e37fec3c5863", "metadata": {}, "source": [ "$O_{sky,2}$ = $Q_{sky}$,$K_{sky}$ similarity * $V_{sky,2}$ + \n", "$Q_{sky}$,$K_{is}$ similarity * $V_{is,2}$ +\n", "$Q_{sky}$,$K_{blue}$ similarity * $V_{blue,2}$" ] }, { "cell_type": "markdown", "id": "d0cd3337-c3ce-4075-86e3-e7dbd51c1125", "metadata": {}, "source": [ "$O_{is,1}$ = $\\vec{Q_{is}}$,$\\vec{K_{sky}}$ similarity * $V_{sky,1}$ + \n", "$\\vec{Q_{is}}$,$\\vec{K_{is}}$ similarity * $V_{is,1}$ +\n", "$\\vec{Q_{is}}$,$\\vec{K_{blue}}$ similarity * $V_{blue,1}$" ] }, { "cell_type": "markdown", "id": "1a9cc69e-7c51-4d29-a230-5c96a768f67d", "metadata": {}, "source": [ "$O_{is,2}$ = $\\vec{Q_{is}}$,$\\vec{K_{sky}}$ similarity * $V_{sky,2}$ + \n", "$\\vec{Q_{is}}$,$\\vec{K_{is}}$ similarity * $V_{is,2}$ +\n", "$\\vec{Q_{is}}$,$\\vec{K_{blue}}$ similarity * $V_{blue,2}$" ] }, { "cell_type": "code", "execution_count": null, "id": "7f6286ad-6e13-4dba-a376-127a10418b37", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "38c84e9d-a114-4440-a0df-7b8a01e9ff20", "metadata": {}, "source": [ "---\n", "\n", "[![Ask questions](https://img.shields.io/static/v1.svg?logo=star&label=❔&message=Ask%20Questions&color=9cf)](https://github.com/gihanpanapitiya/deep_learning/issues) If you have any questions, please add an issue on GitHub. \n", "\n", "---" ] }, { "cell_type": "code", "execution_count": null, "id": "1dbdc51a-3f33-42b8-820e-d7a36caedb7f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }