Skip to content

Commit

Permalink
Add train_inpainting to Native Finetuning, DreamBooth, LoRA and even …
Browse files Browse the repository at this point in the history
…Textual Inversion lol
  • Loading branch information
Fannovel16 committed Feb 10, 2023
1 parent 6b790ba commit 45c0864
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 7 deletions.
19 changes: 19 additions & 0 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 245,22 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215

if batch["masks"] is not None:
masked_latents = vae.encode(
batch["masked_images"].reshape(batch["images"].shape).to(dtype=weight_dtype)
).latent_dist.sample()
masked_latents = masked_latents * 0.18215

masks = batch["masks"]
# Resize the mask to latents shape as we concatenate the mask to the latents
mask = torch.stack(
[
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
for mask in masks
]
)
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
b_size = latents.shape[0]

with torch.set_grad_enabled(args.train_text_encoder):
Expand All @@ -263,6 279,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if batch["masks"] is not None:
# Concatenate the noised latents with the mask and the masked latents
noisy_latents = torch.cat([noisy_latents, mask, masked_latents], dim=1)

# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
Expand Down
69 changes: 62 additions & 7 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 22,7 @@
from diffusers import DDPMScheduler, StableDiffusionPipeline
import albumentations as albu
import numpy as np
from PIL import Image
from PIL import Image, ImageDraw
import cv2
from einops import rearrange
from torch import einsum
Expand Down Expand Up @@ -195,7 195,7 @@ class BucketBatchIndex(NamedTuple):


class BaseDataset(torch.utils.data.Dataset):
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, train_inpainting: bool, debug_dataset: bool) -> None:
super().__init__()
self.tokenizer: CLIPTokenizer = tokenizer
self.max_token_length = max_token_length
Expand Down Expand Up @@ -226,6 226,7 @@ def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_to
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.dropout_rate: float = 0
self.dropout_every_n_epochs: int = None
self.train_inpainting = train_inpainting

# augmentation
flip_p = 0.5 if flip_aug else 0.0
Expand Down Expand Up @@ -564,6 565,47 @@ def crop_target(self, image, face_cx, face_cy, face_w, face_h):

return image

def prepare_mask_and_masked_image(image, mask):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0

mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)

masked_image = image * (mask < 0.5)

return mask, masked_image


# generate random masks
def random_mask(im_shape, ratio=1, mask_full_image=False):
mask = Image.new("L", im_shape, 0)
draw = ImageDraw.Draw(mask)
size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio)))
# use this to always mask the whole image
if mask_full_image:
size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio))
limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2)
center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1]))
draw_type = random.randint(0, 1)
if draw_type == 0 or mask_full_image:
draw.rectangle(
(center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] size[0] // 2, center[1] size[1] // 2),
fill=255,
)
else:
draw.ellipse(
(center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] size[0] // 2, center[1] size[1] // 2),
fill=255,
)

return

def load_latents_from_npz(self, image_info: ImageInfo, flipped):
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
if npz_file is None:
Expand All @@ -586,6 628,8 @@ def __getitem__(self, index):
input_ids_list = []
latents_list = []
images = []
masks = []
masked_images = []

for image_key in bucket[image_index:image_index bucket_batch_size]:
image_info = self.image_data[image_key]
Expand Down Expand Up @@ -625,10 669,18 @@ def __getitem__(self, index):
if self.aug is not None:
img = self.aug(image=img)['image']

if self.train_inpainting:
pil_image = transforms.functional.ToPILImage(img)
mask = self.random_mask(pil_image.size, 1, False)
mask, masked_image = self.prepare_mask_and_masked_image(pil_image, mask)

masks.append(mask)
masked_images.append(masked_image)

latents = None
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる

images.append(image)
images.append(image)
latents_list.append(latents)

caption = self.process_caption(image_info.caption)
Expand All @@ -652,6 704,8 @@ def __getitem__(self, index):
else:
images = None
example['images'] = images
example['masks'] = torch.stack(masks) if masks[0] is not None else None
example['masked_images'] = torch.stack(masked_images) if masked_images[0] is not None else None

example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None

Expand All @@ -662,9 716,9 @@ def __getitem__(self, index):


class DreamBoothDataset(BaseDataset):
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, train_inpainting, debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, train_inpainting, debug_dataset)

assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"

Expand Down Expand Up @@ -794,10 848,11 @@ def load_dreambooth_dir(dir):
self.num_reg_images = num_reg_images



class FineTuningDataset(BaseDataset):
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, train_inpainting, debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, train_inpainting, debug_dataset)

# メタデータを読み込む
if os.path.exists(json_file_name):
Expand Down
19 changes: 19 additions & 0 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 234,22 @@ def train(args):
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215

if batch["masks"] is not None:
masked_latents = vae.encode(
batch["masked_images"].reshape(batch["images"].shape).to(dtype=weight_dtype)
).latent_dist.sample()
masked_latents = masked_latents * 0.18215

masks = batch["masks"]
# Resize the mask to latents shape as we concatenate the mask to the latents
mask = torch.stack(
[
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
for mask in masks
]
)
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
b_size = latents.shape[0]
Expand All @@ -251,6 267,9 @@ def train(args):
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if batch["masks"] is not None:
# Concatenate the noised latents with the mask and the masked latents
noisy_latents = torch.cat([noisy_latents, mask, masked_latents], dim=1)

# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
Expand Down
20 changes: 20 additions & 0 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 396,23 @@ def train(args):
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215

if batch["masks"] is not None:
masked_latents = vae.encode(
batch["masked_images"].reshape(batch["images"].shape).to(dtype=weight_dtype)
).latent_dist.sample()
masked_latents = masked_latents * 0.18215

masks = batch["masks"]
# Resize the mask to latents shape as we concatenate the mask to the latents
mask = torch.stack(
[
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
for mask in masks
]
)
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)

b_size = latents.shape[0]

with torch.set_grad_enabled(train_text_encoder):
Expand All @@ -413,6 430,9 @@ def train(args):
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if batch["masks"] is not None:
# Concatenate the noised latents with the mask and the masked latents
noisy_latents = torch.cat([noisy_latents, mask, masked_latents], dim=1)

# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
Expand Down
19 changes: 19 additions & 0 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 311,22 @@ def train(args):
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215

if batch["masks"] is not None:
masked_latents = vae.encode(
batch["masked_images"].reshape(batch["images"].shape).to(dtype=weight_dtype)
).latent_dist.sample()
masked_latents = masked_latents * 0.18215

masks = batch["masks"]
# Resize the mask to latents shape as we concatenate the mask to the latents
mask = torch.stack(
[
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
for mask in masks
]
)
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
b_size = latents.shape[0]

# Get the text embedding for conditioning
Expand All @@ -328,6 344,9 @@ def train(args):
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if batch["masks"] is not None:
# Concatenate the noised latents with the mask and the masked latents
noisy_latents = torch.cat([noisy_latents, mask, masked_latents], dim=1)

# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
Expand Down

0 comments on commit 45c0864

Please sign in to comment.