-
Notifications
You must be signed in to change notification settings - Fork 2
/
sampler.py
83 lines (62 loc) · 2.6 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
Utilities to iterate over data by batch of examples.
build_dataset_iter and IterOnDevice are built very similarily to functions
with the same name in OpenNMT-py (https://github.com/OpenNMT/OpenNMT-py)
"""
from torch.utils.data._utils.collate import default_collate
from torch.utils.data import DataLoader
import torch
def build_dataset_iter(dataset, batch_size, vocab_sizes, device=None, is_eval=False):
"""
Given a dataset and a batch_size, creates a DataLoader that can yields
batches of examples. These batches are on the correct device.
:param dataset: the dataset of examples. Should implement datatset[idx]
that returns a dict of tensors.
:param batch_size: number of examples per batch
:param vocab_sizes: padding is done using vocab_size in this project
:param device: torch.device object on which to move the batches
:param is_eval: Shuffle training examples but keep eval in correct order
:return: A dataloader. Usage: `for batch in dataloader: ...`
"""
if dataset is None:
return None
device = torch.device('cpu') if device is None else device
collate_fn = build_collate_fn(*vocab_sizes)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=not is_eval,
collate_fn=collate_fn, pin_memory=True, drop_last=False)
return IterOnDevice(loader, device)
class IterOnDevice:
def __init__(self, dataloader, device):
self.dataloader = dataloader
self.device = device
def __hasattr__(self, attr):
return hasattr(self.dataloader, attr)
def __getattr__(self, attr):
return getattr(self.dataloader, attr)
def __len__(self):
return len(self.dataloader)
def to_device(self, batch):
return {
tname: tensor.to(self.device)
for tname, tensor in batch.items()
}
def __iter__(self):
for batch in self.dataloader:
yield self.to_device(batch)
def build_collate_fn(word_pad, ent_dist_pad, num_dist_pad):
"""
a collate_fn is used to merge a number of examples into one batch.
"""
def collate_fn(batch):
for example in batch:
_len = example['lens'].item()
example['sents'][_len:].fill_(word_pad)
example['entdists'][_len:].fill_(ent_dist_pad)
example['numdists'][_len:].fill_(num_dist_pad)
batch = default_collate(batch)
max_len = batch['lens'].max()
return {
tname: tensor[:, :max_len] if tensor.dim() == 2 else tensor
for tname, tensor in batch.items()
}
return collate_fn