Self Attention - Another implementation
[1]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
[2]:
sequence = torch.tensor([[1,2,3]])
[3]:
embed_dim=2
[4]:
embed = torch.nn.Embedding(4, embed_dim, 0)
[5]:
embedded_tokens = embed(sequence)
[6]:
embedded_tokens
[6]:
tensor([[[-1.2221, -0.5735],
[-1.1988, -0.3782],
[-1.4229, 1.7899]]], grad_fn=<EmbeddingBackward0>)
[7]:
input_dim=2
embed_dim=2
[8]:
QKV_transformation = nn.Linear(input_dim, 3*embed_dim)
[9]:
QKV = QKV_transformation(embedded_tokens)
[10]:
QKV
[10]:
tensor([[[ 0.2566, 0.5266, 1.0491, -0.3373, -0.3685, -0.5847],
[ 0.1248, 0.4105, 1.0927, -0.2389, -0.2934, -0.4745],
[-1.1946, -0.8623, 1.7932, 1.0638, 0.6026, 0.7792]]],
grad_fn=<ViewBackward0>)
[11]:
Q, K, V = QKV.chunk(3, dim=-1)
[12]:
Q
[12]:
tensor([[[ 0.2566, 0.5266],
[ 0.1248, 0.4105],
[-1.1946, -0.8623]]], grad_fn=<SplitBackward0>)
[13]:
K
[13]:
tensor([[[ 1.0491, -0.3373],
[ 1.0927, -0.2389],
[ 1.7932, 1.0638]]], grad_fn=<SplitBackward0>)
[14]:
V
[14]:
tensor([[[-0.3685, -0.5847],
[-0.2934, -0.4745],
[ 0.6026, 0.7792]]], grad_fn=<SplitBackward0>)
[15]:
attn_logits = torch.matmul(Q[0], K[0].T)
[16]:
attn_logits
[16]:
tensor([[ 0.0916, 0.1546, 1.0204],
[-0.0075, 0.0384, 0.6606],
[-0.9624, -1.0994, -3.0595]], grad_fn=<MmBackward0>)
[17]:
d_k = Q[0].size()[-1]
[18]:
scaled_attention_logits = torch.matmul(Q[0], K[0].T)/ math.sqrt(d_k)
[19]:
scaled_attention_logits
[19]:
tensor([[ 0.0648, 0.1093, 0.7215],
[-0.0053, 0.0271, 0.4671],
[-0.6805, -0.7774, -2.1634]], grad_fn=<DivBackward0>)
[20]:
attention_weights = F.softmax(scaled_attention_logits, dim=-1)
[21]:
attention_weights
[21]:
tensor([[0.2516, 0.2631, 0.4853],
[0.2750, 0.2840, 0.4410],
[0.4685, 0.4252, 0.1063]], grad_fn=<SoftmaxBackward0>)
[22]:
attention_weights @ V[0]
[22]:
tensor([[ 0.1225, 0.1061],
[ 0.0810, 0.0481],
[-0.2333, -0.3928]], grad_fn=<MmBackward0>)
[23]:
torch.matmul(attention_weights, V[0])
[23]:
tensor([[ 0.1225, 0.1061],
[ 0.0810, 0.0481],
[-0.2333, -0.3928]], grad_fn=<MmBackward0>)
[ ]:
[ ]:
[ ]:
[ ]: