-
Notifications
You must be signed in to change notification settings - Fork 2
/
data.py
146 lines (108 loc) · 4.62 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from torch.utils.data import Dataset as PytorchDataset
from utils import logger
import numpy as np
import torch
import h5py
def prep_data(train_filename, eval_filename=None, is_test=False, is_just_eval=False):
kwargs = {'do_train': True, 'do_test' if is_test else 'do_val': True}
datasets = load_datasets(train_filename, **kwargs)
train = datasets.pop('tr')
val = datasets[next(iter(datasets))]
test = None
if eval_filename not in {None, ''}:
test = load_datasets(eval_filename, do_val=True).pop('val')
# See original code:
# https://github.com/harvardnlp/data2text/blob/master/extractor.lua#L81
min_entdist = min(train['entdists'].min(), val['entdists'].min())
min_numdist = min(train['numdists'].min(), val['numdists'].min())
max_entdist = train['entdists'].max()
max_numdist = train['numdists'].max()
train.shift_dists(min_entdist=min_entdist, min_numdist=min_numdist)
val.shift_dists(min_entdist=min_entdist, min_numdist=min_numdist)
if test is not None:
test.clamp_dists(min_entdist=min_entdist, min_numdist=min_numdist,
max_entdist=max_entdist, max_numdist=max_numdist)
test.shift_dists(min_entdist=min_entdist, min_numdist=min_numdist)
nlabels = train['labels'].max().item() 1
ent_dist_pad = train['entdists'].max() 1
num_dist_pad = train['numdists'].max() 1
word_pad = train['sents'].max() 1
datasets = [train, val, test]
min_dists = [min_entdist, min_numdist]
paddings = [word_pad, ent_dist_pad, num_dist_pad]
return datasets, min_dists, paddings, nlabels
def load_datasets(filename, do_train=False, do_val=False, do_test=False):
sets = set()
if do_train: sets.add('tr')
if do_val: sets.add('val')
if do_test: sets.add('test')
if not len(sets):
raise RuntimeError('This function was asked to load no data!')
return make_datasets(filename, sets)
def make_datasets(h5_filename, sets=None):
sets = sets if sets is not None else ['tr']
assert all(dname in {'tr', 'val', 'test'} for dname in sets)
logger.info(f'Reading file: {h5_filename}')
file = h5py.File(h5_filename, mode="r")
datasets = {dname: dict() for dname in {'tr', 'val', 'test'}}
for dname, dvals in file.items():
for prefix, dataset in datasets.items():
if dname.startswith(prefix):
dataset[dname[len(prefix):]] = torch.tensor(np.array(dvals))
break
else:
assert dname == 'boxrestartidxs'
datasets['val']['boxrestartidxs'] = torch.tensor(np.array(dvals))
file.close()
return {
dname: __datasets[dname](**datasets[dname])
for dname in sets
}
class Dataset(PytorchDataset):
def __init__(self, entdists, labels, lens, numdists, sents):
self._len = entdists.size(0)
assert all(tensor.size(0) == len(self)
for tensor in [entdists, labels, lens, numdists, sents])
self.entdists = entdists
self.labels = labels
self.lens = lens
self.numdists = numdists
self.sents = sents
def shift_dists(self, min_entdist, min_numdist):
self.entdists.add_(-min_entdist)
self.numdists.add_(-min_numdist)
def __getitem__(self, item):
if isinstance(item, (int, slice)):
return {
'sents': self.sents[item],
'entdists': self.entdists[item],
'numdists': self.numdists[item],
'lens': self.lens[item],
'labels': self.labels[item],
}
if (ret := getattr(self, item, None)) is not None:
return ret
raise AttributeError(f'{self.__class__}')
def __len__(self):
return self._len
def __repr__(self):
return f'{self.__class__.__name__.title()}(n_examples={len(self)})'
class EvaluationDataset(Dataset):
def __init__(self, entdists, labels, lens, numdists, sents, boxrestartidxs=None):
super().__init__(entdists, labels, lens, numdists, sents)
self.boxrestartidxs = boxrestartidxs
self.labelnums = self.labels[:, -1]
self.labels = self.labels[:, :-1]
def clamp_dists(self, min_entdist, max_entdist, min_numdist, max_numdist):
self.entdists.clamp_(min_entdist, max_entdist)
self.numdists.clamp_(min_numdist, max_numdist)
def __getitem__(self, item):
ret = super().__getitem__(item)
if isinstance(item, (int, slice)):
ret['labelnums'] = self.labelnums[item]
return ret
__datasets = {
'tr': Dataset,
'val': EvaluationDataset,
'test': EvaluationDataset,
}