First, we need to install blurr module
for Transformers integration.
::py_install('ohmeow-blurr',pip = TRUE) reticulate
Get dataset from the link:
library(fastai)
library(magrittr)
library(zeallot)
= data.table::fread('https://raw.githubusercontent.com/ohmeow/blurr/master/nbs/cnndm_sample.csv') cnndm_df
Get pretrained BART model from transformers
:
= transformers()
transformers
= transformers$BartForConditionalGeneration
BartForConditionalGeneration
= "facebook/bart-large-cnn"
pretrained_model_name c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-%
get_hf_objects(pretrained_model_name,model_cls=BartForConditionalGeneration)
Create datablock and see batch:
= HF_SummarizationBeforeBatchTransform(hf_arch,
before_batch_tfm max_length=c(256, 130))
hf_tokenizer, = list(HF_Text2TextBlock(before_batch_tfms=before_batch_tfm,
blocks input_return_type=HF_SummarizationInput), noop())
= DataBlock(blocks=blocks,
dblock get_x=ColReader('article'),
get_y=ColReader('highlights'),
splitter=RandomSplitter())
= dblock %>% dataloaders(cnndm_df, bs=2)
dls
%>% one_batch() dls
[[1]]
[[1]]$input_ids
tensor([[ 0, 36, 16256, 43, 480, 2193, 7, 62, 7, 158,
135, 9, 70, 684, 4707, 6, 1625, 16, 4984, 25,
65, 9, 5, 144, 39865, 32273, 3806, 15, 5, 5518,
4, 20, 9544, 3455, 9, 2147, 464, 8, 1050, 29779,
26284, 15, 1632, 11534, 32, 6, 959, 6, 5608, 5,
8066, 9, 5, 247, 18, 4066, 7892, 4, 178, 89,
16, 10, 372, 432, 7, 2217, 4, 96, 5, 315,
3076, 9356, 4928, 36, 4154, 9662, 43, 623, 12978, 23853,
2521, 18, 889, 9, 10721, 625, 32273, 749, 1625, 5546,
365, 212, 4, 20, 889, 3372, 10, 333, 9, 601,
749, 14, 19655, 5, 1647, 9, 5, 3875, 18, 4707,
8, 32, 3891, 1687, 2778, 39865, 32273, 4, 1740, 63,
23491, 28919, 11, 5, 7599, 3939, 7, 63, 10602, 41166,
1634, 11, 12718, 1115, 281, 8, 5, 854, 3964, 21785,
14113, 8, 63, 38905, 8, 21950, 19947, 11, 5, 1926,
6, 1625, 10902, 41, 5784, 4066, 3143, 9, 39456, 8,
856, 25729, 4, 993, 195, 5243, 66, 9, 262, 1360,
38950, 1848, 4707, 303, 11, 1625, 480, 5, 144, 11,
143, 247, 480, 64, 129, 28, 13590, 624, 63, 7562,
4, 85, 16, 184, 7, 39179, 3505, 9, 27772, 6,
28482, 4707, 9, 7723, 6, 112, 6, 6115, 17576, 9,
7723, 8, 973, 6, 151, 1380, 14868, 9, 3451, 4,
221, 2839, 8367, 102, 6, 10, 786, 12, 7699, 1651,
14, 1364, 7, 3720, 8360, 8, 5068, 709, 11, 1625,
6, 34, 3919, 411, 4707, 61, 24, 161, 7648, 2072,
5, 1272, 2713, 30, 5, 2],
[ 0, 36, 16256, 43, 480, 6871, 24, 1655, 22049, 50,
41571, 1896, 54, 24509, 13, 244, 5, 363, 5, 601,
12, 180, 12, 279, 1896, 21, 738, 1462, 116, 280,
115, 6723, 15, 61, 985, 5, 3940, 2046, 4, 1868,
22049, 18, 8, 1896, 18, 8826, 2327, 117, 28946, 273,
11, 2559, 461, 4961, 25, 7, 1060, 28604, 2236, 16,
1317, 11347, 148, 10, 7158, 486, 31, 14, 902, 973,
6, 1125, 6, 363, 11, 23058, 6, 1261, 35, 4028,
26, 24, 21, 69, 979, 4, 280, 34920, 480, 19,
5767, 3809, 1243, 21605, 13875, 24, 21, 69, 979, 6,
41571, 6, 54, 16670, 66, 6, 150, 15151, 2459, 22049,
26, 24, 21, 69, 979, 6, 1655, 6, 54, 21,
16600, 71, 145, 4487, 30, 5, 6066, 480, 21, 1353,
7, 273, 18, 461, 7069, 6, 8, 1353, 7, 5,
200, 12, 5743, 1900, 403, 25451, 11, 1353, 1261, 4,
22049, 34, 4407, 45, 2181, 8, 1695, 37, 738, 5,
7044, 11, 1403, 12, 21409, 4, 20, 7158, 486, 702,
2330, 11, 461, 15, 273, 6, 39, 3969, 2026, 6,
124, 62, 49, 19395, 14, 24, 21, 1896, 6, 8,
45, 49, 3653, 6, 54, 21, 5, 29593, 368, 4,
4500, 4945, 628, 273, 1390, 6, 15151, 2459, 22049, 26,
79, 21, 686, 1655, 21, 5, 65, 16600, 4, 2612,
116, 41039, 10105, 37, 18, 127, 979, 39732, 264, 7173,
41039, 1250, 9, 5, 1065, 48149, 77, 553, 549, 79,
56, 655, 137, 1317, 69, 1655, 22049, 7923, 24120, 50,
8930, 66, 13, 244, 4, 2]], device='cuda:0')
[[1]]$attention_mask
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')
[[1]]$decoder_input_ids
tensor([[ 0, 1625, 4452, 7, 62, 7, 158, 135, 9, 70,
684, 4707, 15, 3875, 479, 50118, 243, 16, 184, 7,
39179, 3505, 9, 27772, 6, 28482, 5103, 4707, 8, 973,
6, 151, 3505, 9, 3451, 479, 50118, 33837, 709, 8,
2147, 464, 16, 9405, 10, 380, 8793, 15, 63, 28809,
479, 50118, 133, 3274, 9587, 16, 223, 1856, 11, 14117,
9, 145, 5, 247, 18, 632, 7648, 479, 2, 1,
1, 1, 1, 1, 1, 1],
[ 0, 5178, 35, 83, 1443, 2470, 161, 97, 1283, 6,
45, 5, 7158, 486, 6, 40, 3094, 5, 403, 479,
50118, 5341, 35, 83, 2470, 13, 1896, 18, 284, 161,
37, 4265, 5, 3940, 40, 465, 22049, 2181, 479, 50118,
534, 15356, 2459, 22049, 161, 79, 2215, 5, 28604, 2236,
16, 14, 9, 69, 979, 479, 50118, 30541, 6, 41571,
1896, 18, 18631, 1843, 26, 14, 24, 69, 979, 18,
2236, 15, 5, 7158, 486, 479]], device='cuda:0')
[[1]]$labels
tensor([[ 1625, 4452, 7, 62, 7, 158, 135, 9, 70, 684,
4707, 15, 3875, 479, 50118, 243, 16, 184, 7, 39179,
3505, 9, 27772, 6, 28482, 5103, 4707, 8, 973, 6,
151, 3505, 9, 3451, 479, 50118, 33837, 709, 8, 2147,
464, 16, 9405, 10, 380, 8793, 15, 63, 28809, 479,
50118, 133, 3274, 9587, 16, 223, 1856, 11, 14117, 9,
145, 5, 247, 18, 632, 7648, 479, 2, -100, -100,
-100, -100, -100, -100, -100, -100],
[ 5178, 35, 83, 1443, 2470, 161, 97, 1283, 6, 45,
5, 7158, 486, 6, 40, 3094, 5, 403, 479, 50118,
5341, 35, 83, 2470, 13, 1896, 18, 284, 161, 37,
4265, 5, 3940, 40, 465, 22049, 2181, 479, 50118, 534,
15356, 2459, 22049, 161, 79, 2215, 5, 28604, 2236, 16,
14, 9, 69, 979, 479, 50118, 30541, 6, 41571, 1896,
18, 18631, 1843, 26, 14, 24, 69, 979, 18, 2236,
15, 5, 7158, 486, 479, 2]], device='cuda:0')
[[2]]
tensor([[ 1625, 4452, 7, 62, 7, 158, 135, 9, 70, 684,
4707, 15, 3875, 479, 50118, 243, 16, 184, 7, 39179,
3505, 9, 27772, 6, 28482, 5103, 4707, 8, 973, 6,
151, 3505, 9, 3451, 479, 50118, 33837, 709, 8, 2147,
464, 16, 9405, 10, 380, 8793, 15, 63, 28809, 479,
50118, 133, 3274, 9587, 16, 223, 1856, 11, 14117, 9,
145, 5, 247, 18, 632, 7648, 479, 2, -100, -100,
-100, -100, -100, -100, -100, -100],
[ 5178, 35, 83, 1443, 2470, 161, 97, 1283, 6, 45,
5, 7158, 486, 6, 40, 3094, 5, 403, 479, 50118,
5341, 35, 83, 2470, 13, 1896, 18, 284, 161, 37,
4265, 5, 3940, 40, 465, 22049, 2181, 479, 50118, 534,
15356, 2459, 22049, 161, 79, 2215, 5, 28604, 2236, 16,
14, 9, 69, 979, 479, 50118, 30541, 6, 41571, 1896,
18, 18631, 1843, 26, 14, 24, 69, 979, 18, 2236,
15, 5, 7158, 486, 479, 2]], device='cuda:0')
Define a Learner
object and train the model:
= hf_config$task_specific_params['summarization'][[1]]
text_gen_kwargs 'max_length'] = 130L; text_gen_kwargs['min_length'] = 30L
text_gen_kwargs[
text_gen_kwargs
= HF_BaseModelWrapper(hf_model)
model = HF_SummarizationModelCallback(text_gen_kwargs=text_gen_kwargs)
model_cb
= Learner(dls,
learn
model,opt_func=partial(Adam),
loss_func=CrossEntropyLossFlat(), #HF_PreCalculatedLoss()
cbs=model_cb,
splitter=partial(summarization_splitter, arch=hf_arch)) #.to_native_fp16() #.to_fp16()
$create_opt()
learn$freeze()
learn
%>% fit_one_cycle(1, lr_max=4e-5) learn
Predict a new text:
= c("About 10 men armed with pistols and small machine guns raided a casino in Switzerland
test_article and made off into France with several hundred thousand Swiss francs in the early hours
of Sunday morning, police said. The men, dressed in black clothes and black ski masks,
split into two groups during the raid on the Grand Casino Basel, Chief Inspector Peter
Gill told CNN. One group tried to break into the casino's vault on the lower level
but could not get in, but they did rob the cashier of the money that was not secured,
he said. The second group of armed robbers entered the upper level where the roulette
and blackjack tables are located and robbed the cashier there, he said. As the thieves
were leaving the casino, a woman driving by and unaware of what was occurring unknowingly
blocked the armed robbers' vehicles. A gunman pulled the woman from her vehicle, beat
her, and took off for the French border. The other gunmen followed into France, which
is only about 100 meters (yards) from the casino, Gill said. There were about 600 people
in the casino at the time of the robbery. There were no serious injuries, although one
guest on the Casino floor was kicked in the head by one of the robbers when he moved,
the police officer said. Swiss authorities are working closely with French authorities,
Gill said. The robbers spoke French and drove vehicles with French lRicense plates.
CNN's Andreena Narayan contributed to this report.")
= learn$blurr_summarize(test_article, num_return_sequences=3L)
outputs cat(outputs)
Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .
There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved, police officer says .
A woman driving by unknowingly blocked the robbers' vehicles and was beaten to death .
Swiss authorities are working closely with French authorities, police chief says .
Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .
There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved, police officer says .
A woman driving by unknowingly blocked the robbers' vehicles and was beaten to death .
Swiss authorities are working closely with French authorities, an officer says.
Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .
There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved, police officer says .
A woman driving by unknowingly blocked the robbers' vehicles and was beaten to death .