Skip to content

Commit

Permalink
Fix gradients synchronization for multi-GPUs training (kohya-ss#989)
Browse files Browse the repository at this point in the history
* delete DDP wrapper

* fix train_db vae and train_network

* fix train_db vae and train_network unwrap

* network grad sync

---------

Co-authored-by: Kohya S <[email protected]>
  • Loading branch information
Isotr0py and kohya-ss authored Dec 7, 2023
1 parent 72bbaac commit db84530
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 78 deletions.
3 changes: 0 additions & 3 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
Expand Down
2 changes: 0 additions & 2 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
torch.cuda.empty_cache()
accelerator.wait_for_everyone()

text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])

return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info


Expand Down
13 changes: 0 additions & 13 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3914,17 +3914,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
return text_encoder, vae, unet, load_stable_diffusion_format


# TODO remove this function in the future
def transform_if_model_is_DDP(text_encoder, unet, network=None):
# Transform text_encoder, unet and network from DistributedDataParallel
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)


def transform_models_if_DDP(models):
# Transform text_encoder, unet and network from DistributedDataParallel
return [model.module if type(model) == DDP else model for model in models if model is not None]


def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
# load models for each process
for pi in range(accelerator.state.num_processes):
Expand All @@ -3948,8 +3937,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
torch.cuda.empty_cache()
accelerator.wait_for_everyone()

text_encoder, unet = transform_if_model_is_DDP(text_encoder, unet)

return text_encoder, vae, unet, load_stable_diffusion_format


Expand Down
3 changes: 0 additions & 3 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
# acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
(unet,) = train_util.transform_models_if_DDP([unet])
if train_text_encoder1:
text_encoder1 = accelerator.prepare(text_encoder1)
(text_encoder1,) = train_util.transform_models_if_DDP([text_encoder1])
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
(text_encoder2,) = train_util.transform_models_if_DDP([text_encoder2])

optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

Expand Down
3 changes: 0 additions & 3 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,6 @@ def train(args):
# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# transform DDP after prepare (train_network here only)
unet = train_util.transform_models_if_DDP([unet])[0]

if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
Expand Down
3 changes: 0 additions & 3 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,6 @@ def train(args):
)
network: control_net_lllite.ControlNetLLLite

# transform DDP after prepare (train_network here only)
unet, network = train_util.transform_models_if_DDP([unet, network])

if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
Expand Down
11 changes: 7 additions & 4 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def train(args):

# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype

# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
Expand All @@ -136,7 +137,7 @@ def train(args):

# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
Expand Down Expand Up @@ -225,9 +226,6 @@ def train(args):
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)

if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error

Expand Down Expand Up @@ -484,6 +482,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)

return parser

Expand Down
63 changes: 23 additions & 40 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from tqdm import tqdm
import torch
from torch.nn.parallel import DistributedDataParallel as DDP

try:
import intel_extension_for_pytorch as ipex
Expand Down Expand Up @@ -127,6 +128,11 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
return noise_pred

def all_reduce_network(self, accelerator, network):
for param in network.parameters():
if param.grad is not None:
param.grad = accelerator.reduce(param.grad, reduction="mean")

def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)

Expand Down Expand Up @@ -390,47 +396,23 @@ def train(self, args):

# acceleratorがなんかよろしくやってくれるらしい
# TODO めちゃくちゃ冗長なのでコードを整理する
if train_unet and train_text_encoder:
if train_unet:
unet = accelerator.prepare(unet)
else:
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if len(text_encoders) > 1:
unet, t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler
)
text_encoder = text_encoders = [t_enc1, t_enc2]
del t_enc1, t_enc2
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
else:
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler
)
text_encoder = accelerator.prepare(text_encoder)
text_encoders = [text_encoder]
elif train_unet:
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
else:
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
elif train_text_encoder:
if len(text_encoders) > 1:
t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler
)
text_encoder = text_encoders = [t_enc1, t_enc2]
del t_enc1, t_enc2
else:
text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, network, optimizer, train_dataloader, lr_scheduler
)
text_encoders = [text_encoder]

unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
else:
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, lr_scheduler
)

# transform DDP after prepare (train_network here only)
text_encoders = train_util.transform_models_if_DDP(text_encoders)
unet, network = train_util.transform_models_if_DDP([unet, network])

if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
unet.train()
Expand All @@ -451,7 +433,7 @@ def train(self, args):

del t_enc

network.prepare_grad_etc(text_encoder, unet)
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)

if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
vae.requires_grad_(False)
Expand Down Expand Up @@ -714,8 +696,8 @@ def train(self, args):
del train_dataset_group

# callback for step start
if hasattr(network, "on_step_start"):
on_step_start = network.on_step_start
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
on_step_start = accelerator.unwrap_model(network).on_step_start
else:
on_step_start = lambda *args, **kwargs: None

Expand Down Expand Up @@ -749,10 +731,10 @@ def remove_model(old_ckpt_name):
current_epoch.value = epoch + 1

metadata["ss_epoch"] = str(epoch + 1)

# For --sample_at_first
self.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
network.on_epoch_start(text_encoder, unet)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)

for step, batch in enumerate(train_dataloader):
current_step.value = global_step
Expand Down Expand Up @@ -825,16 +807,17 @@ def remove_model(old_ckpt_name):
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

accelerator.backward(loss)
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = network.get_trainable_params()
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if args.scale_weight_norms:
keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization(
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
)
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
Expand Down
4 changes: 0 additions & 4 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,15 +415,11 @@ def train(self, args):
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder_or_list, unet = train_util.transform_if_model_is_DDP(text_encoder_or_list, unet)

elif len(text_encoders) == 2:
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder1, text_encoder2, unet = train_util.transform_if_model_is_DDP(text_encoder1, text_encoder2, unet)

text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]

Expand Down
3 changes: 0 additions & 3 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,6 @@ def train(args):
text_encoder, optimizer, train_dataloader, lr_scheduler
)

# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)

index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
Expand Down

0 comments on commit db84530

Please sign in to comment.