Skip to content

Commit

Permalink
Merge branch "dev" into dev_device_support
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Feb 12, 2024
2 parents e579648 + 35c6053 commit 358ca20
Show file tree
Hide file tree
Showing 62 changed files with 1,385 additions and 991 deletions.
29 changes: 21 additions & 8 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
from accelerate.utils import set_seed
from diffusers import DDPMScheduler

from library.utils import setup_logging, add_logging_arguments

setup_logging()
import logging

logger = logging.getLogger(__name__)

import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
Expand All @@ -35,6 +42,7 @@
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)

cache_latents = args.cache_latents

Expand All @@ -47,11 +55,11 @@ def train(args):
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
Expand Down Expand Up @@ -84,7 +92,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
Expand All @@ -95,7 +103,7 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"

# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)

# mixed precisionに対応した型を用意しておき適宜castする
Expand Down Expand Up @@ -219,7 +227,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)

# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
Expand Down Expand Up @@ -283,7 +293,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
Expand Down Expand Up @@ -457,12 +467,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.")
logger.info("model saved.")


def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
Expand All @@ -471,7 +482,9 @@ def setup_parser() -> argparse.ArgumentParser:
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)

parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument(
"--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
)
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(
"--learning_rate_te",
Expand Down
6 changes: 5 additions & 1 deletion finetune/blip/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
import os
from urllib.parse import urlparse
from timm.models.hub import download_cached_file
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)

class BLIP_Base(nn.Module):
def __init__(self,
Expand Down Expand Up @@ -235,6 +239,6 @@ def load_checkpoint(model,url_or_filename):
del state_dict[key]

msg = model.load_state_dict(state_dict,strict=False)
print("load checkpoint from %s"%url_or_filename)
logger.info("load checkpoint from %s"%url_or_filename)
return model,msg

42 changes: 23 additions & 19 deletions finetune/clean_captions_and_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import re

from tqdm import tqdm
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)

PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
Expand Down Expand Up @@ -36,13 +40,13 @@ def clean_tags(image_key, tags):
tokens = tags.split(", rating")
if len(tokens) == 1:
# WD14 taggerのときはこちらになるのでメッセージは出さない
# print("no rating:")
# print(f"{image_key} {tags}")
# logger.info("no rating:")
# logger.info(f"{image_key} {tags}")
pass
else:
if len(tokens) > 2:
print("multiple ratings:")
print(f"{image_key} {tags}")
logger.info("multiple ratings:")
logger.info(f"{image_key} {tags}")
tags = tokens[0]

tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
Expand Down Expand Up @@ -124,43 +128,43 @@ def clean_caption(caption):

def main(args):
if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}")
logger.info(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f)
else:
print("no metadata / メタデータファイルがありません")
logger.error("no metadata / メタデータファイルがありません")
return

print("cleaning captions and tags.")
logger.info("cleaning captions and tags.")
image_keys = list(metadata.keys())
for image_key in tqdm(image_keys):
tags = metadata[image_key].get('tags')
if tags is None:
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}")
else:
org = tags
tags = clean_tags(image_key, tags)
metadata[image_key]['tags'] = tags
if args.debug and org != tags:
print("FROM: " + org)
print("TO: " + tags)
logger.info("FROM: " + org)
logger.info("TO: " + tags)

caption = metadata[image_key].get('caption')
if caption is None:
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
else:
org = caption
caption = clean_caption(caption)
metadata[image_key]['caption'] = caption
if args.debug and org != caption:
print("FROM: " + org)
print("TO: " + caption)
logger.info("FROM: " + org)
logger.info("TO: " + caption)

# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
logger.info(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f:
json.dump(metadata, f, indent=2)
print("done!")
logger.info("done!")


def setup_parser() -> argparse.ArgumentParser:
Expand All @@ -178,10 +182,10 @@ def setup_parser() -> argparse.ArgumentParser:

args, unknown = parser.parse_known_args()
if len(unknown) == 1:
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
print("All captions and tags in the metadata are processed.")
print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
print("メタデータ内のすべてのキャプションとタグが処理されます。")
logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
logger.warning("All captions and tags in the metadata are processed.")
logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。")
args.in_json = args.out_json
args.out_json = unknown[0]
elif len(unknown) > 0:
Expand Down
22 changes: 13 additions & 9 deletions finetune/make_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
sys.path.append(os.path.dirname(__file__))
from blip.blip import blip_decoder, is_url
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)

DEVICE = get_preferred_device()

Expand Down Expand Up @@ -51,7 +55,7 @@ def __getitem__(self, idx):
# convert to tensor temporarily so dataloader will accept it
tensor = IMAGE_TRANSFORM(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None

return (tensor, img_path)
Expand All @@ -78,21 +82,21 @@ def main(args):
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path

cwd = os.getcwd()
print("Current Working Directory is: ", cwd)
logger.info(f"Current Working Directory is: {cwd}")
os.chdir("finetune")
if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
args.caption_weights = os.path.join("..", args.caption_weights)

print(f"load images from {args.train_data_dir}")
logger.info(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")

print(f"loading BLIP caption: {args.caption_weights}")
logger.info(f"loading BLIP caption: {args.caption_weights}")
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
model.eval()
model = model.to(DEVICE)
print("BLIP loaded")
logger.info("BLIP loaded")

# captioningする
def run_batch(path_imgs):
Expand All @@ -112,7 +116,7 @@ def run_batch(path_imgs):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
logger.info(f"{image_path} {caption}")

# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
Expand Down Expand Up @@ -142,7 +146,7 @@ def run_batch(path_imgs):
raw_image = raw_image.convert("RGB")
img_tensor = IMAGE_TRANSFORM(raw_image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue

b_imgs.append((image_path, img_tensor))
Expand All @@ -152,7 +156,7 @@ def run_batch(path_imgs):
if len(b_imgs) > 0:
run_batch(b_imgs)

print("done!")
logger.info("done!")


def setup_parser() -> argparse.ArgumentParser:
Expand Down
24 changes: 14 additions & 10 deletions finetune/make_captions_by_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
from transformers.generation.utils import GenerationMixin

import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)

DEVICE = get_preferred_device()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PATTERN_REPLACE = [
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
Expand All @@ -38,8 +42,8 @@ def remove_words(captions, debug):
for pat in PATTERN_REPLACE:
cap = pat.sub("", cap)
if debug and cap != caption:
print(caption)
print(cap)
logger.info(caption)
logger.info(cap)
removed_caps.append(cap)
return removed_caps

Expand Down Expand Up @@ -73,16 +77,16 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs)
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
"""

print(f"load images from {args.train_data_dir}")
logger.info(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")

# できればcacheに依存せず明示的にダウンロードしたい
print(f"loading GIT: {args.model_id}")
logger.info(f"loading GIT: {args.model_id}")
git_processor = AutoProcessor.from_pretrained(args.model_id)
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
print("GIT loaded")
logger.info("GIT loaded")

# captioningする
def run_batch(path_imgs):
Expand All @@ -100,7 +104,7 @@ def run_batch(path_imgs):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
logger.info(f"{image_path} {caption}")

# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
Expand Down Expand Up @@ -129,7 +133,7 @@ def run_batch(path_imgs):
if image.mode != "RGB":
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue

b_imgs.append((image_path, image))
Expand All @@ -140,7 +144,7 @@ def run_batch(path_imgs):
if len(b_imgs) > 0:
run_batch(b_imgs)

print("done!")
logger.info("done!")


def setup_parser() -> argparse.ArgumentParser:
Expand Down
Loading

0 comments on commit 358ca20

Please sign in to comment.