Skip to content

Commit

Permalink
initial commit; move to the new repo
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Apr 22, 2021
0 parents commit 673d93e
Show file tree
Hide file tree
Showing 64 changed files with 10,287 additions and 0 deletions.
30 changes: 30 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 1,30 @@
# Results
results
*.log
*.npz
image_classification/results.tar.gz
image_classification/results
image_classification/results.tsv

# IDE & OS
.idea
.DS_Store

# Documents
*.pdf
*.png
*.jpg
*.pptx

# Python
*.pyc
__pycache__

# VIM
*.swp

# Build
actnn/build
actnn/dist
actnn/actnn.egg-info
actnn/actnn/cpp_extension/*.so
57 changes: 57 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 1,57 @@
# ActNN : Activation Compressed Training

## Install
- Requirements
```
torch>=1.7.1
torchvision>=0.8.2
```

- Build
```bash
cd actnn
pip install -v -e .
```

## Usage
[mem_speed_benchmark/train.py](mem_speed_benchmark/train.py) is an example on using ActNN for models from torchvision.

### Basic Usage
- Step1: Convert the model to use ActNN's layers.
```python
import actnn
model = actnn.QModule(model)
```

- Step2: Configure the optimization level
ActNN provides several optimization levels to control the trade-off between memory saving and computational overhead.
You can set the optimization level by
```python
# available choices are ["L0", "L1", "L2", "L3", "L4", "L5"]
actnn.set_optimization_level("L3")
```
See [set_optimization_level](actnn/actnn/conf.py) for more details.

### Advanced Features
- (Optional) Change the data loader
If you want to use per-sample gradient information for adaptive quantization,
you have to update the dataloader to return sample indices.
You can see `train_loader` in [mem_speed_benchmark/train.py](mem_speed_benchmark/train.py) for example.
In addition, you have to update the configurations.
```python
from actnn import config, QScheme
config.use_gradient = True
QScheme.num_samples = 1300000 # the size of training set
```
You can find sample code in the above script.


## Image Classification
See [image_classification](image_classification/)

## Sementic Segmentation
Will be added later.

## Benchmark Memory Usage and Training Speed
See [mem_speed_benchmark](mem_speed_benchmark/)

10 changes: 10 additions & 0 deletions actnn/actnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 1,10 @@
from . import dataloader
from . import ops
from .conf import config, set_optimization_level
from .dataloader import DataLoader
from .layers import QConv1d, QConv2d, QConv3d, QConvTranspose1d, QConvTranspose2d, QConvTranspose3d, \
QBatchNorm1d, QBatchNorm2d, QBatchNorm3d, QLinear, QReLU, QSyncBatchNorm, QMaxPool2d
from .module import QModule
from .qscheme import QScheme
from .qbnscheme import QBNScheme
from .utils import get_memory_usage, compute_tensor_bytes, exp_recorder
45 changes: 45 additions & 0 deletions actnn/actnn/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 1,45 @@
r"""Utility classes & functions for data loading. Code in this folder is mostly
used by ../dataloder.py.
A lot of multiprocessing is used in data loading, which only supports running
functions defined in global environment (py2 can't serialize static methods).
Therefore, for code tidiness we put these functions into different files in this
folder.
"""

import sys
import atexit

# old private location of the ExceptionWrapper that some users rely on:
from torch._utils import ExceptionWrapper


IS_WINDOWS = sys.platform == "win32"


MP_STATUS_CHECK_INTERVAL = 5.0
r"""Interval (in seconds) to check status of processes to avoid hanging in
multiprocessing data loading. This is mainly used in getting data from
another process, in which case we need to periodically check whether the
sender is alive to prevent hanging."""


python_exit_status = False
r"""Whether Python is shutting down. This flag is guaranteed to be set before
the Python core library resources are freed, but Python may already be exiting
for some time when this is set.
Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar
hook in Python 3.7 multiprocessing library:
https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327
"""


def _set_python_exit_flag():
global python_exit_status
python_exit_status = True

atexit.register(_set_python_exit_flag)


from . import worker, signal_handling, pin_memory, collate, fetch
86 changes: 86 additions & 0 deletions actnn/actnn/_utils/collate.py
Original file line number Diff line number Diff line change
@@ -0,0 1,86 @@
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
collate samples fetched from dataset into Tensor(s).
These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
"""

import torch
import re
from torch._six import container_abcs, string_classes, int_classes

np_str_obj_array_pattern = re.compile(r'[SaUO]')


def default_convert(data):
r"""Converts each NumPy array data field into a tensor"""
elem_type = type(data)
if isinstance(data, torch.Tensor):
return data
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
# array of string classes and object
if elem_type.__name__ == 'ndarray' \
and np_str_obj_array_pattern.search(data.dtype.str) is not None:
return data
return torch.as_tensor(data)
elif isinstance(data, container_abcs.Mapping):
return {key: default_convert(data[key]) for key in data}
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
return elem_type(*(default_convert(d) for d in data))
elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes):
return [default_convert(d) for d in data]
else:
return data


default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}")


def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""

elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))

return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int_classes):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, container_abcs.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, container_abcs.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]

raise TypeError(default_collate_err_msg_format.format(elem_type))
47 changes: 47 additions & 0 deletions actnn/actnn/_utils/fetch.py
Original file line number Diff line number Diff line change
@@ -0,0 1,47 @@
r""""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch
data from an iterable-style or map-style dataset. This logic is shared in both
single- and multi-processing data loading.
"""


class _BaseDatasetFetcher(object):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
self.dataset = dataset
self.auto_collation = auto_collation
self.collate_fn = collate_fn
self.drop_last = drop_last

def fetch(self, possibly_batched_index):
raise NotImplementedError()


class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
self.dataset_iter = iter(dataset)

def fetch(self, possibly_batched_index):
if self.auto_collation:
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
break
if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
raise StopIteration
else:
data = next(self.dataset_iter)
return self.collate_fn(data)


class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
59 changes: 59 additions & 0 deletions actnn/actnn/_utils/pin_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 1,59 @@
r""""Contains definitions of the methods used by the _BaseDataLoaderIter to put
fetched tensors into pinned memory.
These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
"""

import torch
from torch._six import queue, container_abcs, string_classes
from . import MP_STATUS_CHECK_INTERVAL
from torch._utils import ExceptionWrapper


def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
# This setting is thread local, and prevents the copy in pin_memory from
# consuming all CPU cores.
torch.set_num_threads(1)

torch.cuda.set_device(device_id)

# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
while not done_event.is_set():
try:
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
idx, data = r
if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
try:
data = pin_memory(data)
except Exception:
data = ExceptionWrapper(
where="in pin memory thread for device {}".format(device_id))
r = (idx, data)
while not done_event.is_set():
try:
out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
break
except queue.Full:
continue
del r # save memory


def pin_memory(data):
if isinstance(data, torch.Tensor):
return data.pin_memory()
elif isinstance(data, string_classes):
return data
elif isinstance(data, container_abcs.Mapping):
return {k: pin_memory(sample) for k, sample in data.items()}
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
return type(data)(*(pin_memory(sample) for sample in data))
elif isinstance(data, container_abcs.Sequence):
return [pin_memory(sample) for sample in data]
elif hasattr(data, "pin_memory"):
return data.pin_memory()
else:
return data
Loading

0 comments on commit 673d93e

Please sign in to comment.