-
Notifications
You must be signed in to change notification settings - Fork 914
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dev SDXL:multi-GPUs train #994
Comments
It seems that Maybe you need to change the Lines 506 to 515 in 0908c54
I can"t test this because I don"t have enough VRAM to fine tune the SDXL model. 😢 |
Thank you for opening this issue.
I think it may stop the gradient synchronization for text encoders when text encoders are trained. How do you think about it? If my guess is correct, it will be an idea to modify the work around for the pooled output: making EOS token has the largest token id even if new tokens are added to work the original CLIP"s pooled out, instead of the current implementation which is getting embeddings from the original EOS token id. |
Oh! Text encoders are forward in However, I think if we only unwrap the text encoder at applying text projection, it won"t affect the gradient synchronization, because applying text projection is after the DDP forward. Some maybe we can solve this like: text_projection = accelerator.unwrap_model(text_encoder).text_projection
pooled_output = text_projection(pooled_output.to(text_projection.weight.dtype)) or pool2 = pool_workaround(
accelerator.unwrap_model(text_encoder2),
enc_out["last_hidden_state"],
input_ids2,
tokenizer2.eos_token_id
) |
I made a small test to simulate the workflow: import time
import torch
import torch.nn as nn
from accelerate import Accelerator, notebook_launcher
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
class TestNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3,16, kernel_size=3, padding=1)
self.linear = nn.Linear(16,16) # simulate text projection using a linear layer
self.conv2 = nn.Conv2d(16,3, kernel_size=3, padding=1)
self.act = nn.ReLU()
def forward(self, x):
x = self.conv1(x)
x = self.act(x)
x = x.permute(0, 2, 3, 1)
x = self.linear(x)
x = self.act(x)
x = x.permute(0, 3, 1, 2)
x = self.conv2(x)
return x
def main():
accelerator = Accelerator()
model = TestNet()
loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
model, optimizer = accelerator.prepare(model, optimizer)
# forward pass
global_step = 0
with accelerator.accumulate(model):
for i in range(5):
outputs = model(torch.rand(1,3,64,64).to(accelerator.device))
labels = torch.rand_like(outputs)
# hijack hidden state
hidden_linear = accelerator.unwrap_model(model).linear
hidden_outputs = hidden_linear(torch.rand(1,64,64,16).to(accelerator.device))
hidden_labels = torch.rand_like(hidden_outputs)
# backward pass and do something to hidden state
loss = loss_fn(outputs, labels)
loss += loss_fn(hidden_outputs, hidden_labels)
accelerator.backward(loss)
# update parameters
print(
f"""
[Rank {accelerator.device}]
conv1 grad: {accelerator.unwrap_model(model).conv1.weight.grad.mean():3f}
linear grad: {accelerator.unwrap_model(model).linear.weight.grad.mean():3f}
conv2 grad: {accelerator.unwrap_model(model).conv2.weight.grad.mean():3f}
[Global step: {global_step}]
"""
)
global_step += 1
optimizer.step()
notebook_launcher(main, num_processes=2) And outputs:
So it seems that this proved my guess, unwrap the text encoder at applying text projection won"t stop the gradient synchronization. |
Thank you for quick reply and verification of the issue.
I understand. That"s fine! I will add one of your suggestions to dev branch sooner. |
I think I fixed dev branch (but did not test on multi GPU env). @weiyutang Could you please test with the latest version? |
Grandient checkpointing seems to be enabled for both of TextEncoders. Lines 273 to 278 in 4a2cef8
I think mixed_precision and/or 8bit optimizers (because the number of parameters is too large) may help. |
train args looks right. image resolution=1024. |
Could you verify again the batch size? In addition, I can train both Text Encoders and U-Net with "--full_bf16" option with 24GB VRAM. So could you please try to specify that option and see what happens. |
i"m sure batch_size=1. single gpu batch_size=16 can run, multi gpu batch_size=1 can"t run |
I test with only training text_encoder2 on two T4 GPUs, and it seems that gradient sync and checkpointing work normally:
It"s strange that DDP training has a larger VRAM usage than single GPU. |
@Isotr0py unet, txt1, txt2 are trained, the varm will explode even if batch_size=1 and pagedadamw8bit and gradient_checkpoint and bf16, i don"t know why |
Emmm, this beyonds me. Can you disable txt1/txt2 training to see what happened? If the cause is txt2, disable txt2 training should prevent VRAM explosion. |
@Isotr0py |
@Isotr0py
|
No, according to the source code, @weiyutang Can you replace with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
) as p:
accelerator.backward(loss)
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) I guess it"s because unet"s grad bucket reduce or grad checkpoint failed with some reasons. |
Refer to pytorch discussion forum, maybe the significant increased VRAM is an expected behavior of DDP, because the SDXL model is too large. So according to the tips given in the discussion, maybe we can reduce memory peak like this: from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
kwargs = DistributedDataParallelKwargs(gradient_as_bucket_view=True)
accelerator = Accelerator(kwargs_handlers=[kwargs]) |
looks right |
can reduce memory |
It seems that it"s a conflict between DDP and gradient checkpoint in text_encoders (maybe the te2), because model is expected to forward only once in DDP training for checkpoint. According to the tip, you can try to add from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
kwargs = DistributedDataParallelKwargs(gradient_as_bucket_view=True, static_graph=True)
accelerator = Accelerator(kwargs_handlers=[kwargs]) |
No, no need to unwrap te1 and te2 anymore. This will break grad sync when they forward in Since |
@Isotr0py add this can run!!!!!! thanks, Train sdxl both(unet+te1+te2)40g memory will explode!!!! Thanks very much for your help |
@kohya-ss Thanks very much for your help |
Add prefix and postfix for WD14 captioning
Traceback (most recent call last):
File "/app/wyt_train/sd-scripts/sdxl_train.py", line 777, in
train(args)
File "/app/wyt_train/sd-scripts/sdxl_train.py", line 508, in train
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
File "/app/wyt_train/sd-scripts/library/train_util.py", line 4066, in get_hidden_states_sdxl
pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
File "/app/wyt_train/sd-scripts/library/train_util.py", line 4036, in pool_workaround
pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
File "/root/anaconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1614, in getattr
raise AttributeError(""{}" object has no attribute "{}"".format(
AttributeError: "DistributedDataParallel" object has no attribute "text_projection"
The text was updated successfully, but these errors were encountered: