-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c3e40f4
commit fe98027
Showing
9 changed files
with
1,116 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,24 @@ | ||
__pycache__/ | ||
dist/ | ||
docs/_build/ | ||
bin/ | ||
.ipynb_checkpoints/ | ||
Untitled*.ipynb | ||
data/ | ||
*.egg-info | ||
.coverage | ||
*~ | ||
*# | ||
*#* | ||
.* | ||
*.pyc | ||
*.pkl | ||
*.gz | ||
*.log | ||
*.c | ||
*.so | ||
*.py~ | ||
*.pt | ||
*.pth | ||
*.pickle | ||
*.mat |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,122 @@ | ||
# Fine tuning lib | ||
|
||
Implements some parameter-efficient fine-tuning techniques for neural networks in PyTorch. | ||
|
||
Started out as a refactoring attempt for [PEFT](https://github.com/huggingface/peft), now lives in this cave. | ||
|
||
## Initializing from a config file | ||
|
||
This is similar to how PEFT does it. | ||
|
||
```python | ||
class MLP(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.lin0 = nn.Linear(10, 300) | ||
self.relu = nn.ReLU() | ||
self.drop = nn.Dropout(0.5) | ||
self.lin1 = nn.Linear(300, 1) | ||
|
||
def forward(self, X): | ||
X = self.lin0(X) | ||
X = self.relu(X) | ||
X = self.drop(X) | ||
X = self.lin1(X) | ||
return X | ||
|
||
X, y = self.get_data() | ||
# only train the LoRA weights for layer 'lin0' | ||
config = LoraConfig(target_modules="lin0")) | ||
model = AdapterWrapper.from_config(MLP(), config) | ||
# use whatever training method or custom train function you like | ||
train(model, X, y) | ||
``` | ||
|
||
Train an additional layer, but without LoRA | ||
|
||
```python | ||
# 'lin1' is trained normally | ||
config = LoraConfig(target_modules="lin0"), modules_to_save=["lin1"]) | ||
model = AdapterWrapper.from_config(MLP(), config) | ||
train(model, X, y) | ||
``` | ||
|
||
## Mixing adapters | ||
|
||
These things are not possible in PEFT. | ||
|
||
### Mixing LoRA layers with different settings | ||
|
||
```python | ||
model = AdapterWrapper(MLP()) | ||
model.add_adapter() | ||
|
||
# pass the uninitialized layer to add_adapter | ||
lin0_lora = partial(LinearLoraLayer, r=4) | ||
model.add_adapter_layer("lin0", lin0_lora) | ||
lin1_lora = partial(LinearLoraLayer, r=16) | ||
model.add_adapter_layer("lin1", lin1_lora) | ||
``` | ||
|
||
### Mixing different types of adapters | ||
|
||
For instance, mix LoRA and IA³: | ||
|
||
```python | ||
model = AdapterWrapper(MLP()) | ||
model.add_adapter() | ||
|
||
lin_lora = partial(LinearLoraLayer, r=4) | ||
model.add_adapter_layer("lin0", lin_lora) | ||
lin_ia3 = LinearIA3Layer | ||
model.add_adapter_layer("lin1", lin_ia3) | ||
``` | ||
|
||
### Custom adapters | ||
|
||
```python | ||
class MyLinearLayer(AdapterLayer): | ||
... | ||
|
||
model = AdapterWrapper(MLP()) | ||
model.add_adapter() | ||
model.add_adapter_layer("lin0", MyLinearLayer) | ||
``` | ||
|
||
## Utilities | ||
|
||
This is the same as in PEFT: | ||
|
||
```python | ||
# create a model with the "default" adapter | ||
config = ... | ||
model = AdapterWrapper.from_config(MLP(), config, adapter_name="default") | ||
|
||
# create a new adapter | ||
model.add_adapter("other-adapter") | ||
|
||
# add new adapter layer | ||
model.add_adapter_layer("lin1", LinearLoraLayer) | ||
|
||
# delete adapter layer | ||
model.delete_adapter_layer("lin1") | ||
|
||
# switch to a different adapter | ||
model.set_adapter("default") | ||
|
||
# merging | ||
model.merge_adapter() | ||
|
||
# undo the merge | ||
model.unmerge_adapter() | ||
|
||
# return the base mode | ||
base_model = model.unload() | ||
|
||
# merge adapter into base model and return it | ||
base_model_merged = model.merge_and_unload() | ||
``` | ||
|
||
## Status | ||
|
||
Not seriously maintained, don't use this for prod. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,107 @@ | ||
from __future__ import annotations | ||
|
||
import copy | ||
from dataclasses import dataclass, field | ||
from typing import Any | ||
|
||
from torch import nn | ||
|
||
|
||
@dataclass | ||
class AdapterConfig: | ||
target_modules: str | list[str] | None = field( | ||
default=None, | ||
metadata={ | ||
"help": "List of module names or regex expression of the module names to replace with Lora." | ||
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " | ||
}, | ||
) | ||
|
||
modules_to_save: list[str] = field( | ||
default_factory=list, | ||
metadata={ | ||
"help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. " | ||
"For example, in Sequence Classification or Token Classification tasks, " | ||
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." | ||
}, | ||
) | ||
|
||
|
||
class AdapterLayer(nn.Module): | ||
def __init__(self, base_module: nn.Module) -> None: | ||
super().__init__() | ||
self.base_module = base_module # TODO rename to base_layer | ||
|
||
self.active: bool = True | ||
self.merged: bool = False | ||
self.reset_params() | ||
self.reset_device() | ||
|
||
@classmethod | ||
def from_config(cls, config: AdapterConfig, base_module: nn.Module) -> AdapterLayer: | ||
raise NotImplementedError | ||
|
||
def set_active(self, active: bool) -> None: | ||
self.active = active | ||
|
||
def reset_params(self) -> None: | ||
raise NotImplementedError | ||
|
||
def reset_device(self) -> None: | ||
raise NotImplementedError | ||
|
||
def reset_requires_grad(self) -> None: | ||
raise NotImplementedError | ||
|
||
def merge(self) -> None: | ||
raise NotImplementedError | ||
|
||
def unmerge(self) -> None: | ||
raise NotImplementedError | ||
|
||
def _pre_forward(self, *args, **kwargs): | ||
return args, kwargs | ||
|
||
def forward(self, *args, **kwargs) -> Any: | ||
args, kwargs = self._pre_forward(*args, **kwargs) | ||
output = self.base_module(*args, **kwargs) | ||
return self._post_forward(output, *args, **kwargs) | ||
|
||
def _post_forward(self, output, *args, **kwargs): | ||
return output | ||
|
||
|
||
class ModulesToSaveWrapper(AdapterLayer): | ||
def reset_params(self) -> None: | ||
self.new_module = copy.deepcopy(self.base_module) | ||
|
||
def reset_device(self) -> None: | ||
pass | ||
|
||
def reset_requires_grad(self) -> None: | ||
self.base_module.requires_grad_(False) | ||
self.new_module.requires_grad_(True) | ||
|
||
def merge(self) -> None: | ||
if self.merged: | ||
return | ||
|
||
self.base_module, self.new_module = self.new_module, self.base_module | ||
self.merged = True | ||
|
||
def unmerge(self) -> None: | ||
if not self.merged: | ||
return | ||
|
||
self.base_module, self.new_module = self.new_module, self.base_module | ||
self.merged = False | ||
|
||
def forward(self, *args, **kwargs) -> Any: | ||
args, kwargs = self._pre_forward(*args, **kwargs) | ||
|
||
if self.active is self.merged: | ||
output = self.base_module(*args, **kwargs) | ||
else: | ||
output = self.new_module(*args, **kwargs) | ||
|
||
return self._post_forward(output, *args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,68 @@ | ||
"""Function and classes that construct an adapter from a config""" | ||
|
||
from __future__ import annotations | ||
|
||
import re | ||
from typing import Any, Callable, Iterator | ||
|
||
from torch import nn | ||
|
||
from .base import AdapterConfig, AdapterLayer | ||
from .ia3 import IA3Config, LinearIA3Layer | ||
from .lora import EmbeddingLoraLayer, LinearLoraLayer, LoraConfig | ||
|
||
|
||
def _get_selection_strategy(config: AdapterConfig, base_model: nn.Module) -> Any: # TODO | ||
if isinstance(config.target_modules, str): | ||
return _regex_selection_strategy(config, base_model) | ||
if isinstance(config.target_modules, list): | ||
return _list_match_selection_strategy(config, base_model) | ||
raise ValueError("TODO") | ||
|
||
|
||
def _regex_selection_strategy(config: AdapterConfig, base_model: nn.Module) -> Iterator[tuple[str, Any]]: | ||
assert isinstance(config.target_modules, str) | ||
for name, _ in base_model.named_modules(): | ||
if re.fullmatch(config.target_modules, name): | ||
yield name, None | ||
|
||
|
||
def _list_match_selection_strategy(config: AdapterConfig, base_model: nn.Module) -> Iterator[tuple[str, Any]]: | ||
assert isinstance(config.target_modules, list) | ||
assert isinstance(config.target_modules[0], str) | ||
for name, _ in base_model.named_modules(): | ||
if any(name.endswith(key) for key in config.target_modules): | ||
yield name, None | ||
|
||
|
||
def _get_adaptation_strategy( | ||
config: AdapterConfig, layer_specific_args: Any | None | ||
) -> Callable[[nn.Module, str], AdapterLayer]: | ||
# TODO: more complex strategies, e.g. allowing to provide user defined layers | ||
# as replacement layers, or even mixing things up, like: | ||
# - one is LoRA layer with r=8, another with r=16 | ||
# - one is LoRA layer, another is IA³ layer | ||
if layer_specific_args is None: | ||
return _OneToOneMappingStrategy(config) | ||
raise ValueError("TODO") | ||
|
||
|
||
class _OneToOneMappingStrategy: | ||
# TODO could be partial-ed function | ||
def __init__(self, config: AdapterConfig) -> None: | ||
self.config = config | ||
|
||
def __call__(self, base_model: nn.Module, name: str) -> AdapterLayer: | ||
layer = getattr(base_model, name) | ||
|
||
if isinstance(layer, nn.Linear): | ||
if isinstance(self.config, LoraConfig): | ||
return LinearLoraLayer.from_config(self.config, layer) | ||
if isinstance(self.config, IA3Config): | ||
return LinearIA3Layer.from_config(self.config, layer) | ||
|
||
if isinstance(layer, nn.Embedding): | ||
if isinstance(self.config, LoraConfig): | ||
return EmbeddingLoraLayer.from_config(self.config, layer) | ||
|
||
raise TypeError(f"Could not find a suitable adapter layer for {type(layer)} and config {type(self.config)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,53 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from .base import AdapterConfig, AdapterLayer | ||
|
||
|
||
@dataclass | ||
class IA3Config(AdapterConfig): | ||
pass | ||
|
||
|
||
class LinearIA3Layer(AdapterLayer): | ||
@classmethod | ||
def from_config(cls, config: AdapterConfig, base_module: nn.Module) -> LinearIA3Layer: | ||
return cls(base_module) | ||
|
||
def reset_device(self) -> None: | ||
self.to(self.base_module.weight.device) # type: ignore | ||
|
||
def reset_params(self) -> None: | ||
if not isinstance(self.base_module, nn.Linear): | ||
raise ValueError(f"{self.__class__.__name__} must be applied to an nn.Linear layer") | ||
self.ia3_weight = nn.Parameter(torch.ones_like(self.base_module.weight[:1])) | ||
|
||
def reset_requires_grad(self) -> None: | ||
self.base_module.requires_grad_(False) | ||
self.ia3_weight.requires_grad_(True) | ||
|
||
def _pre_forward(self, X, *args, **kwargs): | ||
if self.merged or not self.active: | ||
return (X,) args, kwargs | ||
|
||
return (X * self.ia3_weight,) args, kwargs | ||
|
||
def merge(self) -> None: | ||
if self.merged: | ||
return | ||
|
||
self.base_module.weight.data *= self.ia3_weight | ||
#self.base_module.weight.data = torch.mul(self.base_module.weight.data, self.ia3_weight.data) | ||
self.merged = True | ||
|
||
def unmerge(self) -> None: | ||
if not self.merged: | ||
return | ||
|
||
self.base_module.weight.data /= self.ia3_weight | ||
#self.base_module.weight.data = torch.div(self.base_module.weight.data, self.ia3_weight.data 1e-8) | ||
self.merged = False |
Oops, something went wrong.