Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev SDXL:multi-GPUs train #994

Closed
weiyutang opened this issue Dec 8, 2023 · 26 comments
Closed

dev SDXL:multi-GPUs train #994

weiyutang opened this issue Dec 8, 2023 · 26 comments

Comments

@weiyutang
Copy link

Traceback (most recent call last):
File "/app/wyt_train/sd-scripts/sdxl_train.py", line 777, in
train(args)
File "/app/wyt_train/sd-scripts/sdxl_train.py", line 508, in train
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
File "/app/wyt_train/sd-scripts/library/train_util.py", line 4066, in get_hidden_states_sdxl
pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
File "/app/wyt_train/sd-scripts/library/train_util.py", line 4036, in pool_workaround
pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
File "/root/anaconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1614, in getattr
raise AttributeError(""{}" object has no attribute "{}"".format(
AttributeError: "DistributedDataParallel" object has no attribute "text_projection"

@weiyutang weiyutang changed the title SDXL:multi-GPUs train dev SDXL:multi-GPUs train Dec 8, 2023
@Isotr0py
Copy link
Contributor

Isotr0py commented Dec 9, 2023

It seems that text_encoder wasn"t unwrapped when calling train_util.get_hidden_states_sdxl()

Maybe you need to change the text_encoder1 to accelerator.unwrap_model(text_encoder1) and text_encoder2 to accelerator.unwrap_model(text_encoder2) in the code below:

sd-scripts/sdxl_train.py

Lines 506 to 515 in 0908c54

encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizer1,
tokenizer2,
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
)

I can"t test this because I don"t have enough VRAM to fine tune the SDXL model. 😢

@kohya-ss
Copy link
Owner

kohya-ss commented Dec 9, 2023

Thank you for opening this issue.

@Isotr0py

It seems that text_encoder wasn"t unwrapped when calling train_util.get_hidden_states_sdxl()

Maybe you need to change the text_encoder1 to accelerator.unwrap_model(text_encoder1) and text_encoder2 to accelerator.unwrap_model(text_encoder2) in the code below:

I think it may stop the gradient synchronization for text encoders when text encoders are trained. How do you think about it?

If my guess is correct, it will be an idea to modify the work around for the pooled output: making EOS token has the largest token id even if new tokens are added to work the original CLIP"s pooled out, instead of the current implementation which is getting embeddings from the original EOS token id.

@Isotr0py
Copy link
Contributor

Isotr0py commented Dec 9, 2023

Oh! Text encoders are forward in train_util.get_hidden_states_sdxl(), unwrap text encoders as inputs will break the gradient synchronization.

However, I think if we only unwrap the text encoder at applying text projection, it won"t affect the gradient synchronization, because applying text projection is after the DDP forward.

Some maybe we can solve this like:

text_projection = accelerator.unwrap_model(text_encoder).text_projection
pooled_output = text_projection(pooled_output.to(text_projection.weight.dtype))

or

    pool2 = pool_workaround(
        accelerator.unwrap_model(text_encoder2), 
        enc_out["last_hidden_state"], 
        input_ids2, 
        tokenizer2.eos_token_id
    )

@Isotr0py
Copy link
Contributor

Isotr0py commented Dec 9, 2023

I made a small test to simulate the workflow:

import time
import torch
import torch.nn as nn
from accelerate import Accelerator, notebook_launcher


seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


class TestNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3,16, kernel_size=3, padding=1)
        self.linear = nn.Linear(16,16) # simulate text projection using a linear layer
        self.conv2 = nn.Conv2d(16,3, kernel_size=3, padding=1)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.act(x)
        x = x.permute(0, 2, 3, 1)
        x = self.linear(x)
        x = self.act(x)
        x = x.permute(0, 3, 1, 2)
        x = self.conv2(x)
        return x


def main():
    accelerator = Accelerator()

    model = TestNet()
    
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

    model, optimizer = accelerator.prepare(model, optimizer)

    # forward pass
    global_step = 0
    with accelerator.accumulate(model):
        for i in range(5):
            outputs = model(torch.rand(1,3,64,64).to(accelerator.device))
            labels = torch.rand_like(outputs)

            # hijack hidden state
            hidden_linear = accelerator.unwrap_model(model).linear
            hidden_outputs = hidden_linear(torch.rand(1,64,64,16).to(accelerator.device))
            hidden_labels = torch.rand_like(hidden_outputs)

            # backward pass and do something to hidden state
            loss = loss_fn(outputs, labels)
            loss += loss_fn(hidden_outputs, hidden_labels)
            accelerator.backward(loss)

            # update parameters
            print(
                f"""
                [Rank {accelerator.device}] 
                conv1 grad: {accelerator.unwrap_model(model).conv1.weight.grad.mean():3f} 
                linear grad: {accelerator.unwrap_model(model).linear.weight.grad.mean():3f} 
                conv2 grad: {accelerator.unwrap_model(model).conv2.weight.grad.mean():3f} 
                [Global step: {global_step}]
                """
            )
            global_step += 1
            optimizer.step()


notebook_launcher(main, num_processes=2)

And outputs:

[Rank cuda:1] conv1 grad: 0.000249 linear grad: -0.000115 conv2 grad: -0.027779 [Global step: 0]
[Rank cuda:0] conv1 grad: 0.000249 linear grad: -0.000115 conv2 grad: -0.027779 [Global step: 0]
[Rank cuda:1] conv1 grad: -0.000046 linear grad: -0.000561 conv2 grad: -0.054896 [Global step: 1]
[Rank cuda:0] conv1 grad: -0.000046 linear grad: -0.000561 conv2 grad: -0.054896 [Global step: 1]
[Rank cuda:1] conv1 grad: -0.000895 linear grad: -0.001345 conv2 grad: -0.081134 [Global step: 2]
[Rank cuda:0] conv1 grad: -0.000895 linear grad: -0.001345 conv2 grad: -0.081134 [Global step: 2]
[Rank cuda:0] conv1 grad: -0.002218 linear grad: -0.002447 conv2 grad: -0.106917 [Global step: 3]
[Rank cuda:1] conv1 grad: -0.002218 linear grad: -0.002447 conv2 grad: -0.106917 [Global step: 3]
[Rank cuda:0] conv1 grad: -0.003995 linear grad: -0.003858 conv2 grad: -0.131941 [Global step: 4]
[Rank cuda:1] conv1 grad: -0.003995 linear grad: -0.003858 conv2 grad: -0.131941 [Global step: 4]

So it seems that this proved my guess, unwrap the text encoder at applying text projection won"t stop the gradient synchronization.

@kohya-ss
Copy link
Owner

kohya-ss commented Dec 9, 2023

Thank you for quick reply and verification of the issue.

However, I think if we only unwrap the text encoder at applying text projection, it won"t affect the gradient synchronization, because applying text projection is after the DDP forward.

I understand. That"s fine!

I will add one of your suggestions to dev branch sooner.

@kohya-ss
Copy link
Owner

I think I fixed dev branch (but did not test on multi GPU env).

@weiyutang Could you please test with the latest version?

@weiyutang
Copy link
Author

weiyutang commented Dec 10, 2023

@kohya-ss @Isotr0py A new problem has arisen,40g gpu memory cannot train batch_size=1, maybe Text encoder 2 looks does not appear to use gradient_checkpoint or other reason?
image

@kohya-ss
Copy link
Owner

Grandient checkpointing seems to be enabled for both of TextEncoders.

sd-scripts/sdxl_train.py

Lines 273 to 278 in 4a2cef8

if args.train_text_encoder:
# TODO each option for two text encoders?
accelerator.print("enable text encoder training")
if args.gradient_checkpointing:
text_encoder1.gradient_checkpointing_enable()
text_encoder2.gradient_checkpointing_enable()

I think mixed_precision and/or 8bit optimizers (because the number of parameters is too large) may help.

@weiyutang
Copy link
Author

weiyutang commented Dec 10, 2023

train args looks right. image resolution=1024.
nohup accelerate launch --num_cpu_threads_per_process 4 sdxl_train.py
--pretrained_model_name_or_path="${pretrained_model}"
--dataset_config=/app/sd-scripts/a.toml
--output_dir="${output_dir}"
--logging_dir="${logging_dir}"
--output_name="${output_name}"
--bucket_reso_steps=32
--save_model_as=safetensors
--vae_batch_size=12
--max_train_epochs=20
--save_every_n_epochs=1
--learning_rate=2e-06
--train_text_encoder
--learning_rate_te1=4e-7
--learning_rate_te2=4e-7
--optimizer_type="PagedAdamW8bit"
--lr_scheduler="constant_with_warmup"
--lr_warmup_steps=100
--lr_scheduler_num_cycles=1
--optimizer_args "weight_decay=0.01"
--max_grad_norm=1.0
--mixed_precision="bf16"
--save_precision="bf16"
--cache_latents_to_disk
--sdpa
--cache_latents
--gradient_checkpointing
--ddp_timeout=300
--ip_noise_gamma=0.1
--seed=1026 &

@kohya-ss
Copy link
Owner

Could you verify again the batch size?

In addition, I can train both Text Encoders and U-Net with "--full_bf16" option with 24GB VRAM. So could you please try to specify that option and see what happens.

@weiyutang
Copy link
Author

weiyutang commented Dec 10, 2023

i"m sure batch_size=1. single gpu batch_size=16 can run, multi gpu batch_size=1 can"t run

@Isotr0py
Copy link
Contributor

Isotr0py commented Dec 10, 2023

I test with only training text_encoder2 on two T4 GPUs, and it seems that gradient sync and checkpointing work normally:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------ 
                                     record_param_comms         0.00%      95.000us         0.01%     138.000us     138.000us     453.578ms        27.51%     453.578ms     453.578ms             1  
ncclKernel_AllReduce_RING_LL_Sum_half(ncclDevComm*, ...         0.00%       0.000us         0.00%       0.000us       0.000us     453.578ms        27.51%     453.578ms     453.578ms             1  

                                                    ...

autograd::engine::evaluate_function: CheckpointFunct...         0.11%       2.268ms        75.54%        1.510s      12.278ms       0.000us         0.00%        1.314s      10.681ms           123  
                             CheckpointFunctionBackward         8.25%     164.944ms        75.33%        1.506s      12.245ms       0.000us         0.00%        1.312s      10.665ms           123  

                                                    ...

                               aten::embedding_backward         0.00%       7.000us         0.04%     857.000us     428.500us       0.000us         0.00%     571.000us     285.500us             2  
                                       c10d::allreduce_         0.00%      16.000us         0.01%     154.000us     154.000us       0.000us         0.00%     453.578ms     453.578ms             1  
                                    cudaStreamWaitEvent         0.00%      15.000us         0.00%      15.000us       3.000us       0.000us         0.00%       0.000us       0.000us             5  
                                        nccl:all_reduce         0.00%       0.000us             0     114.000us     114.000us       0.000us         0.00%       0.000us       0.000us             1  
                                                INVALID         0.00%       1.000us         0.00%       1.000us       1.000us       0.000us         0.00%       0.000us       0.000us             1  
     torch.distributed.ddp.reducer::copy_bucket_to_grad         0.03%     660.000us         0.31%       6.132ms      11.861us       0.000us         0.00%      12.621ms      24.412us           517  
                                  cudaDeviceSynchronize        23.03%     460.457ms        23.03%     460.457ms     460.457ms       0.000us         0.00%       0.000us       0.000us             1 
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.999s
Self CUDA time total: 1.649s

It"s strange that DDP training has a larger VRAM usage than single GPU.

@weiyutang
Copy link
Author

@Isotr0py unet, txt1, txt2 are trained, the varm will explode even if batch_size=1 and pagedadamw8bit and gradient_checkpoint and bf16, i don"t know why

@Isotr0py
Copy link
Contributor

Emmm, this beyonds me. Can you disable txt1/txt2 training to see what happened?

If the cause is txt2, disable txt2 training should prevent VRAM explosion.

@weiyutang
Copy link
Author

weiyutang commented Dec 10, 2023

@Isotr0py
image
only train unet, batch_size=1, it takes up too much varm!!!!!

@kohya-ss
Copy link
Owner

@Isotr0py
Do we need to specify all models and optimizers etc. in a single call to prepare, like a following?

t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler
)

@Isotr0py
Copy link
Contributor

Isotr0py commented Dec 10, 2023

No, according to the source code, prepare is just wrapping objects iteratively, so I think it"s no problem to call it in multiple steps.

@weiyutang Can you replace accelerator.backward(loss) with the codes below to see what goes wrong in GPU?

                with torch.profiler.profile(
                    activities=[
                        torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA,
                    ],
                ) as p:
                    accelerator.backward(loss)
                print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))

I guess it"s because unet"s grad bucket reduce or grad checkpoint failed with some reasons.
If my guess is right, torch.distributed.ddp.reducer::copy_bucket_to_grad or CheckpointFunctionBackward will not be printed in the table.

@Isotr0py
Copy link
Contributor

Isotr0py commented Dec 10, 2023

Refer to pytorch discussion forum, maybe the significant increased VRAM is an expected behavior of DDP, because the SDXL model is too large.

So according to the tips given in the discussion, maybe we can reduce memory peak like this:

from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs

kwargs = DistributedDataParallelKwargs(gradient_as_bucket_view=True)
accelerator = Accelerator(kwargs_handlers=[kwargs])

@weiyutang
Copy link
Author

No, according to the source code, prepare is just wrapping objects iteratively, so I think it"s no problem to call it in multiple steps.

@weiyutang Can you replace accelerator.backward(loss) with the codes below to see what goes wrong in GPU?

                with torch.profiler.profile(
                    activities=[
                        torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA,
                    ],
                ) as p:
                    accelerator.backward(loss)
                print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))

I guess it"s because unet"s grad bucket reduce or grad checkpoint failed with some reasons. If my guess is right, torch.distributed.ddp.reducer::copy_bucket_to_grad or CheckpointFunctionBackward will not be printed in the table.

image looks right

@weiyutang
Copy link
Author

Refer to pytorch discussion forum, maybe the significant increased VRAM is an expected behavior of DDP, because the SDXL model is too large.

So according to the tips given in the discussion, maybe we can reduce memory peak like this:

from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs

kwargs = DistributedDataParallelKwargs(gradient_as_bucket_view=True)
accelerator = Accelerator(kwargs_handlers=[kwargs])

can reduce memory

@weiyutang
Copy link
Author

weiyutang commented Dec 11, 2023

image new issue。.....>_<..... Report such errors when training te

@Isotr0py
Copy link
Contributor

It seems that it"s a conflict between DDP and gradient checkpoint in text_encoders (maybe the te2), because model is expected to forward only once in DDP training for checkpoint.

According to the tip, you can try to add static_graph=True to DistributedDataParallelKwargs to see if it would work.

from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs

kwargs = DistributedDataParallelKwargs(gradient_as_bucket_view=True, static_graph=True)
accelerator = Accelerator(kwargs_handlers=[kwargs])

@weiyutang
Copy link
Author

weiyutang commented Dec 11, 2023

It seems that it"s a conflict between DDP and gradient checkpoint in text_encoders (maybe the te2), because model is expected to forward only once in DDP training for checkpoint.

According to the tip, you can try to add static_graph=True to DistributedDataParallelKwargs to see if it would work.

from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs

kwargs = DistributedDataParallelKwargs(gradient_as_bucket_view=True, static_graph=True)
accelerator = Accelerator(kwargs_handlers=[kwargs])

looks not valid,
image
[Rank cuda:0]
te1 grad: -0.000541687011718750000000000000
[Global step: 62]
[Rank cuda:1]
te1 grad: 0.000000000005684341886080801487
[Global step: 62]
gradient sync is not right. ...>_<..

@Isotr0py
Copy link
Contributor

Isotr0py commented Dec 11, 2023

No, no need to unwrap te1 and te2 anymore. This will break grad sync when they forward in train_util.get_hidden_states_sdxl.

Since text_encoder.text_projection issue has been fixed in the latest version, we just need to pass te1 and te2 normally.

@weiyutang
Copy link
Author

@Isotr0py add this can run!!!!!!
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)

thanks, Train sdxl both(unet+te1+te2)40g memory will explode!!!!
i try only unet+te1 can run.

Thanks very much for your help

@weiyutang
Copy link
Author

@kohya-ss Thanks very much for your help

wkpark pushed a commit to wkpark/sd-scripts that referenced this issue Feb 27, 2024
Add prefix and postfix for WD14 captioning
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants