Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP Doesn"t Work with model.generate() #30228

Closed
4 tasks
QiyaoWei opened this issue Apr 12, 2024 · 12 comments · Fixed by #33483
Closed
4 tasks

FSDP Doesn"t Work with model.generate() #30228

QiyaoWei opened this issue Apr 12, 2024 · 12 comments · Fixed by #33483

Comments

@QiyaoWei
Copy link

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • transformers version: 4.39.3
  • Platform: Linux-5.15.0-1059-azure-x86_64-with-glibc2.31
  • Python version: 3.10.11
  • Huggingface_hub version: 0.22.2
  • Safetensors version: 0.4.2
  • Accelerate version: 0.29.2
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: FSDP
    - mixed_precision: bf16
    - use_cpu: False
    - debug: False
    - num_processes: 2
    - machine_rank: 0
    - num_machines: 1
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - enable_cpu_affinity: False
    - fsdp_config: {"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", "fsdp_backward_prefetch": "BACKWARD_PRE", "fsdp_cpu_ram_efficient_loading": True, "fsdp_forward_prefetch": False, "fsdp_offload_params": True, "fsdp_sharding_strategy": "FULL_SHARD", "fsdp_state_dict_type": "SHARDED_STATE_DICT", "fsdp_sync_module_states": True, "fsdp_use_orig_params": True}
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []
  • PyTorch version (GPU?): 2.2.2 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: YES
  • Using distributed or parallel set-up in script?: YES

Who can help?

@ArthurZucker @younesbelkada @gante for the relevance with text models and generate()

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I am trying to use FSDP, but for some reason there is an error when I do model.generate(). MWE below

import torch
import os
from omegaconf import DictConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    StateDictType,
    BackwardPrefetch,
    ShardingStrategy,
    CPUOffload,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools
import subtrainers

class BasicTrainer(object):
    def __init__(self):

        model_name_or_path = "openai-community/gpt2-large"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.policy = AutoModelForCausalLM.from_pretrained(model_name_or_path)
        
        tokenized = self.tokenizer("hi there", return_tensors="pt").to(self.policy.device)
        print(self.policy.generate(**tokenized))
        return
    
    def train(self):
        pass
    
def get_block_class_from_model(model: torch.nn.Module, block_class_name: str) -> torch.nn.Module:
    """Get the class of a block from a model, using the block"s class name."""
    for module in model.modules():
        if module.__class__.__name__ == block_class_name:
            return module.__class__
    raise ValueError(f"Could not find block class {block_class_name} in model {model}")

def init_distributed(rank: int, world_size: int, master_addr: str = "localhost", port: int = 12355, backend: str = "nccl"):
    print(rank, "initializing distributed")
    os.environ["MASTER_ADDR"] = master_addr
    os.environ["MASTER_PORT"] = str(port)
    torch.distributed.init_process_group(backend, rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def worker_main(rank: int, world_size: int):

    init_distributed(rank, world_size)
    print(f"Creating trainer on process {rank} with world size {world_size}")
    trainer = FSDPTrainer()

    # trainer.train()
    # trainer.save()


def main():

    world_size = torch.cuda.device_count()
    print("starting", world_size, "processes for FSDP training")
    torch.multiprocessing.spawn(worker_main, nprocs=world_size, args=(world_size,), join=True)
        
class FSDPTrainer(BasicTrainer):
    def __init__(self):

        super().__init__()
        
        model_name_or_path = "openai-community/gpt2-large"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)#.to("cuda")
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.policy = AutoModelForCausalLM.from_pretrained(model_name_or_path).to("cuda")

        wrap_class = get_block_class_from_model(self.policy, "GPT2Block")
        model_auto_wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={wrap_class},)

        shared_fsdp_kwargs = dict(
            auto_wrap_policy=model_auto_wrap_policy,
            sharding_strategy=ShardingStrategy.FULL_SHARD,
            cpu_offload=CPUOffload(offload_params=False),
            backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
            ignored_modules=None,
            limit_all_gathers=False,
            use_orig_params=False,
            sync_module_states=False
        )
        mp_dtype = None
        policy_mp_policy = MixedPrecision(param_dtype=mp_dtype, reduce_dtype=mp_dtype, buffer_dtype=mp_dtype)
        self.policy = FSDP(self.policy, **shared_fsdp_kwargs, mixed_precision=policy_mp_policy)
        
        tokenized = self.tokenizer("hi there", return_tensors="pt").to(self.policy.device)
        print(self.policy.generate(**tokenized))
        return
    
if __name__ == "__main__":

    main() #BasicTrainer works, but FSDPTrainer errors

Error below

starting 2 processes for FSDP training
1 initializing distributed
0 initializing distributed
Creating trainer on process 0 with world size 2
Creating trainer on process 1 with world size 2
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
/anaconda/lib/python3.10/site-packages/transformers/generation/utils.py:1132: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
  warnings.warn(
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
/anaconda/lib/python3.10/site-packages/transformers/generation/utils.py:1132: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
  warnings.warn(
tensor([[5303,  612,  318,  257, 1256,  286,  670,  284,  307, 1760,   13,  198,
          198,    1, 1135,  423,  284,  787, 1654,  326]])
tensor([[5303,  612,  318,  257, 1256,  286,  670,  284,  307, 1760,   13,  198,
          198,    1, 1135,  423,  284,  787, 1654,  326]])
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Traceback (most recent call last):
  File "/home/azureuser/f-divergence-dpo/mwe.py", line 98, in <module>
    main()
  File "/home/azureuser/f-divergence-dpo/mwe.py", line 61, in main
    torch.multiprocessing.spawn(worker_main, nprocs=world_size, args=(world_size,), join=True)
  File "/anaconda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 241, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
  File "/anaconda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes
    while not context.join():
  File "/anaconda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 158, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/anaconda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 68, in _wrap
    fn(i, *args)
  File "/home/azureuser/f-divergence-dpo/mwe.py", line 51, in worker_main
    trainer = FSDPTrainer()
  File "/home/azureuser/f-divergence-dpo/mwe.py", line 92, in __init__
    print(self.policy.generate(**tokenized))
  File "/anaconda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/anaconda/lib/python3.10/site-packages/transformers/generation/utils.py", line 1527, in generate
    result = self._greedy_search(
  File "/anaconda/lib/python3.10/site-packages/transformers/generation/utils.py", line 2411, in _greedy_search
    outputs = self(
  File "/anaconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/anaconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/anaconda/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1074, in forward
    transformer_outputs = self.transformer(
  File "/anaconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/anaconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/anaconda/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 837, in forward
    inputs_embeds = self.wte(input_ids)
  File "/anaconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/anaconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/anaconda/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 163, in forward
    return F.embedding(
  File "/anaconda/lib/python3.10/site-packages/torch/nn/functional.py", line 2237, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

Expected behavior

The code provided should not error

@younesbelkada
Copy link
Contributor

Hi @QiyaoWei
If I am not mistaken FSDP is not compatible with generate - cc @pacman100 @SunMarc (as I think we"ve discussed this offline at some point but not 100% sure)

@sahilsuneja1
Copy link

Hi @QiyaoWei,
Have a look at pacman100"s model(**batch) workaround here and here

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@zzxslp
Copy link

zzxslp commented Aug 31, 2024

Hi @QiyaoWei If I am not mistaken FSDP is not compatible with generate - cc @pacman100 @SunMarc (as I think we"ve discussed this offline at some point but not 100% sure)

May I know why FSDP is not compatible with generate?

@QiyaoWei
Copy link
Author

I think the model(**dummy_batch) workaround is the way to go. Basically the FSDP model would need an additional forward pass for generate to run

@gante
Copy link
Member

gante commented Sep 5, 2024

@SunMarc do you know the status of generate 🤜 🤛 FSDP? If not working well together: should we reopen this issue, or is it being tracked somewhere else?

@ArthurZucker
Copy link
Collaborator

👁️

@ringohoffman
Copy link
Contributor

Related (and also auto-closed due to lack of activity):

@ringohoffman
Copy link
Contributor

ringohoffman commented Sep 13, 2024

RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

This error is solved by using use_orig_params=True:

fsdp_model = torch.distributed.fsdp.FullyShardedDataParallel(
    model,
    use_orig_params=True,
)

From the FSDP docs:

use_orig_params (bool) – ... True is required to use torch.compile(). ... (Default: False)

I am reasonably certain the reason for this error is from generate calling a torch.compile-wrapped function.

@ringohoffman
Copy link
Contributor

So after a lot longer than I would like to admit, I have uncovered all the gotchas of using generate with FSDP.

  1. As I mentioned above, torch.distributed.fsdp.FullyShardedDataParallel(use_orig_params=True) is required when instantiating your FSDP instance. Otherwise, you get the error The tensor has a non-zero number of elements, but its data is not allocated yet. It seems likely that generate is calling a torch.compile-wrapped function.
  2. Calls to generate must be inside a torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params(fsdp_model) context, otherwise you get the error "weight" must be 2-D from the input embeddings still being flattened by FSDP.
  3. Lastly, and most trickily, you must use generate(synced_gpus=True) when using differently-sized data across ranks. Otherwise, the different ranks will pause on different synchronization points leading to a deadlock.

Here is a minimum reproducible example to show these off:

OMP_NUM_THREADS=2 \
TOKENIZERS_PARALLELISM=false \
CUDA_VISIBLE_DEVICES=6,7 \
torchrun \
    --rdzv-backend=c10d \
    --rdzv-endpoint=localhost:0 \
    --nnodes=1 \
    --nproc-per-node=2 \
    fsdp_generate.py

fsdp_generate.py

import functools

import torch
import torch.distributed
import torch.distributed.fsdp
import torch.distributed.fsdp.wrap
import transformers
import transformers.models.gpt_neo.modeling_gpt_neo


def main() -> None:
    torch.distributed.init_process_group(world_size=2)
    device = torch.device(torch.distributed.get_rank())
    torch.cuda.set_device(device)

    pretrained_model_name_or_path = "EleutherAI/gpt-neo-125m"
    model = transformers.AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path,
        device_map=device,
    )
    assert isinstance(model, transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM)
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path
    )
    tokenizer.pad_token_id = tokenizer.eos_token_id

    fsdp_model = torch.distributed.fsdp.FullyShardedDataParallel(
        model,
        auto_wrap_policy=functools.partial(
            torch.distributed.fsdp.wrap.transformer_auto_wrap_policy,
            transformer_layer_cls={
                transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock
            },
        ),
        limit_all_gathers=True,
        use_orig_params=True,  # required to overcome the error "The tensor has a non-zero number of elements, but its data is not allocated yet" ... PreTrainedModel.generate is probably using some torch.compile-wrapped function
    )

    data_by_rank = {  # differently-sized causes FSDP to hang
        0: "Hello world!",
        1: "The quick brown fox jumps over the lazy dog."
    }

    batch = tokenizer(
        data_by_rank[torch.distributed.get_rank()],
        return_tensors="pt",
        return_attention_mask=True,
    ).to(device)

    with torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params(fsdp_model):  # required to overcome to the error ""weight" must be 2-D"
        generated_text = fsdp_model.module.generate(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            max_length=20,
            # synced_gpus=True,  # True is required to use differently sized data with FSDP + generate (current default is False)
        )

    torch.distributed.barrier()
    torch.distributed.destroy_process_group()


if __name__ == "__main__":
    main()

My takeaways from this:

  1. Improve the synged_gpus documentation

synced_gpus (`bool`, *optional*):
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
generating before other GPUs. Otherwise it"ll be set to `False`.

it makes no mention of FullyShardedDataParallel.

  1. Can synced_gpus=None default to True when using FSDP, the same as it does for DeepSpeed ZeRo-3?

default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
gen_kwargs["synced_gpus"] = (
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
)

if synced_gpus is None:
if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
synced_gpus = True
else:
synced_gpus = False

It would be nice to save people the trouble of figuring this out if we could automatically detect FSDP usage.

  1. Document use_orig_params and summon_full_params.

ringohoffman added a commit to ringohoffman/transformers that referenced this issue Sep 13, 2024
Fixes huggingface#30228

Related:

* pytorch/pytorch#100069
* pytorch/pytorch#123962

Similar to DeepSpeed ZeRO Stage 3, when using FSDP with multiple GPUs and differently sized data per rank, the ranks reach different synchronization points at the same time, leading to deadlock

To avoid this, we can automatically set synced_gpus to True if we detect that a PreTrainedModel is being managed by FSDP using _is_fsdp_managed_module, which was added in 2.0.0 for torch.compile: https://github.com/pytorch/pytorch/blob/v2.0.0/torch/distributed/fsdp/_dynamo_utils.py
muellerzr pushed a commit that referenced this issue Oct 10, 2024
…#33483)

* Default synced_gpus to True when using FullyShardedDataParallel

Fixes #30228

Related:

* pytorch/pytorch#100069
* pytorch/pytorch#123962

Similar to DeepSpeed ZeRO Stage 3, when using FSDP with multiple GPUs and differently sized data per rank, the ranks reach different synchronization points at the same time, leading to deadlock

To avoid this, we can automatically set synced_gpus to True if we detect that a PreTrainedModel is being managed by FSDP using _is_fsdp_managed_module, which was added in 2.0.0 for torch.compile: https://github.com/pytorch/pytorch/blob/v2.0.0/torch/distributed/fsdp/_dynamo_utils.py

* Remove test file

* ruff formatting

* ruff format

* Update copyright year

Co-authored-by: Arthur <[email protected]>

* Add test for FSDP-wrapped model generation

Before #33483, these tests would have hung for 10 minutes before crashing due to a timeout error

* Ruff format

* Move argparse import

* Remove barrier

I think this might cause more problems if one of the workers was killed

* Move import into function to decrease load time

#33483 (comment)

* Add test for accelerate and Trainer

#33483 (comment)

* Refactor imports

* Ruff format

* Use nullcontext

---------

Co-authored-by: Arthur <[email protected]>
@ShengYun-Peng
Copy link

Issue with FSDP + HuggingFace generate pytorch/pytorch#100069

Thank you so much! It"s working on 7b and 8b model on my side. However, I noticed an extremely high GPU memory consumption after summon_full_params. It also leads to OOM while inferencing with 70b models on 8 nodes. I can successfully finetune 70b on 8 nodes, so technically, inference should take fewer nodes. I"m curious if you have any work around methods.

@ringohoffman
Copy link
Contributor

Issue with FSDP + HuggingFace generate pytorch/pytorch#100069

Thank you so much! It"s working on 7b and 8b model on my side. However, I noticed an extremely high GPU memory consumption after summon_full_params. It also leads to OOM while inferencing with 70b models on 8 nodes. I can successfully finetune 70b on 8 nodes, so technically, inference should take fewer nodes. I"m curious if you have any work around methods.

Check out FSDP2:

def fsdp2_generate():
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
mesh = init_device_mesh("cuda", (torch.distributed.get_world_size(),))
for submodule in model.modules():
if isinstance(submodule, GPT2Block):
fully_shard(submodule, mesh=mesh)
fully_shard(model, mesh=mesh)
register_fsdp_forward_method(model, "generate")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device)
_ = model.generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
max_length=30,
)

It solves a lot of the problems with FSDP, like having to use summon_full_params.

BernardZach pushed a commit to BernardZach/transformers that referenced this issue Dec 5, 2024
…huggingface#33483)

* Default synced_gpus to True when using FullyShardedDataParallel

Fixes huggingface#30228

Related:

* pytorch/pytorch#100069
* pytorch/pytorch#123962

Similar to DeepSpeed ZeRO Stage 3, when using FSDP with multiple GPUs and differently sized data per rank, the ranks reach different synchronization points at the same time, leading to deadlock

To avoid this, we can automatically set synced_gpus to True if we detect that a PreTrainedModel is being managed by FSDP using _is_fsdp_managed_module, which was added in 2.0.0 for torch.compile: https://github.com/pytorch/pytorch/blob/v2.0.0/torch/distributed/fsdp/_dynamo_utils.py

* Remove test file

* ruff formatting

* ruff format

* Update copyright year

Co-authored-by: Arthur <[email protected]>

* Add test for FSDP-wrapped model generation

Before huggingface#33483, these tests would have hung for 10 minutes before crashing due to a timeout error

* Ruff format

* Move argparse import

* Remove barrier

I think this might cause more problems if one of the workers was killed

* Move import into function to decrease load time

huggingface#33483 (comment)

* Add test for accelerate and Trainer

huggingface#33483 (comment)

* Refactor imports

* Ruff format

* Use nullcontext

---------

Co-authored-by: Arthur <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants