-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhf.py
More file actions
112 lines (94 loc) · 4.19 KB
/
hf.py
File metadata and controls
112 lines (94 loc) · 4.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
from torch.utils.data import Dataset
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
import torch.autograd.profiler as profiler
import time
from torch.utils.data import DataLoader
"""
train_dataset = load_dataset('squad', split='train')
valid_dataset = load_dataset('squad', split='validation')
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
train_dataset = train_dataset.flatten()
train_dataset = train_dataset.map(lambda e: tokenizer(e['context'], truncation=True, padding='max_length'), batched=True)
#dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
"""
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
def get_correct_alignement(context, answer):
""" Some original examples in SQuAD have indices wrong by 1 or 2 character. We test and fix this here. """
gold_text = answer["text"][0]
start_idx = answer["answer_start"][0]
end_idx = start_idx + len(gold_text)
if context[start_idx:end_idx] == gold_text:
return start_idx, end_idx # When the gold label position is good
elif context[start_idx - 1 : end_idx - 1] == gold_text:
return start_idx - 1, end_idx - 1 # When the gold label is off by one character
elif context[start_idx - 2 : end_idx - 2] == gold_text:
return start_idx - 2, end_idx - 2 # When the gold label is off by two character
else:
raise ValueError()
# Tokenize our training dataset
def convert_to_features_context(example_batch):
encodings_context = tokenizer(
example_batch["context"], truncation=True, padding="max_length"
)
# Compute start and end tokens for labels using Transformers's fast tokenizers alignement methods.
start_positions, end_positions = [], []
for i, (context, answer) in enumerate(
zip(example_batch["context"], example_batch["answers"])
):
start_idx, end_idx = get_correct_alignement(context, answer)
start_positions.append(encodings_context.char_to_token(i, start_idx))
end_positions.append(encodings_context.char_to_token(i, end_idx - 1))
encodings_context.update(
{"start_positions": start_positions, "end_positions": end_positions}
)
return encodings_context
def convert_to_features_question(example_batch):
encodings_question = tokenizer(
example_batch["question"], truncation=True, padding="max_length"
)
return encodings_question
class SQUAD(Dataset):
def __init__(self):
# Load our training dataset and tokenizer
self.dataset = load_dataset("squad", split="train")
self.encoded_context = self.dataset.map(
convert_to_features_context, batched=True
)
self.dataset = self.dataset.flatten()
self.encoded_question = self.dataset.map(
convert_to_features_question, batched=True
)
self.encoded_context.flatten()
self.encoded_question.flatten()
# Format our dataset to outputs torch.Tensor to train a pytorch model
columns = ["input_ids", "start_positions", "end_positions"]
self.encoded_context.set_format(type="torch", columns=columns)
self.encoded_question.set_format(type="torch", columns=["input_ids"])
self.length = len(self.encoded_context["input_ids"])
self.encoded_context.flatten()
self.encoded_question.flatten()
def __len__(self):
return self.length
def __getitem__(self, idx):
t1 = time.time()
input_ids_context = self.encoded_context["input_ids"]
print("context", time.time() - t1)
t2 = time.time()
input_ids_question = self.encoded_question["input_ids"][idx]
print("question", time.time() - t2)
t3 = time.time()
input_ids_start_positions = self.encoded_context["start_positions"][idx]
print("start", time.time() - t3)
t4 = time.time()
input_ids_end_positions = self.encoded_context["end_positions"][idx]
print("end", time.time() - t4)
with profiler.record_function("is_this_right"):
return (
input_ids_context,
input_ids_question,
input_ids_start_positions,
input_ids_end_positions,
)