-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathself_attention.py
68 lines (51 loc) · 1.64 KB
/
self_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# %%
import torch
import torch.nn.functional as F
import math
import numpy as np
# %%
def scaled_dot_product_attention(Q, K, V, dk=4):
## matmul Q and K
QK = ?
## scale QK by the sqrt of dk
matmul_scaled = ?
attention_weights = F.softmax(matmul_scaled, dim=-1)
## matmul attention_weights by V
output = ?
return output, attention_weights
# %%
def print_attention(Q, K, V, n_digits=3):
temp_out, temp_attn = scaled_dot_product_attention(Q, K, V)
temp_out, temp_attn = temp_out.numpy(), temp_attn.numpy()
print('Attention weights are:')
print(np.round(temp_attn, n_digits))
print()
print('Output is:')
print(np.around(temp_out, n_digits))
# %%
temp_k = torch.Tensor([[10, 0, 0],
[0, 10, 0],
[0, 0, 10],
[0, 0, 10]]) # (4, 3)
temp_v = torch.Tensor([[1, 0, 1],
[10, 0, 2],
[100, 5, 0],
[1000, 6, 0]]) # (4, 3)
# %%
# This `query` aligns with the second `key`,
# so the second `value` is returned.
temp_q = torch.Tensor([[0, 10, 0]]) # (1, 3)
print_attention(temp_q, temp_k, temp_v)
# %%
# This query aligns with a repeated key (third and fourth),
# so all associated values get averaged.
temp_q = torch.Tensor([[0, 0, 10]]) # (1, 3)
print_attention(temp_q, temp_k, temp_v)
# %%
# This query aligns equally with the first and second key,
# so their values get averaged.
temp_q = torch.Tensor([[10, 10, 0]]) # (1, 3)
print_attention(temp_q, temp_k, temp_v)
# %%
temp_q = torch.Tensor([[0, 10, 0], [0, 0, 10], [10, 10, 0]]) # (3, 3)
print_attention(temp_q, temp_k, temp_v)