Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminBossan committed Jul 24, 2023
1 parent c3e40f4 commit fe98027
Show file tree
Hide file tree
Showing 9 changed files with 1,116 additions and 0 deletions.
24 changes: 24 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 1,24 @@
__pycache__/
dist/
docs/_build/
bin/
.ipynb_checkpoints/
Untitled*.ipynb
data/
*.egg-info
.coverage
*~
*#
*#*
.*
*.pyc
*.pkl
*.gz
*.log
*.c
*.so
*.py~
*.pt
*.pth
*.pickle
*.mat
122 changes: 122 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 1,122 @@
# Fine tuning lib

Implements some parameter-efficient fine-tuning techniques for neural networks in PyTorch.

Started out as a refactoring attempt for [PEFT](https://github.com/huggingface/peft), now lives in this cave.

## Initializing from a config file

This is similar to how PEFT does it.

```python
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.lin0 = nn.Linear(10, 300)
self.relu = nn.ReLU()
self.drop = nn.Dropout(0.5)
self.lin1 = nn.Linear(300, 1)

def forward(self, X):
X = self.lin0(X)
X = self.relu(X)
X = self.drop(X)
X = self.lin1(X)
return X

X, y = self.get_data()
# only train the LoRA weights for layer 'lin0'
config = LoraConfig(target_modules="lin0"))
model = AdapterWrapper.from_config(MLP(), config)
# use whatever training method or custom train function you like
train(model, X, y)
```

Train an additional layer, but without LoRA

```python
# 'lin1' is trained normally
config = LoraConfig(target_modules="lin0"), modules_to_save=["lin1"])
model = AdapterWrapper.from_config(MLP(), config)
train(model, X, y)
```

## Mixing adapters

These things are not possible in PEFT.

### Mixing LoRA layers with different settings

```python
model = AdapterWrapper(MLP())
model.add_adapter()

# pass the uninitialized layer to add_adapter
lin0_lora = partial(LinearLoraLayer, r=4)
model.add_adapter_layer("lin0", lin0_lora)
lin1_lora = partial(LinearLoraLayer, r=16)
model.add_adapter_layer("lin1", lin1_lora)
```

### Mixing different types of adapters

For instance, mix LoRA and IA³:

```python
model = AdapterWrapper(MLP())
model.add_adapter()

lin_lora = partial(LinearLoraLayer, r=4)
model.add_adapter_layer("lin0", lin_lora)
lin_ia3 = LinearIA3Layer
model.add_adapter_layer("lin1", lin_ia3)
```

### Custom adapters

```python
class MyLinearLayer(AdapterLayer):
...

model = AdapterWrapper(MLP())
model.add_adapter()
model.add_adapter_layer("lin0", MyLinearLayer)
```

## Utilities

This is the same as in PEFT:

```python
# create a model with the "default" adapter
config = ...
model = AdapterWrapper.from_config(MLP(), config, adapter_name="default")

# create a new adapter
model.add_adapter("other-adapter")

# add new adapter layer
model.add_adapter_layer("lin1", LinearLoraLayer)

# delete adapter layer
model.delete_adapter_layer("lin1")

# switch to a different adapter
model.set_adapter("default")

# merging
model.merge_adapter()

# undo the merge
model.unmerge_adapter()

# return the base mode
base_model = model.unload()

# merge adapter into base model and return it
base_model_merged = model.merge_and_unload()
```

## Status

Not seriously maintained, don't use this for prod.
Empty file added src/fein/__init__.py
Empty file.
107 changes: 107 additions & 0 deletions src/fein/base.py
Original file line number Diff line number Diff line change
@@ -0,0 1,107 @@
from __future__ import annotations

import copy
from dataclasses import dataclass, field
from typing import Any

from torch import nn


@dataclass
class AdapterConfig:
target_modules: str | list[str] | None = field(
default=None,
metadata={
"help": "List of module names or regex expression of the module names to replace with Lora."
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
},
)

modules_to_save: list[str] = field(
default_factory=list,
metadata={
"help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. "
"For example, in Sequence Classification or Token Classification tasks, "
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
},
)


class AdapterLayer(nn.Module):
def __init__(self, base_module: nn.Module) -> None:
super().__init__()
self.base_module = base_module # TODO rename to base_layer

self.active: bool = True
self.merged: bool = False
self.reset_params()
self.reset_device()

@classmethod
def from_config(cls, config: AdapterConfig, base_module: nn.Module) -> AdapterLayer:
raise NotImplementedError

def set_active(self, active: bool) -> None:
self.active = active

def reset_params(self) -> None:
raise NotImplementedError

def reset_device(self) -> None:
raise NotImplementedError

def reset_requires_grad(self) -> None:
raise NotImplementedError

def merge(self) -> None:
raise NotImplementedError

def unmerge(self) -> None:
raise NotImplementedError

def _pre_forward(self, *args, **kwargs):
return args, kwargs

def forward(self, *args, **kwargs) -> Any:
args, kwargs = self._pre_forward(*args, **kwargs)
output = self.base_module(*args, **kwargs)
return self._post_forward(output, *args, **kwargs)

def _post_forward(self, output, *args, **kwargs):
return output


class ModulesToSaveWrapper(AdapterLayer):
def reset_params(self) -> None:
self.new_module = copy.deepcopy(self.base_module)

def reset_device(self) -> None:
pass

def reset_requires_grad(self) -> None:
self.base_module.requires_grad_(False)
self.new_module.requires_grad_(True)

def merge(self) -> None:
if self.merged:
return

self.base_module, self.new_module = self.new_module, self.base_module
self.merged = True

def unmerge(self) -> None:
if not self.merged:
return

self.base_module, self.new_module = self.new_module, self.base_module
self.merged = False

def forward(self, *args, **kwargs) -> Any:
args, kwargs = self._pre_forward(*args, **kwargs)

if self.active is self.merged:
output = self.base_module(*args, **kwargs)
else:
output = self.new_module(*args, **kwargs)

return self._post_forward(output, *args, **kwargs)
68 changes: 68 additions & 0 deletions src/fein/construction.py
Original file line number Diff line number Diff line change
@@ -0,0 1,68 @@
"""Function and classes that construct an adapter from a config"""

from __future__ import annotations

import re
from typing import Any, Callable, Iterator

from torch import nn

from .base import AdapterConfig, AdapterLayer
from .ia3 import IA3Config, LinearIA3Layer
from .lora import EmbeddingLoraLayer, LinearLoraLayer, LoraConfig


def _get_selection_strategy(config: AdapterConfig, base_model: nn.Module) -> Any: # TODO
if isinstance(config.target_modules, str):
return _regex_selection_strategy(config, base_model)
if isinstance(config.target_modules, list):
return _list_match_selection_strategy(config, base_model)
raise ValueError("TODO")


def _regex_selection_strategy(config: AdapterConfig, base_model: nn.Module) -> Iterator[tuple[str, Any]]:
assert isinstance(config.target_modules, str)
for name, _ in base_model.named_modules():
if re.fullmatch(config.target_modules, name):
yield name, None


def _list_match_selection_strategy(config: AdapterConfig, base_model: nn.Module) -> Iterator[tuple[str, Any]]:
assert isinstance(config.target_modules, list)
assert isinstance(config.target_modules[0], str)
for name, _ in base_model.named_modules():
if any(name.endswith(key) for key in config.target_modules):
yield name, None


def _get_adaptation_strategy(
config: AdapterConfig, layer_specific_args: Any | None
) -> Callable[[nn.Module, str], AdapterLayer]:
# TODO: more complex strategies, e.g. allowing to provide user defined layers
# as replacement layers, or even mixing things up, like:
# - one is LoRA layer with r=8, another with r=16
# - one is LoRA layer, another is IA³ layer
if layer_specific_args is None:
return _OneToOneMappingStrategy(config)
raise ValueError("TODO")


class _OneToOneMappingStrategy:
# TODO could be partial-ed function
def __init__(self, config: AdapterConfig) -> None:
self.config = config

def __call__(self, base_model: nn.Module, name: str) -> AdapterLayer:
layer = getattr(base_model, name)

if isinstance(layer, nn.Linear):
if isinstance(self.config, LoraConfig):
return LinearLoraLayer.from_config(self.config, layer)
if isinstance(self.config, IA3Config):
return LinearIA3Layer.from_config(self.config, layer)

if isinstance(layer, nn.Embedding):
if isinstance(self.config, LoraConfig):
return EmbeddingLoraLayer.from_config(self.config, layer)

raise TypeError(f"Could not find a suitable adapter layer for {type(layer)} and config {type(self.config)}")
53 changes: 53 additions & 0 deletions src/fein/ia3.py
Original file line number Diff line number Diff line change
@@ -0,0 1,53 @@
from __future__ import annotations

from dataclasses import dataclass

import torch
from torch import nn

from .base import AdapterConfig, AdapterLayer


@dataclass
class IA3Config(AdapterConfig):
pass


class LinearIA3Layer(AdapterLayer):
@classmethod
def from_config(cls, config: AdapterConfig, base_module: nn.Module) -> LinearIA3Layer:
return cls(base_module)

def reset_device(self) -> None:
self.to(self.base_module.weight.device) # type: ignore

def reset_params(self) -> None:
if not isinstance(self.base_module, nn.Linear):
raise ValueError(f"{self.__class__.__name__} must be applied to an nn.Linear layer")
self.ia3_weight = nn.Parameter(torch.ones_like(self.base_module.weight[:1]))

def reset_requires_grad(self) -> None:
self.base_module.requires_grad_(False)
self.ia3_weight.requires_grad_(True)

def _pre_forward(self, X, *args, **kwargs):
if self.merged or not self.active:
return (X,) args, kwargs

return (X * self.ia3_weight,) args, kwargs

def merge(self) -> None:
if self.merged:
return

self.base_module.weight.data *= self.ia3_weight
#self.base_module.weight.data = torch.mul(self.base_module.weight.data, self.ia3_weight.data)
self.merged = True

def unmerge(self) -> None:
if not self.merged:
return

self.base_module.weight.data /= self.ia3_weight
#self.base_module.weight.data = torch.div(self.base_module.weight.data, self.ia3_weight.data 1e-8)
self.merged = False
Loading

0 comments on commit fe98027

Please sign in to comment.