forked from ucbrise/actnn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
initial commit; move to the new repo
- Loading branch information
0 parents
commit 673d93e
Showing
64 changed files
with
10,287 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,30 @@ | ||
# Results | ||
results | ||
*.log | ||
*.npz | ||
image_classification/results.tar.gz | ||
image_classification/results | ||
image_classification/results.tsv | ||
|
||
# IDE & OS | ||
.idea | ||
.DS_Store | ||
|
||
# Documents | ||
*.png | ||
*.jpg | ||
*.pptx | ||
|
||
# Python | ||
*.pyc | ||
__pycache__ | ||
|
||
# VIM | ||
*.swp | ||
|
||
# Build | ||
actnn/build | ||
actnn/dist | ||
actnn/actnn.egg-info | ||
actnn/actnn/cpp_extension/*.so |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,57 @@ | ||
# ActNN : Activation Compressed Training | ||
|
||
## Install | ||
- Requirements | ||
``` | ||
torch>=1.7.1 | ||
torchvision>=0.8.2 | ||
``` | ||
|
||
- Build | ||
```bash | ||
cd actnn | ||
pip install -v -e . | ||
``` | ||
|
||
## Usage | ||
[mem_speed_benchmark/train.py](mem_speed_benchmark/train.py) is an example on using ActNN for models from torchvision. | ||
|
||
### Basic Usage | ||
- Step1: Convert the model to use ActNN's layers. | ||
```python | ||
import actnn | ||
model = actnn.QModule(model) | ||
``` | ||
|
||
- Step2: Configure the optimization level | ||
ActNN provides several optimization levels to control the trade-off between memory saving and computational overhead. | ||
You can set the optimization level by | ||
```python | ||
# available choices are ["L0", "L1", "L2", "L3", "L4", "L5"] | ||
actnn.set_optimization_level("L3") | ||
``` | ||
See [set_optimization_level](actnn/actnn/conf.py) for more details. | ||
|
||
### Advanced Features | ||
- (Optional) Change the data loader | ||
If you want to use per-sample gradient information for adaptive quantization, | ||
you have to update the dataloader to return sample indices. | ||
You can see `train_loader` in [mem_speed_benchmark/train.py](mem_speed_benchmark/train.py) for example. | ||
In addition, you have to update the configurations. | ||
```python | ||
from actnn import config, QScheme | ||
config.use_gradient = True | ||
QScheme.num_samples = 1300000 # the size of training set | ||
``` | ||
You can find sample code in the above script. | ||
|
||
|
||
## Image Classification | ||
See [image_classification](image_classification/) | ||
|
||
## Sementic Segmentation | ||
Will be added later. | ||
|
||
## Benchmark Memory Usage and Training Speed | ||
See [mem_speed_benchmark](mem_speed_benchmark/) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,10 @@ | ||
from . import dataloader | ||
from . import ops | ||
from .conf import config, set_optimization_level | ||
from .dataloader import DataLoader | ||
from .layers import QConv1d, QConv2d, QConv3d, QConvTranspose1d, QConvTranspose2d, QConvTranspose3d, \ | ||
QBatchNorm1d, QBatchNorm2d, QBatchNorm3d, QLinear, QReLU, QSyncBatchNorm, QMaxPool2d | ||
from .module import QModule | ||
from .qscheme import QScheme | ||
from .qbnscheme import QBNScheme | ||
from .utils import get_memory_usage, compute_tensor_bytes, exp_recorder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,45 @@ | ||
r"""Utility classes & functions for data loading. Code in this folder is mostly | ||
used by ../dataloder.py. | ||
A lot of multiprocessing is used in data loading, which only supports running | ||
functions defined in global environment (py2 can't serialize static methods). | ||
Therefore, for code tidiness we put these functions into different files in this | ||
folder. | ||
""" | ||
|
||
import sys | ||
import atexit | ||
|
||
# old private location of the ExceptionWrapper that some users rely on: | ||
from torch._utils import ExceptionWrapper | ||
|
||
|
||
IS_WINDOWS = sys.platform == "win32" | ||
|
||
|
||
MP_STATUS_CHECK_INTERVAL = 5.0 | ||
r"""Interval (in seconds) to check status of processes to avoid hanging in | ||
multiprocessing data loading. This is mainly used in getting data from | ||
another process, in which case we need to periodically check whether the | ||
sender is alive to prevent hanging.""" | ||
|
||
|
||
python_exit_status = False | ||
r"""Whether Python is shutting down. This flag is guaranteed to be set before | ||
the Python core library resources are freed, but Python may already be exiting | ||
for some time when this is set. | ||
Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar | ||
hook in Python 3.7 multiprocessing library: | ||
https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327 | ||
""" | ||
|
||
|
||
def _set_python_exit_flag(): | ||
global python_exit_status | ||
python_exit_status = True | ||
|
||
atexit.register(_set_python_exit_flag) | ||
|
||
|
||
from . import worker, signal_handling, pin_memory, collate, fetch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,86 @@ | ||
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to | ||
collate samples fetched from dataset into Tensor(s). | ||
These **needs** to be in global scope since Py2 doesn't support serializing | ||
static methods. | ||
""" | ||
|
||
import torch | ||
import re | ||
from torch._six import container_abcs, string_classes, int_classes | ||
|
||
np_str_obj_array_pattern = re.compile(r'[SaUO]') | ||
|
||
|
||
def default_convert(data): | ||
r"""Converts each NumPy array data field into a tensor""" | ||
elem_type = type(data) | ||
if isinstance(data, torch.Tensor): | ||
return data | ||
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ | ||
and elem_type.__name__ != 'string_': | ||
# array of string classes and object | ||
if elem_type.__name__ == 'ndarray' \ | ||
and np_str_obj_array_pattern.search(data.dtype.str) is not None: | ||
return data | ||
return torch.as_tensor(data) | ||
elif isinstance(data, container_abcs.Mapping): | ||
return {key: default_convert(data[key]) for key in data} | ||
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple | ||
return elem_type(*(default_convert(d) for d in data)) | ||
elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes): | ||
return [default_convert(d) for d in data] | ||
else: | ||
return data | ||
|
||
|
||
default_collate_err_msg_format = ( | ||
"default_collate: batch must contain tensors, numpy arrays, numbers, " | ||
"dicts or lists; found {}") | ||
|
||
|
||
def default_collate(batch): | ||
r"""Puts each data field into a tensor with outer dimension batch size""" | ||
|
||
elem = batch[0] | ||
elem_type = type(elem) | ||
if isinstance(elem, torch.Tensor): | ||
out = None | ||
if torch.utils.data.get_worker_info() is not None: | ||
# If we're in a background process, concatenate directly into a | ||
# shared memory tensor to avoid an extra copy | ||
numel = sum([x.numel() for x in batch]) | ||
storage = elem.storage()._new_shared(numel) | ||
out = elem.new(storage) | ||
return torch.stack(batch, 0, out=out) | ||
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ | ||
and elem_type.__name__ != 'string_': | ||
elem = batch[0] | ||
if elem_type.__name__ == 'ndarray': | ||
# array of string classes and object | ||
if np_str_obj_array_pattern.search(elem.dtype.str) is not None: | ||
raise TypeError(default_collate_err_msg_format.format(elem.dtype)) | ||
|
||
return default_collate([torch.as_tensor(b) for b in batch]) | ||
elif elem.shape == (): # scalars | ||
return torch.as_tensor(batch) | ||
elif isinstance(elem, float): | ||
return torch.tensor(batch, dtype=torch.float64) | ||
elif isinstance(elem, int_classes): | ||
return torch.tensor(batch) | ||
elif isinstance(elem, string_classes): | ||
return batch | ||
elif isinstance(elem, container_abcs.Mapping): | ||
return {key: default_collate([d[key] for d in batch]) for key in elem} | ||
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple | ||
return elem_type(*(default_collate(samples) for samples in zip(*batch))) | ||
elif isinstance(elem, container_abcs.Sequence): | ||
# check to make sure that the elements in batch have consistent size | ||
it = iter(batch) | ||
elem_size = len(next(it)) | ||
if not all(len(elem) == elem_size for elem in it): | ||
raise RuntimeError('each element in list of batch should be of equal size') | ||
transposed = zip(*batch) | ||
return [default_collate(samples) for samples in transposed] | ||
|
||
raise TypeError(default_collate_err_msg_format.format(elem_type)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,47 @@ | ||
r""""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch | ||
data from an iterable-style or map-style dataset. This logic is shared in both | ||
single- and multi-processing data loading. | ||
""" | ||
|
||
|
||
class _BaseDatasetFetcher(object): | ||
def __init__(self, dataset, auto_collation, collate_fn, drop_last): | ||
self.dataset = dataset | ||
self.auto_collation = auto_collation | ||
self.collate_fn = collate_fn | ||
self.drop_last = drop_last | ||
|
||
def fetch(self, possibly_batched_index): | ||
raise NotImplementedError() | ||
|
||
|
||
class _IterableDatasetFetcher(_BaseDatasetFetcher): | ||
def __init__(self, dataset, auto_collation, collate_fn, drop_last): | ||
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last) | ||
self.dataset_iter = iter(dataset) | ||
|
||
def fetch(self, possibly_batched_index): | ||
if self.auto_collation: | ||
data = [] | ||
for _ in possibly_batched_index: | ||
try: | ||
data.append(next(self.dataset_iter)) | ||
except StopIteration: | ||
break | ||
if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)): | ||
raise StopIteration | ||
else: | ||
data = next(self.dataset_iter) | ||
return self.collate_fn(data) | ||
|
||
|
||
class _MapDatasetFetcher(_BaseDatasetFetcher): | ||
def __init__(self, dataset, auto_collation, collate_fn, drop_last): | ||
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last) | ||
|
||
def fetch(self, possibly_batched_index): | ||
if self.auto_collation: | ||
data = [self.dataset[idx] for idx in possibly_batched_index] | ||
else: | ||
data = self.dataset[possibly_batched_index] | ||
return self.collate_fn(data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,59 @@ | ||
r""""Contains definitions of the methods used by the _BaseDataLoaderIter to put | ||
fetched tensors into pinned memory. | ||
These **needs** to be in global scope since Py2 doesn't support serializing | ||
static methods. | ||
""" | ||
|
||
import torch | ||
from torch._six import queue, container_abcs, string_classes | ||
from . import MP_STATUS_CHECK_INTERVAL | ||
from torch._utils import ExceptionWrapper | ||
|
||
|
||
def _pin_memory_loop(in_queue, out_queue, device_id, done_event): | ||
# This setting is thread local, and prevents the copy in pin_memory from | ||
# consuming all CPU cores. | ||
torch.set_num_threads(1) | ||
|
||
torch.cuda.set_device(device_id) | ||
|
||
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the | ||
# logic of this function. | ||
while not done_event.is_set(): | ||
try: | ||
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) | ||
except queue.Empty: | ||
continue | ||
idx, data = r | ||
if not done_event.is_set() and not isinstance(data, ExceptionWrapper): | ||
try: | ||
data = pin_memory(data) | ||
except Exception: | ||
data = ExceptionWrapper( | ||
where="in pin memory thread for device {}".format(device_id)) | ||
r = (idx, data) | ||
while not done_event.is_set(): | ||
try: | ||
out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL) | ||
break | ||
except queue.Full: | ||
continue | ||
del r # save memory | ||
|
||
|
||
def pin_memory(data): | ||
if isinstance(data, torch.Tensor): | ||
return data.pin_memory() | ||
elif isinstance(data, string_classes): | ||
return data | ||
elif isinstance(data, container_abcs.Mapping): | ||
return {k: pin_memory(sample) for k, sample in data.items()} | ||
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple | ||
return type(data)(*(pin_memory(sample) for sample in data)) | ||
elif isinstance(data, container_abcs.Sequence): | ||
return [pin_memory(sample) for sample in data] | ||
elif hasattr(data, "pin_memory"): | ||
return data.pin_memory() | ||
else: | ||
return data |
Oops, something went wrong.