-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FSDP Doesn"t Work with model.generate() #30228
Comments
Hi @QiyaoWei |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
May I know why |
I think the |
@SunMarc do you know the status of |
👁️ |
Related (and also auto-closed due to lack of activity): |
This error is solved by using fsdp_model = torch.distributed.fsdp.FullyShardedDataParallel(
model,
use_orig_params=True,
) From the FSDP docs:
I am reasonably certain the reason for this error is from |
So after a lot longer than I would like to admit, I have uncovered all the gotchas of using
Here is a minimum reproducible example to show these off: OMP_NUM_THREADS=2 \
TOKENIZERS_PARALLELISM=false \
CUDA_VISIBLE_DEVICES=6,7 \
torchrun \
--rdzv-backend=c10d \
--rdzv-endpoint=localhost:0 \
--nnodes=1 \
--nproc-per-node=2 \
fsdp_generate.py fsdp_generate.py import functools
import torch
import torch.distributed
import torch.distributed.fsdp
import torch.distributed.fsdp.wrap
import transformers
import transformers.models.gpt_neo.modeling_gpt_neo
def main() -> None:
torch.distributed.init_process_group(world_size=2)
device = torch.device(torch.distributed.get_rank())
torch.cuda.set_device(device)
pretrained_model_name_or_path = "EleutherAI/gpt-neo-125m"
model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
device_map=device,
)
assert isinstance(model, transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM)
tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained_model_name_or_path
)
tokenizer.pad_token_id = tokenizer.eos_token_id
fsdp_model = torch.distributed.fsdp.FullyShardedDataParallel(
model,
auto_wrap_policy=functools.partial(
torch.distributed.fsdp.wrap.transformer_auto_wrap_policy,
transformer_layer_cls={
transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock
},
),
limit_all_gathers=True,
use_orig_params=True, # required to overcome the error "The tensor has a non-zero number of elements, but its data is not allocated yet" ... PreTrainedModel.generate is probably using some torch.compile-wrapped function
)
data_by_rank = { # differently-sized causes FSDP to hang
0: "Hello world!",
1: "The quick brown fox jumps over the lazy dog."
}
batch = tokenizer(
data_by_rank[torch.distributed.get_rank()],
return_tensors="pt",
return_attention_mask=True,
).to(device)
with torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params(fsdp_model): # required to overcome to the error ""weight" must be 2-D"
generated_text = fsdp_model.module.generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
max_length=20,
# synced_gpus=True, # True is required to use differently sized data with FSDP + generate (current default is False)
)
torch.distributed.barrier()
torch.distributed.destroy_process_group()
if __name__ == "__main__":
main() My takeaways from this:
transformers/src/transformers/generation/utils.py Lines 1742 to 1745 in a05ce55
it makes no mention of
transformers/src/transformers/trainer_seq2seq.py Lines 294 to 297 in a05ce55
transformers/src/transformers/generation/utils.py Lines 1788 to 1792 in a05ce55
It would be nice to save people the trouble of figuring this out if we could automatically detect FSDP usage.
|
Fixes huggingface#30228 Related: * pytorch/pytorch#100069 * pytorch/pytorch#123962 Similar to DeepSpeed ZeRO Stage 3, when using FSDP with multiple GPUs and differently sized data per rank, the ranks reach different synchronization points at the same time, leading to deadlock To avoid this, we can automatically set synced_gpus to True if we detect that a PreTrainedModel is being managed by FSDP using _is_fsdp_managed_module, which was added in 2.0.0 for torch.compile: https://github.com/pytorch/pytorch/blob/v2.0.0/torch/distributed/fsdp/_dynamo_utils.py
…#33483) * Default synced_gpus to True when using FullyShardedDataParallel Fixes #30228 Related: * pytorch/pytorch#100069 * pytorch/pytorch#123962 Similar to DeepSpeed ZeRO Stage 3, when using FSDP with multiple GPUs and differently sized data per rank, the ranks reach different synchronization points at the same time, leading to deadlock To avoid this, we can automatically set synced_gpus to True if we detect that a PreTrainedModel is being managed by FSDP using _is_fsdp_managed_module, which was added in 2.0.0 for torch.compile: https://github.com/pytorch/pytorch/blob/v2.0.0/torch/distributed/fsdp/_dynamo_utils.py * Remove test file * ruff formatting * ruff format * Update copyright year Co-authored-by: Arthur <[email protected]> * Add test for FSDP-wrapped model generation Before #33483, these tests would have hung for 10 minutes before crashing due to a timeout error * Ruff format * Move argparse import * Remove barrier I think this might cause more problems if one of the workers was killed * Move import into function to decrease load time #33483 (comment) * Add test for accelerate and Trainer #33483 (comment) * Refactor imports * Ruff format * Use nullcontext --------- Co-authored-by: Arthur <[email protected]>
Thank you so much! It"s working on 7b and 8b model on my side. However, I noticed an extremely high GPU memory consumption after summon_full_params. It also leads to OOM while inferencing with 70b models on 8 nodes. I can successfully finetune 70b on 8 nodes, so technically, inference should take fewer nodes. I"m curious if you have any work around methods. |
Check out FSDP2: transformers/tests/generation/test_fsdp.py Lines 81 to 101 in 3033509
It solves a lot of the problems with FSDP, like having to use |
…huggingface#33483) * Default synced_gpus to True when using FullyShardedDataParallel Fixes huggingface#30228 Related: * pytorch/pytorch#100069 * pytorch/pytorch#123962 Similar to DeepSpeed ZeRO Stage 3, when using FSDP with multiple GPUs and differently sized data per rank, the ranks reach different synchronization points at the same time, leading to deadlock To avoid this, we can automatically set synced_gpus to True if we detect that a PreTrainedModel is being managed by FSDP using _is_fsdp_managed_module, which was added in 2.0.0 for torch.compile: https://github.com/pytorch/pytorch/blob/v2.0.0/torch/distributed/fsdp/_dynamo_utils.py * Remove test file * ruff formatting * ruff format * Update copyright year Co-authored-by: Arthur <[email protected]> * Add test for FSDP-wrapped model generation Before huggingface#33483, these tests would have hung for 10 minutes before crashing due to a timeout error * Ruff format * Move argparse import * Remove barrier I think this might cause more problems if one of the workers was killed * Move import into function to decrease load time huggingface#33483 (comment) * Add test for accelerate and Trainer huggingface#33483 (comment) * Refactor imports * Ruff format * Use nullcontext --------- Co-authored-by: Arthur <[email protected]>
System Info
Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.
transformers
version: 4.39.3- distributed_type: FSDP
- mixed_precision: bf16
- use_cpu: False
- debug: False
- num_processes: 2
- machine_rank: 0
- num_machines: 1
- rdzv_backend: static
- same_network: True
- main_training_function: main
- enable_cpu_affinity: False
- fsdp_config: {"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", "fsdp_backward_prefetch": "BACKWARD_PRE", "fsdp_cpu_ram_efficient_loading": True, "fsdp_forward_prefetch": False, "fsdp_offload_params": True, "fsdp_sharding_strategy": "FULL_SHARD", "fsdp_state_dict_type": "SHARDED_STATE_DICT", "fsdp_sync_module_states": True, "fsdp_use_orig_params": True}
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
Who can help?
@ArthurZucker @younesbelkada @gante for the relevance with text models and generate()
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I am trying to use FSDP, but for some reason there is an error when I do model.generate(). MWE below
Error below
Expected behavior
The code provided should not error
The text was updated successfully, but these errors were encountered: