First, we need to install blurr module
for Transformers integration.
::py_install('ohmeow-blurr',pip = TRUE) reticulate
Get dataset from the link:
library(fastai)
library(magrittr)
library(zeallot)
= data.table::fread('https://raw.githubusercontent.com/ohmeow/blurr/master/nbs/squad_sample.csv') squad_df
And load pretrained BERT for question answering from transformers
library:
= 'bert-large-uncased-whole-word-masking-finetuned-squad'
pretrained_model_name
= transformers$BertForQuestionAnswering
hf_model_cls
c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-%
get_hf_objects(pretrained_model_name, model_cls=hf_model_cls)
= partial(pre_process_squad(), hf_arch=hf_arch, hf_tokenizer=hf_tokenizer) preprocess
Prepare dataset for fastai
:
= data.table::as.data.table(squad_df %>% py_apply(preprocess))
squad_df = 128
max_seq_len
::tibble(squad_df)
tibble
c(8,10:12)] = lapply(squad_df[, c(8,10:12)], function(x) unlist(as.vector(x)))
squad_df[, = squad_df[is_impossible == FALSE & tokenized_input_len < max_seq_len]
squad_df = c(1:max_seq_len) vocab
Crate datalaoder. But at first, create getters (how we will pick our columns):
= ifelse(hf_tokenizer$padding_side == 'right', 'only_second', 'only_first')
trunc_strat
= HF_QABeforeBatchTransform(hf_arch, hf_tokenizer,
before_batch_tfm max_length = max_seq_len,
truncation = trunc_strat,
tok_kwargs = list('return_special_tokens_mask' = TRUE))
= list(
blocks HF_TextBlock(before_batch_tfms=before_batch_tfm, input_return_type=HF_QuestionAnswerInput),
CategoryBlock(vocab=vocab),
CategoryBlock(vocab=vocab)
)
# question and context
= function(x) {
get_x if(hf_tokenizer$padding_side == 'right') {
list(x[['question']], x[['context']])
else {
} list(x[['context']], x[['question']])
}
}
= DataBlock(blocks=blocks,
dblock get_x=get_x,
get_y=list(ColReader('tok_answer_start'), ColReader('tok_answer_end')),
splitter=RandomSplitter(),
n_inp=1)
= dblock %>% dataloaders(squad_df, bs=4)
dls
%>% one_batch() dls
[[1]]
[[1]]$input_ids
tensor([[ 101, 20773, 2207, 1996, 2299, 1000, 1000, 4195, 1000, 1000,
2006, 2029, 3784, 2189, 2326, 1029, 102, 2006, 2337, 1020,
1010, 2355, 1010, 2028, 2154, 2077, 2014, 2836, 2012, 1996,
3565, 4605, 1010, 20773, 2207, 1037, 2047, 2309, 7580, 2006,
2189, 11058, 2326, 15065, 2170, 1000, 1000, 4195, 1000, 1000,
1012, 102, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 101, 2019, 2590, 4895, 6465, 2165, 2173, 2043, 1029, 102,
1999, 2325, 20773, 2772, 2019, 2330, 3661, 2029, 1996, 2028,
3049, 2018, 2042, 9334, 16442, 2005, 1025, 1996, 3661, 2001,
8280, 2000, 10413, 21442, 11705, 1998, 25930, 8820, 13471, 2050,
21469, 10631, 3490, 1011, 16950, 2863, 1010, 14328, 2068, 2000,
3579, 2006, 2308, 2004, 2027, 3710, 2004, 1996, 2132, 1997,
1996, 1043, 2581, 1999, 2762, 1998, 1996, 8740, 1999, 2148,
3088, 4414, 1010, 2029, 2097, 2707, 2000, 2275, 1996, 18402,
1999, 2458, 4804, 2077, 1037, 2364, 4895, 6465, 1999, 2244,
2325, 2008, 2097, 5323, 2047, 2458, 3289, 2005, 1996, 4245,
1012, 102],
[ 101, 20773, 2247, 2007, 6108, 1062, 2777, 2007, 3183, 1005,
1055, 2155, 2044, 2037, 2331, 1029, 102, 2206, 1996, 2331,
1997, 15528, 3897, 1010, 20773, 1998, 6108, 1011, 1062, 1010,
2426, 2060, 3862, 4481, 1010, 2777, 2007, 2010, 2155, 1012,
2044, 1996, 10219, 1997, 13337, 1997, 3897, 1005, 1055, 2331,
1010, 20773, 1998, 6108, 1011, 1062, 6955, 5190, 1997, 6363,
2000, 15358, 2068, 2041, 1012, 102, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 101, 2129, 2038, 2023, 3192, 2904, 1999, 3522, 2086, 1029,
102, 2044, 7064, 16864, 1999, 2384, 1010, 20773, 1998, 20539,
2631, 1996, 12084, 3192, 2000, 3073, 17459, 3847, 2005, 5694,
1999, 1996, 5395, 2181, 1010, 2000, 2029, 20773, 5201, 2019,
3988, 1002, 5539, 1010, 2199, 1012, 1996, 3192, 2038, 2144,
4423, 2000, 2147, 2007, 2060, 15430, 1999, 1996, 2103, 1010,
1998, 2036, 3024, 4335, 2206, 7064, 25209, 2093, 2086, 2101,
1012, 102, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0]], device='cuda:0')
[[1]]$token_type_ids
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0]], device='cuda:0')
[[1]]$special_tokens_mask
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1]], device='cuda:0')
[[1]]$attention_mask
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0]], device='cuda:0')
[[1]]$cls_index
tensor([[0],
[0],
[0],
[0]], device='cuda:0')
[[1]]$p_mask
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1]], device='cuda:0')
[[2]]
TensorCategory([42, 88, 20, 49], device='cuda:0')
[[3]]
TensorCategory([43, 90, 22, 55], device='cuda:0')
Wrap the model and fit:
= HF_BaseModelWrapper(hf_model)
model
= Learner(dls,
learn
model,opt_func=partial(Adam, decouple_wd=T),
cbs=HF_QstAndAnsModelCallback(),
splitter=hf_splitter())
$loss_func=MultiTargetLoss()
learn$create_opt() # -> will create your layer groups based on your "splitter" function
learn$freeze()
learn
%>% fit_one_cycle(4, lr_max=1e-3) learn
Lets create a dataset and predict with learn
model:
= data.frame( 'question'= 'When was Star Wars made?',
inf_df 'context'= 'George Lucas created Star Wars in 1977. He directed and produced it.')
= function(inf_df) {
bert_answer = dls$test_dl(inf_df)
test_dl = test_dl$one_batch()[[1]]['input_ids']
inp
= learn %>% predict(inf_df)
res
# as_array is a function to turn a torch tensor to R array
sapply(res[[3]],as_array)
$convert_ids_to_tokens(inp[[1]]$tolist()[[1]],
hf_tokenizerskip_special_tokens=FALSE)[sapply(res[[3]],as_array)+1]
# [sapply(res[[3]],as_array)+1] here +1 because tensor starts from 0 but R from 1
}
Result:
cat(bert_answer(inf_df))
# in 1977