Skip to content

Commit

Permalink
Add dropout options
Browse files Browse the repository at this point in the history
  • Loading branch information
forestsource committed Feb 6, 2023
1 parent d591891 commit 7db98ba
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 5 deletions.
7 changes: 7 additions & 0 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")

# 学習データのdropout率を設定する
train_dataset.dropout_rate = args.dropout_rate
train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs

# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
Expand Down Expand Up @@ -226,6 +230,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")

train_dataset.epoch_current = epoch + 1

for m in training_models:
m.train()

Expand Down
20 changes: 19 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_to

self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2

self.epoch_current:int = int(0)
self.dropout_rate:float = 0
self.dropout_every_n_epochs:int = 0

# augmentation
flip_p = 0.5 if flip_aug else 0.0
if color_aug:
Expand Down Expand Up @@ -598,7 +602,17 @@ def __getitem__(self, index):
images.append(image)
latents_list.append(latents)

caption = self.process_caption(image_info.caption)
# dropoutの決定
is_drop_out = False
if self.dropout_rate > 0 and self.dropout_rate < random.random() :
is_drop_out = True
if self.dropout_every_n_epochs > 0 and self.epoch_current % self.dropout_every_n_epochs == 0 :
is_drop_out = True

if is_drop_out:
caption = ""
else:
caption = self.process_caption(image_info.caption)
captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future
input_ids_list.append(self.get_input_ids(caption))
Expand Down Expand Up @@ -1407,6 +1421,10 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
parser.add_argument("--bucket_no_upscale", action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
parser.add_argument("--dropout_rate", type=float, default=0,
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
parser.add_argument("--dropout_every_n_epochs", type=int, default=0,
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")

if support_dreambooth:
# DreamBooth dataset
Expand Down
6 changes: 6 additions & 0 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ def train(args):
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)

# 学習データのdropout率を設定する
train_dataset.dropout_rate = args.dropout_rate
train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs

# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
Expand Down Expand Up @@ -204,6 +208,8 @@ def train(args):
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")

train_dataset.epoch_current = epoch + 1

# 指定したステップ数までText Encoderを学習する:epoch最初の状態
unet.train()
# train==True is required to enable gradient_checkpointing
Expand Down
15 changes: 11 additions & 4 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,16 @@ def train(args):
print("Use DreamBooth method.")
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
args.random_crop, args.debug_dataset)
else:
print("Train with captions.")
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.bucket_reso_steps, args.bucket_no_upscale,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
args.dataset_repeats, args.debug_dataset)
train_dataset.make_buckets()
Expand Down Expand Up @@ -219,6 +219,10 @@ def train(args):
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)

# 学習データのdropout率を設定する
train_dataset.dropout_rate = args.dropout_rate
train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs

# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
Expand Down Expand Up @@ -376,6 +380,9 @@ def train(args):

for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")

train_dataset.epoch_current = epoch + 1

metadata["ss_epoch"] = str(epoch+1)

network.on_epoch_start(text_encoder, unet)
Expand Down

0 comments on commit 7db98ba

Please sign in to comment.