Skip to content

Commit

Permalink
Move TE/UN loss calc to train script
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Jan 21, 2023
1 parent 17089b1 commit 22ee0ac
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
12 changes: 0 additions & 12 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,17 1423,5 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))

def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {"loss/current": current_loss, "loss/average": avr_loss}

if args.network_train_unet_only:
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
elif args.network_train_text_encoder_only:
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
else:
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1]

return logs

# endregion
17 changes: 15 additions & 2 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 21,20 @@ def collate_fn(examples):
return examples[0]


def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {"loss/current": current_loss, "loss/average": avr_loss}

if args.network_train_unet_only:
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
elif args.network_train_text_encoder_only:
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
else:
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder

return logs


def train(args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
Expand Down Expand Up @@ -353,8 367,7 @@ def train(args):
progress_bar.set_postfix(**logs)

if args.logging_dir is not None:
logs = train_util.generate_step_logs(args, current_loss, avr_loss, lr_scheduler)

logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
Expand Down

0 comments on commit 22ee0ac

Please sign in to comment.