Self Attention - Another implementation

[1]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
[2]:
sequence =  torch.tensor([[1,2,3]])
[3]:
embed_dim=2
[4]:
embed = torch.nn.Embedding(4, embed_dim, 0)
[5]:
embedded_tokens = embed(sequence)
[6]:
embedded_tokens
[6]:
tensor([[[-1.2221, -0.5735],
         [-1.1988, -0.3782],
         [-1.4229,  1.7899]]], grad_fn=<EmbeddingBackward0>)
[7]:
input_dim=2
embed_dim=2
[8]:
QKV_transformation = nn.Linear(input_dim, 3*embed_dim)
[9]:
QKV = QKV_transformation(embedded_tokens)
[10]:
QKV
[10]:
tensor([[[ 0.2566,  0.5266,  1.0491, -0.3373, -0.3685, -0.5847],
         [ 0.1248,  0.4105,  1.0927, -0.2389, -0.2934, -0.4745],
         [-1.1946, -0.8623,  1.7932,  1.0638,  0.6026,  0.7792]]],
       grad_fn=<ViewBackward0>)
[11]:
Q, K, V = QKV.chunk(3, dim=-1)
[12]:
Q
[12]:
tensor([[[ 0.2566,  0.5266],
         [ 0.1248,  0.4105],
         [-1.1946, -0.8623]]], grad_fn=<SplitBackward0>)
[13]:
K
[13]:
tensor([[[ 1.0491, -0.3373],
         [ 1.0927, -0.2389],
         [ 1.7932,  1.0638]]], grad_fn=<SplitBackward0>)
[14]:
V
[14]:
tensor([[[-0.3685, -0.5847],
         [-0.2934, -0.4745],
         [ 0.6026,  0.7792]]], grad_fn=<SplitBackward0>)
[15]:
attn_logits = torch.matmul(Q[0], K[0].T)
[16]:
attn_logits
[16]:
tensor([[ 0.0916,  0.1546,  1.0204],
        [-0.0075,  0.0384,  0.6606],
        [-0.9624, -1.0994, -3.0595]], grad_fn=<MmBackward0>)
[17]:
d_k = Q[0].size()[-1]
[18]:
scaled_attention_logits = torch.matmul(Q[0], K[0].T)/ math.sqrt(d_k)
[19]:
scaled_attention_logits
[19]:
tensor([[ 0.0648,  0.1093,  0.7215],
        [-0.0053,  0.0271,  0.4671],
        [-0.6805, -0.7774, -2.1634]], grad_fn=<DivBackward0>)
[20]:
attention_weights = F.softmax(scaled_attention_logits, dim=-1)
[21]:
attention_weights
[21]:
tensor([[0.2516, 0.2631, 0.4853],
        [0.2750, 0.2840, 0.4410],
        [0.4685, 0.4252, 0.1063]], grad_fn=<SoftmaxBackward0>)
[22]:
attention_weights @ V[0]
[22]:
tensor([[ 0.1225,  0.1061],
        [ 0.0810,  0.0481],
        [-0.2333, -0.3928]], grad_fn=<MmBackward0>)
[23]:
torch.matmul(attention_weights, V[0])
[23]:
tensor([[ 0.1225,  0.1061],
        [ 0.0810,  0.0481],
        [-0.2333, -0.3928]], grad_fn=<MmBackward0>)
[ ]:

[ ]:

[ ]:

[ ]: