Skip to content

Commit

Permalink
fix multi-epoch finetuning
Browse files Browse the repository at this point in the history
  • Loading branch information
cwallenwein committed Aug 14, 2024
1 parent b588dac commit 5c5dd6e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
2 changes: 1 addition & 1 deletion scripts/train_model.py
Original file line number Diff line number Diff line change
@@ -1,7 1,7 @@
import argparse
from trainer import TrainerForPreTraining
from trainer.arguments import TrainingArguments
from model import BertModel, BertConfig
from model.bert import BertModel, BertConfig


def train(args):
Expand Down
11 changes: 5 additions & 6 deletions trainer/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 13,6 @@


class TrainerForSequenceClassificationFinetuning:
# TODO: reduce the LR after 80% of the training
def __init__(self, experiment_name: str, training_args: TrainingArguments, verbose: bool = True):
self.training_args = training_args
self.device = self.get_device(training_args.device)
Expand Down Expand Up @@ -50,6 49,7 @@ def train(

# prepare dataset
dataset.set_format("torch", device=self.device)
steps_per_epoch = dataset_size // self.training_args.macro_batch_size

# prepare optimizer
optimizer = optim.Adam(
Expand All @@ -67,12 67,12 @@ def train(
)

for epoch in tqdm(range(epochs)):
dataset = dataset.iter(batch_size=self.training_args.micro_batch_size)
dataset_for_epoch = dataset.iter(batch_size=self.training_args.micro_batch_size)
for step in tqdm(range(steps)):
macro_batch_loss = 0.0
assert self.training_args.gradient_accumulation_steps > 0, "Gradient accumulation steps must be greater than 0"
for micro_step in range(self.training_args.gradient_accumulation_steps):
micro_batch = next(dataset)
micro_batch = next(dataset_for_epoch)
sequence_classification_output = model(**micro_batch)

# calculate loss
Expand All @@ -91,14 91,14 @@ def train(
sequence_classification_output, micro_batch["labels"]
)
if self.training_args.with_wandb:
wandb.log({"mnli": accuracy}, step=step)
wandb.log({"mnli": accuracy}, step=(epoch * steps_per_epoch step))

# log loss and lr
if self.training_args.with_wandb:
wandb.log({
"loss": macro_batch_loss,
"learning_rate": scheduler.get_last_lr()[0]
}, step=step)
}, step=epoch * steps_per_epoch step)

# gradient accumulation
if self.training_args.use_gradient_clipping:
Expand All @@ -107,7 107,6 @@ def train(
optimizer.zero_grad()
scheduler.step()

steps_per_epoch = dataset_size // self.training_args.macro_batch_size
progress = (epoch * steps_per_epoch step) / (epochs * steps_per_epoch)
if progress >= 0.8 and not scheduler.decaying:
scheduler.start_decay(step, 0.8)
Expand Down

0 comments on commit 5c5dd6e

Please sign in to comment.