Skip to content

Commit

Permalink
conditional caption dropout (in progress)
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Feb 7, 2023
1 parent f9478f0 commit e42b2f7
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 28 deletions.
10 changes: 5 additions & 5 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def train(args):
args.bucket_reso_steps, args.bucket_no_upscale,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
args.dataset_repeats, args.debug_dataset)

# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs)

train_dataset.make_buckets()

if args.debug_dataset:
Expand Down Expand Up @@ -171,10 +175,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")

# 学習データのdropout率を設定する
train_dataset.dropout_rate = args.dropout_rate
train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs

# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
Expand Down Expand Up @@ -339,7 +339,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
parser = argparse.ArgumentParser()

train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_sd_saving_arguments(parser)

Expand Down
35 changes: 23 additions & 12 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def set_predefined_resos(self, resos):
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
self.predefined_resos = resos.copy()
self.predefined_resos_set = set(resos)
self.predifined_aspect_ratios = np.array([w / h for w, h in resos])
self.predefined_aspect_ratios = np.array([w / h for w, h in resos])

def add_if_new_reso(self, reso):
if reso not in self.reso_to_id:
Expand All @@ -135,7 +135,7 @@ def select_bucket(self, image_width, image_height):
if reso in self.predefined_resos_set:
pass
else:
ar_errors = self.predifined_aspect_ratios - aspect_ratio
ar_errors = self.predefined_aspect_ratios - aspect_ratio
predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
reso = self.predefined_resos[predefined_bucket_id]

Expand Down Expand Up @@ -223,9 +223,10 @@ def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_to

self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2

self.epoch_current:int = int(0)
self.dropout_rate:float = 0
self.dropout_every_n_epochs:int = 0
# TODO 外から渡したほうが安心だが自動で計算したほうが呼ぶ側に余分なコードがいらないのでよさそう
self.epoch_current: int = int(0)
self.dropout_rate: float = 0
self.dropout_every_n_epochs: int = None

# augmentation
flip_p = 0.5 if flip_aug else 0.0
Expand All @@ -251,6 +252,12 @@ def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_to

self.replacements = {}

def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs):
# 将来的にタグのドロップアウトも対応したいのでメソッドを生やしておく
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
self.dropout_rate = dropout_rate
self.dropout_every_n_epochs = dropout_every_n_epochs

def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {})
self.tag_frequency[dir_name] = frequency_for_dir
Expand Down Expand Up @@ -604,9 +611,9 @@ def __getitem__(self, index):

# dropoutの決定
is_drop_out = False
if self.dropout_rate > 0 and self.dropout_rate < random.random() :
if self.dropout_rate > 0 and random.random() < self.dropout_rate:
is_drop_out = True
if self.dropout_every_n_epochs > 0 and self.epoch_current % self.dropout_every_n_epochs == 0 :
if self.dropout_every_n_epochs and self.epoch_current % self.dropout_every_n_epochs == 0:
is_drop_out = True

if is_drop_out:
Expand Down Expand Up @@ -1391,7 +1398,7 @@ def verify_training_args(args: argparse.Namespace):
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")


def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool):
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool):
# dataset common
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--shuffle_caption", action="store_true",
Expand Down Expand Up @@ -1421,10 +1428,14 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
parser.add_argument("--bucket_no_upscale", action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
parser.add_argument("--dropout_rate", type=float, default=0,
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
parser.add_argument("--dropout_every_n_epochs", type=int, default=0,
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")

if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
parser.add_argument("--caption_dropout_rate", type=float, default=0,
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")

if support_dreambooth:
# DreamBooth dataset
Expand Down
11 changes: 6 additions & 5 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,13 @@ def train(args):
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)

if args.no_token_padding:
train_dataset.disable_token_padding()

# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs)

train_dataset.make_buckets()

if args.debug_dataset:
Expand Down Expand Up @@ -136,10 +141,6 @@ def train(args):
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)

# 学習データのdropout率を設定する
train_dataset.dropout_rate = args.dropout_rate
train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs

# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
Expand Down Expand Up @@ -333,7 +334,7 @@ def train(args):
parser = argparse.ArgumentParser()

train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, False)
train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True)
train_util.add_sd_saving_arguments(parser)

Expand Down
10 changes: 5 additions & 5 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def train(args):
args.bucket_reso_steps, args.bucket_no_upscale,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
args.dataset_repeats, args.debug_dataset)

# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs)

train_dataset.make_buckets()

if args.debug_dataset:
Expand Down Expand Up @@ -219,10 +223,6 @@ def train(args):
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)

# 学習データのdropout率を設定する
train_dataset.dropout_rate = args.dropout_rate
train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs

# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
Expand Down Expand Up @@ -516,7 +516,7 @@ def remove_old_func(old_epoch_no):
parser = argparse.ArgumentParser()

train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)

parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
Expand Down
2 changes: 1 addition & 1 deletion train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def load_weights(file):
parser = argparse.ArgumentParser()

train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)

parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
Expand Down

0 comments on commit e42b2f7

Please sign in to comment.