This vignette describes how to implement the attention mechanism - which forms the basis of transformers - in the R language.
The code is translated from the Python original by Stefania Cristina (University of Malta) in her post The Attention Mechanism from Scratch.
We begin by generating encoder representations of four different words.
# encoder representations of four different words
= matrix(c(1,0,0), nrow=1)
word_1 = matrix(c(0,1,0), nrow=1)
word_2 = matrix(c(1,1,0), nrow=1)
word_3 = matrix(c(0,0,1), nrow=1) word_4
Next, we stack the word embeddings into a single array (in this case
a matrix) which we call words
.
# stacking the word embeddings into a single array
= rbind(word_1,
words
word_2,
word_3, word_4)
Let’s see what this looks like.
print(words)
#> [,1] [,2] [,3]
#> [1,] 1 0 0
#> [2,] 0 1 0
#> [3,] 1 1 0
#> [4,] 0 0 1
Next, we generate random integers on the domain
[0,3]
.
# initializing the weight matrices (with random values)
set.seed(0)
= matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
W_Q = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
W_K = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3) W_V
Next, we generate the Queries (Q
), Keys
(K
), and Values (V
). The %*%
operator performs the matrix multiplication. You can view the R help
page using help('%*%')
(or the online An
Introduction to R).
# generating the queries, keys and values
= words %*% W_Q
Q = words %*% W_K
K = words %*% W_V V
Following this, we score the Queries (Q
) against the Key
(K
) vectors (which are transposed for the multiplation
using t()
, see help('t')
for more info).
# scoring the query vectors against all key vectors
= Q %*% t(K)
scores print(scores)
#> [,1] [,2] [,3] [,4]
#> [1,] 6 4 10 5
#> [2,] 4 6 10 6
#> [3,] 10 10 20 11
#> [4,] 3 1 4 2
We now generate the weights
matrix.
= ComputeWeights(scores) weights
Let’s have a look at the weights
matrix.
print(weights)
#> [,1] [,2] [,3] [,4]
#> [1,] 0.10679806 0.03928881 0.7891368 0.06477630
#> [2,] 0.03770440 0.10249120 0.7573132 0.10249120
#> [3,] 0.00657627 0.00657627 0.9760050 0.01084244
#> [4,] 0.27600434 0.10153632 0.4550542 0.16740510
Finally, we compute the attention
as a weighted sum of
the value vectors (which are combined in the matrix V
).
# computing the attention by a weighted sum of the value vectors
= weights %*% V attention
Now we can view the results using:
print(attention)
#> [,1] [,2] [,3]
#> [1,] 2.749848 1.856646 0.06477630
#> [2,] 2.654822 1.692526 0.10249120
#> [3,] 2.969429 1.976005 0.01084244
#> [4,] 2.353518 1.629522 0.16740510
After working through this, have a look at the Complete Self-Attention from Scratch vignette.