Skip to content

Commit

Permalink
allow setting torch dtype during training
Browse files Browse the repository at this point in the history
  • Loading branch information
cwallenwein committed Aug 15, 2024
1 parent 4f50609 commit 8fc4a4d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
20 changes: 18 additions & 2 deletions trainer/arguments.py
Original file line number Diff line number Diff line change
@@ -1,3 1,4 @@
import torch
from dataclasses import dataclass, fields


Expand All @@ -8,17 9,32 @@ class TrainingArguments:
device: str = "mps"
save_model_after_training: bool = True
with_wandb: bool = True

use_torch_compile: bool = True
use_gradient_clipping: bool = True
gradient_clipping_value: float = 0.5
model_dtype: str = "float32"

def __post_init__(self):
assert self.micro_batch_size % 8 == 0
assert self.macro_batch_size % self.micro_batch_size == 0
self.model_dtype: torch.dtype = parse_dtype(self.model_dtype)
self.gradient_accumulation_steps = self.macro_batch_size // self.micro_batch_size

@classmethod
def from_dict(cls, arguments: dict):
field_names = {field.name for field in fields(cls)}
return TrainingArguments(**{key: value for key, value in arguments.items() if key in field_names})
return TrainingArguments(**{key: value for key, value in arguments.items() if key in field_names})


def parse_dtype(dtype: str) -> torch.dtype:
dtype_mapping = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
}

try:
return dtype_mapping[dtype]
except KeyError:
raise ValueError(f"Unsupported dtype: {dtype}")
2 changes: 1 addition & 1 deletion trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 56,7 @@ def train(
self.initialize_wandb(model.config, self.training_args)

# prepare model
model = model.to(self.device)
model = model.to(device=self.device, dtype=self.training_args.model_dtype)
if self.training_args.use_torch_compile and self.device != "mps":
model = torch.compile(model)
model.train()
Expand Down

0 comments on commit 8fc4a4d

Please sign in to comment.