forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
executable file
·364 lines (306 loc) · 12 KB
/
utils.py
File metadata and controls
executable file
·364 lines (306 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# TODO: reenable pyre after fixing the issues
# pyre-ignore-all-errors
import csv
import inspect
import os
import random
import shutil
from typing import Dict, List, Optional
import numpy as np
import torch
import transformers
from executorch.backends.qualcomm.export_utils import * # noqa: F401,F403
def replace_module_with_custom_class(
model: torch.nn.Module,
target_class: torch.nn.Module,
custom_class: torch.nn.Module,
strict: bool = False,
extra_custom_kwargs: Optional[Dict] = None,
):
"""
Recursively replaces all instances of `target_class` in `model` with `custom_class`.
Args:
model (torch.nn.Module): The root module to search within.
target_class (type): The class to be replaced.
custom_class (type): The class to replace with.
strict (bool): Whether to strictly enforce that the keys in `state_dict` match the model.
extra_custom_kwargs: Extra keyword arguments to override or extend the constructor args.
Example:
>>> class MyDecoder(Decoder):
... def __init__(self, ...)
... super().__init__()
... freqs_cos, freqs_sin = precompute_freqs_cis(...)
... self.register_buffer("freqs_cos", freqs_cos)
... self.register_buffer("freqs_sin", freqs_sin)
...
... def forward(self, x):
... ....
>>> model = Decoder()
>>> replace_module_with_custom_class(model, Decoder, MyDecoder)
"""
def extract_init_args_from_instance(instance):
init_signature = inspect.signature(instance.__init__)
init_params = [
param
for param in init_signature.parameters.values()
if param.name != "self"
]
extracted_args = {}
for param in init_params:
name = param.name
if hasattr(instance, name):
extracted_args[name] = getattr(instance, name)
elif param.default is not inspect.Parameter.empty:
extracted_args[name] = param.default
return extracted_args
if extra_custom_kwargs is None:
extra_custom_kwargs = {}
for name, child in model.named_children():
if isinstance(child, target_class):
state_dict = child.state_dict()
original_args = extract_init_args_from_instance(child)
new_module = custom_class(**{**original_args, **extra_custom_kwargs})
new_module.load_state_dict(state_dict, strict=strict)
new_module.eval()
setattr(model, name, new_module)
else:
replace_module_with_custom_class(
child, target_class, custom_class, strict, extra_custom_kwargs
)
def make_output_dir(path: str):
if os.path.exists(path):
shutil.rmtree(path, ignore_errors=True)
os.makedirs(path)
def topk_accuracy(predictions, targets, k):
def solve(prob, target, k):
_, indices = torch.topk(prob, k=k, sorted=True)
golden = torch.reshape(target, [-1, 1])
correct = (golden == indices) * 1.0
top_k_accuracy = torch.mean(correct) * k
return top_k_accuracy
cnt = 0
for index, pred in enumerate(predictions):
cnt += solve(torch.from_numpy(pred), targets[index], k)
return cnt * 100.0 / len(predictions)
def segmentation_metrics(predictions, targets, classes):
def make_confusion(goldens, predictions, num_classes):
def histogram(golden, predict):
mask = golden < num_classes
hist = np.bincount(
num_classes * golden[mask].astype(int) + predict[mask],
minlength=num_classes**2,
).reshape(num_classes, num_classes)
return hist
confusion = np.zeros((num_classes, num_classes))
for g, p in zip(goldens, predictions):
confusion += histogram(g.flatten(), p.flatten())
return confusion
eps = 1e-6
confusion = make_confusion(targets, predictions, len(classes))
pa = np.diag(confusion).sum() / (confusion.sum() + eps)
mpa = np.mean(np.diag(confusion) / (confusion.sum(axis=1) + eps))
iou = np.diag(confusion) / (
confusion.sum(axis=1) + confusion.sum(axis=0) - np.diag(confusion) + eps
)
miou = np.mean(iou)
cls_iou = dict(zip(classes, iou))
return (pa, mpa, miou, cls_iou)
def class_agnostic_mIoU(predictions, targets):
total_iou = 0
for pred, tar in zip(predictions, targets):
inter = np.count_nonzero(pred & tar)
union = np.count_nonzero(pred | tar)
total_iou += inter / (union + 1e-10)
return total_iou / len(predictions)
def evaluate_squad(predicted_texts: List[str], target_texts: List[str]):
import evaluate
squad_metric = evaluate.load("squad")
predictions = []
references = []
for i, (pred, target) in enumerate(zip(predicted_texts, target_texts)):
predictions.append({"id": str(i), "prediction_text": pred.strip()})
references.append(
{
"id": str(i),
"answers": {
"text": [target.strip()],
"answer_start": [0], # answer_start could be dummy
},
}
)
results = squad_metric.compute(predictions=predictions, references=references)
results["f1"] /= 100
results["exact_match"] /= 100
return results
def get_imagenet_dataset(
dataset_path, data_size, image_shape, crop_size=None, shuffle=True
):
from torchvision import datasets, transforms
def get_data_loader():
preprocess = transforms.Compose(
[
transforms.Resize(image_shape),
transforms.CenterCrop(crop_size or image_shape[0]),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess)
return torch.utils.data.DataLoader(
imagenet_data,
shuffle=shuffle,
)
# prepare input data
inputs, targets = [], []
data_loader = get_data_loader()
for index, data in enumerate(data_loader):
if index >= data_size:
break
feature, target = data
inputs.append((feature,))
targets.append(target)
return inputs, targets
def get_masked_language_model_dataset(dataset_path, tokenizer, data_size, shuffle=True):
def get_data_loader():
class MaskedSentencesDataset(torch.utils.data.Dataset):
def __init__(self, dataset_path, tokenizer, data_size) -> None:
self.data_size = data_size
self.dataset = self._get_val_dataset(dataset_path, data_size, tokenizer)
def _get_val_dataset(self, dataset_path, data_size, tokenizer):
data_collator = transformers.DataCollatorForLanguageModeling(
tokenizer=tokenizer
)
with open(dataset_path, "r") as f:
texts = f.read().split("\n")
texts = [
text for text in random.choices(texts, k=2000) if len(text) > 1
]
dataset = data_collator([tokenizer(text) for text in texts])
return dataset
def __getitem__(self, idx):
return (
self.dataset["input_ids"][idx].to(torch.int32),
self.dataset["attention_mask"][idx].to(torch.float32),
self.dataset["labels"][idx],
)
def __len__(self):
return self.data_size
dataset = MaskedSentencesDataset(dataset_path, tokenizer, data_size)
return torch.utils.data.DataLoader(
dataset,
shuffle=shuffle,
)
# prepare input data
inputs, targets = [], []
data_loader = get_data_loader()
for data in data_loader:
if len(inputs) >= data_size:
break
input_ids = data[0]
attention_mask = data[1]
target = data[2][0]
indice = [i for i, x in enumerate(target) if x != -100]
# continue if no mask annotated
if len(indice) == 0:
continue
inputs.append((input_ids, attention_mask))
targets.append(target)
return inputs, targets
def get_seq2seq_dataset_from_squad_csv( # noqa: C901
dataset_path,
tokenizer,
data_size,
max_hidden_seq_length=384,
shuffle=True,
):
def get_data_loader(max_hidden_seq_length):
class SquadSeq2SeqDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_path,
tokenizer,
data_size,
max_hidden_seq_length,
):
self.max_hidden_seq_length = max_hidden_seq_length
self.tokenizer = tokenizer
self.samples = self._load_and_process(dataset_path, data_size)
def _load_and_process(self, path, max_samples):
with open(path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
rows = list(reader)
if shuffle:
random.shuffle(rows)
samples = []
for row in rows:
question = row["question"].strip()
context = row["context"].strip()
answer = row["answer"].strip()
if not question or not context or not answer:
continue
input_text = f"question: {question} context: {context}"
target_text = answer
samples.append((input_text, target_text))
if len(samples) >= max_samples:
break
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
input_text, target_text = self.samples[idx]
model_input = tokenizer(
input_text,
truncation=True,
padding="max_length",
max_length=self.max_hidden_seq_length,
return_tensors="pt",
)
label = tokenizer(
target_text,
truncation=True,
padding="max_length",
max_length=64,
return_tensors="pt",
)
return {
"input_ids": model_input["input_ids"].squeeze(0),
"attention_mask": model_input["attention_mask"]
.reshape(1, 1, -1)
.to(torch.float32),
"decoder_input_ids": torch.tensor([0], dtype=torch.long),
"labels": label["input_ids"].squeeze(0),
}
dataset = SquadSeq2SeqDataset(
dataset_path, tokenizer, data_size, max_hidden_seq_length
)
collator = transformers.DataCollatorForSeq2Seq(tokenizer)
return torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=shuffle, collate_fn=collator
)
inputs, targets = [], []
data_loader = get_data_loader(max_hidden_seq_length)
for batch in data_loader:
if len(inputs) >= data_size:
break
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
decoder_input_ids = batch["decoder_input_ids"]
labels = batch["labels"][0]
if (labels != -100).sum().item() == 0:
continue
inputs.append(
(
input_ids.to(torch.long),
torch.where(attention_mask == 0.0, -255.0, 0.0),
decoder_input_ids,
)
)
targets.append(labels)
return inputs, targets