Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] FSDP with CPU offload consumes 1.65X more GPU memory when training models with most of the params frozen #91165

Open
pacman100 opened this issue Dec 20, 2022 · 20 comments
Labels
high priority module: fsdp oncall: distributed Add this issue/PR to distributed oncall triage queue triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@pacman100
Copy link

pacman100 commented Dec 20, 2022

🐛 Describe the bug

Context: We have more and more situations where a large part of the model that's being trained is frozen. As these are very large LLMs, we want to leverage FSDP with CPU offloading to fit such large model training with only a tiny fraction of training params on consumer GPUs. To this end, below is an example of finetuning bigscience/mt0-small using LoRA parameter efficient fine-tuning method with/without FSDP.

To fine-tune with FSDP:

  1. Following How do I freeze weights when using FSDP? huggingface/accelerate#807, to avoid AssertionError: expects all parameters to have same requires_grad, created a custom auto_wrap_policy such that the layers with trainable params are in separate FSDP units than those which are frozen. The result model along with FSDP options are given below. We are using Accelerate's FSDP integration . the trainable params have lora_ prefix:
FullyShardedDataParallelPlugin(sharding_strategy=<ShardingStrategy.FULL_SHARD: 1>, backward_prefetch=<BackwardPrefetch.BACKWARD_PRE: 1>, mixed_precision_policy=None, auto_wrap_policy=functools.partial(<function _or_policy at 0x7f022c1a8430>, policies=[functools.partial(<function lambda_auto_wrap_policy at 0x7f022c1a8160>, lambda_fn=<function fsdp_auto_wrap_policy.<locals>.lambda_policy_fn at 0x7f01ec31e3b0>), functools.partial(<function transformer_auto_wrap_policy at 0x7f022c1a8310>, transformer_layer_cls=(<class 'pet.tuners.prefix_tuning.PrefixEncoder'>, <class 'pet.tuners.p_tuning.PromptEncoder'>, <class 'pet.tuners.prompt_tuning.PromptEmbedding'>, <class 'transformers.models.t5.modeling_t5.T5Block'>))]), cpu_offload=CPUOffload(offload_params=True), ignored_modules=None, state_dict_type=<StateDictType.FULL_STATE_DICT: 1>, state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), limit_all_gathers=False)


FullyShardedDataParallel(
  (_fsdp_wrapped_module): PETModelForSeq2SeqLM(
    (base_model): LoRAModel(
      (model): MT5ForConditionalGeneration(
        (shared): Embedding(250112, 512)
        (encoder): T5Stack(
          (embed_tokens): Embedding(250112, 512)
          (block): ModuleList(
            (0): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                      (relative_attention_bias): Embedding(32, 6)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (1): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (2): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (3): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (4): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (5): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (6): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (7): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
          )
          (final_layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (decoder): T5Stack(
          (embed_tokens): Embedding(250112, 512)
          (block): ModuleList(
            (0): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                      (relative_attention_bias): Embedding(32, 6)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (1): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (2): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (3): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (4): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (5): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (6): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (7): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
          )
          (final_layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (lm_head): Linear(in_features=512, out_features=250112, bias=False)
      )
    )
  )
)
  1. The number of trainable params are given below:
trainable params: 344064 || all params: 300520832 || trainable%: 0.11448923447676333
  1. Now, the issue is that in comparison to plain Pytorch, FSDP consumes 1.65X more GPU memory instead of reducing the same by a large amount (expectation) while also greatly increasing the memory consumed on CPU. Below are the screenshots for the same. The hardware used is 1 A100 80GB GPU.

Plain PyTorch:
Screenshot 2022-12-20 at 3 45 16 PM

FSDP Full Shard with CPU offloading:
Screenshot 2022-12-20 at 4 14 04 PM

  1. When trying to use FSDP with CPU offloading using bigscience/mt0-xxl model (13B params) on a A100 80GB GPU it results in OOM GPU error whereas Plain Pytorch consumes 56GB GPU memory.

  2. Expected behaviour: Efficiently deal with frozen weights during training such that large models could be offloaded on CPUs/sharded across GPUs properly with storage of the optimizer state only for the trainable parameters, e.g., we can see that using plain PyTorch, mt0-xxl (13B params) model takes up 56GB on GPU, now, it would be really helpful if one could do CPU offloading such that training could work on a 16GB or 24GB GPU using FSDP with CPU offloading.

Versions

Collecting environment information...
PyTorch version: 1.14.0.dev20221117 cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.25.0
Libc version: glibc-2.31

Python version: 3.10.4 (main, Mar 31 2022, 08:41:55) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-125-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA DGX Display
GPU 4: NVIDIA A100-SXM4-80GB

Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.7.1
[pip3] numpy==1.23.4
[pip3] torch==1.14.0.dev20221117 cu117
[pip3] torchaudio==0.14.0.dev20221117
[pip3] torchtriton==2.0.0 0d7e753227
[pip3] torchvision==0.15.0.dev20221117
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py310h7f8727e_0
[conda] mkl_fft 1.3.1 py310hd6ae3a3_0
[conda] mkl_random 1.2.2 py310h00e6091_0
[conda] msgpack-numpy 0.4.7.1 pypi_0 pypi
[conda] numpy 1.23.4 pypi_0 pypi
[conda] pytorch-cuda 11.7 h67b0de4_0 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torch 1.14.0.dev20221117 cu117 pypi_0 pypi
[conda] torchaudio 0.14.0.dev20221117 py310_cu117 pytorch-nightly
[conda] torchtriton 2.0.0 0d7e753227 pypi_0 pypi
[conda] torchvision 0.15.0.dev20221117 py310_cu117 pytorch-nightly

cc @ezyang @gchanan @zou3519 @kadeng @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin

@awgu awgu added oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: fsdp labels Dec 20, 2022
@awgu
Copy link
Contributor

awgu commented Dec 20, 2022

Thanks for providing so many details! This helps us out a lot.

We may need more time to investigate why CPU offload consumes so much more GPU memory. For now, have you guys tried setting limit_all_gathers=True? We plan to default to this soon. Without this, the memory usage may be unexpectedly high (as long as the CPU thread runs faster than the GPU execution, which is almost always). This may or may not be related to this issue.

IIUC, CPU offload was required for the 1T parameter training run (cc: @rohan-varma @zhaojuanmao), so the memory savings were proven there. Conceptually, I am not seeing how frozen parameters/layers affect the GPU memory usage aside from the optimizer state as you mentioned. In that case, if you think it is helpful, could you verify what optimizer states are being created when using FSDP? Are you seeing that optimizer states are being created even for FlatParameters that do not require gradient?

@pacman100
Copy link
Author

Hello @awgu, thanks for your quick reply. Apart from the optimizer states only for only the tiny portion of trainable params, offloading to CPU should further reduce GPU memory usage. In case of CPU offloading, the expected behaviour would be to only move model shards (T5 block one at a time as they belong to separate FSDP units) from CPU to GPU, do the forward pass computation and then GPU to CPU (similar to B/w pass), hence the max memory consumption on GPU should be that of a single T5 block only. Let me know if there are gaps in my understanding.

@awgu
Copy link
Contributor

awgu commented Apr 27, 2023

@pacman100 Do you have a recommendation for how to try to reproduce this?

@pengyanghua
Copy link

@awgu From the source code, it seems that another problem of FSDP is that the freezed parameters are not deallocated until the end of the backward. If most parameters are frozen in this case, then the GPU would have a nearly whole copy of the model parameters.

@awgu
Copy link
Contributor

awgu commented May 11, 2023

@awgu From the source code, it seems that another problem of FSDP is that the freezed parameters are not deallocated until the end of the backward. If most parameters are frozen in this case, then the GPU would have a nearly whole copy of the model parameters.

This is unfortunately a known problem :(

We would need to add special hooks to reshard the frozen parameters after they are used for gradient computation. We have not had any bandwidth to look into this though.

@hscspring
Copy link

@awgu From the source code, it seems that another problem of FSDP is that the freezed parameters are not deallocated until the end of the backward. If most parameters are frozen in this case, then the GPU would have a nearly whole copy of the model parameters.

This is unfortunately a known problem :(

We would need to add special hooks to reshard the frozen parameters after they are used for gradient computation. We have not had any bandwidth to look into this though.

Is it possible to use a sized_policy to deal with those frozen parameters?

@awgu
Copy link
Contributor

awgu commented May 22, 2023

@hscspring I do not think so. The issue is more fundamental than simply choosing the wrapping policy and comes from when a parameter corresponds to a module whose output activations require gradient but itself does not require gradient -- in that case, FSDP's pre-backward hook runs but its post-backward hook does not.

@hscspring
Copy link

@awgu thanks for the quick reply. u r right, wrapping policy only effects how to shard the model.
hope it'll be optimized in the next version.;)

@pengyanghua
Copy link

We would need to add special hooks to reshard the frozen parameters after they are used for gradient computation. We have not had any bandwidth to look into this though.

@awgu any idea about how to add the hook? e.g., torch.autograd.Function.

@awgu
Copy link
Contributor

awgu commented May 23, 2023

@pengyanghua I am testing #101982 out. I am seeing memory savings in my experiments, but @HamidShojanazeri is not at the moment. We are investigating.

@HamidShojanazeri
Copy link

@awgu , just to update the thread as we discussed offline, I could also observe the memory savings as well.

@awgu
Copy link
Contributor

awgu commented Jun 15, 2023

Since #101982 landed, I would recommend people try things out again (e.g. with a nightly). There should be memory savings now.

@pacman100
Copy link
Author

Thank you @awgu and @HamidShojanazeri. Will be trying this out this week and report back. This should work with CPU offloading too, right?

@awgu
Copy link
Contributor

awgu commented Jun 15, 2023

@pacman100 Yes, it should work with CPU offloading. (It turns out that CPU offloading is connected to "freeing the parameters", so once we fixed "freeing the parameters", the parameters should be offloaded to CPU as well if enabled.)

@pacman100
Copy link
Author

pacman100 commented Jun 21, 2023

When trying to use FSDP with CPU offloading using bigscience/mt0-xxl model (13B params) on a A100 80GB GPU it results in OOM GPU error whereas Plain Pytorch consumes 56GB GPU memory.

Tried the latest nightly '2.1.0.dev20230620'

Using 2 A100 GPUs, able to run bigscience/mt0-xxl but doesn't result in any significant memory savings at all compared to Plain PyTorch while consuming a whole lot of CPU memory.

GPU Memory before entering the train : 0
GPU Memory consumed at the end of the train (end-begin): 315
GPU Peak Memory consumed during the train (max-begin): 53360
GPU Total Peak Memory consumed during the train (max): 53360
CPU Memory before entering the train : 26816
CPU Memory consumed at the end of the train (end-begin): 33817
CPU Peak Memory consumed during the train (max-begin): 57895
CPU Total Peak Memory consumed during the train (max): 84711

As I stated earlier, the expected behaviour is

Expected behaviour: Efficiently deal with frozen weights during training such that large models could be offloaded on CPUs/sharded across GPUs properly with storage of the optimizer state only for the trainable parameters, e.g., we can see that using plain PyTorch, mt0-xxl (13B params) model takes up 56GB on GPU, now, it would be really helpful if one could do CPU offloading such that training could work on a 16GB or 24GB GPU using FSDP with CPU offloading.

Steps to reproduce:

  1. Tried the latest nightly '2.1.0.dev20230620'
  2. install peft, transformers and accelerate from source.
  3. remove line 1316 from accelerator.py as FSDP in nightly doesn't support it.
kwargs = {
                        "sharding_strategy": fsdp_plugin.sharding_strategy,
                        "cpu_offload": fsdp_plugin.cpu_offload,
                        "auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
                        "mixed_precision": fsdp_plugin.mixed_precision_policy,
                        "sync_module_states": fsdp_plugin.sync_module_states,
                        "backward_prefetch": fsdp_plugin.backward_prefetch,
                        "forward_prefetch": fsdp_plugin.forward_prefetch,
                        "use_orig_params": fsdp_plugin.use_orig_params,
                        "param_init_fn": fsdp_plugin.param_init_fn,
                        "ignored_modules": fsdp_plugin.ignored_modules,
-                        "ignored_parameters": fsdp_plugin.ignored_parameters,
                        "limit_all_gathers": fsdp_plugin.limit_all_gathers,
                        "device_id": self.device,
                    }
                    model = FSDP(model, **kwargs)
  1. Follow https://github.com/huggingface/peft#caveats related to FSDP. Use the below example to track GPU and CPU memory usage:

FSDP config:

compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_forward_prefetch: false
  fsdp_offload_params: true
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: false
  fsdp_transformer_layer_cls_to_wrap: MT5Block
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Code

import gc
import os
import sys
import threading

import numpy as np
import psutil
import torch
from accelerate import Accelerator
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup, set_seed

from peft import LoraConfig, TaskType, get_peft_model
from peft.utils.other import fsdp_auto_wrap_policy

# Converting Bytes to Megabytes
def b2mb(x):
    return int(x / 2**20)


# This context manager is used to track the peak memory usage of the process
class TorchTracemalloc:
    def __enter__(self):
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
        self.begin = torch.cuda.memory_allocated()
        self.process = psutil.Process()

        self.cpu_begin = self.cpu_mem_used()
        self.peak_monitoring = True
        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
        peak_monitor_thread.daemon = True
        peak_monitor_thread.start()
        return self

    def cpu_mem_used(self):
        """get resident set size memory for the current process"""
        return self.process.memory_info().rss

    def peak_monitor_func(self):
        self.cpu_peak = -1

        while True:
            self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)

            # can't sleep or will not catch the peak right (this comment is here on purpose)
            # time.sleep(0.001) # 1msec

            if not self.peak_monitoring:
                break

    def __exit__(self, *exc):
        self.peak_monitoring = False

        gc.collect()
        torch.cuda.empty_cache()
        self.end = torch.cuda.memory_allocated()
        self.peak = torch.cuda.max_memory_allocated()
        self.used = b2mb(self.end - self.begin)
        self.peaked = b2mb(self.peak - self.begin)

        self.cpu_end = self.cpu_mem_used()
        self.cpu_used = b2mb(self.cpu_end - self.cpu_begin)
        self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin)
        # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")


def main():
    accelerator = Accelerator()
    # model_name_or_path = "bigscience/T0_3B"
    model_name_or_path = "bigscience/mt0-xxl"#"facebook/bart-large"
    dataset_name = "twitter_complaints"
    peft_config = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
    )
    text_column = "Tweet text"
    label_column = "text_label"
    lr = 3e-3
    num_epochs = 5
    batch_size = 8
    seed = 42
    do_test = False
    set_seed(seed)

    dataset = load_dataset("ought/raft", dataset_name)
    classes = [k.replace("_", " ") for k in dataset["train"].features["Label"].names]
    dataset = dataset.map(
        lambda x: {"text_label": [classes[label] for label in x["Label"]]},
        batched=True,
        num_proc=1,
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    target_max_length = max([len(tokenizer(class_label)["input_ids"]) for class_label in classes])

    def preprocess_function(examples):
        inputs = examples[text_column]
        targets = examples[label_column]
        model_inputs = tokenizer(inputs, truncation=True)
        labels = tokenizer(
            targets, max_length=target_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        labels = labels["input_ids"]
        labels[labels == tokenizer.pad_token_id] = -100
        model_inputs["labels"] = labels
        return model_inputs

    with accelerator.main_process_first():
        processed_datasets = dataset.map(
            preprocess_function,
            batched=True,
            num_proc=1,
            remove_columns=dataset["train"].column_names,
            load_from_cache_file=True,
            desc="Running tokenizer on dataset",
        )
    accelerator.wait_for_everyone()

    train_dataset = processed_datasets["train"]
    eval_dataset = processed_datasets["train"]
    test_dataset = processed_datasets["test"]

    def collate_fn(examples):
        return tokenizer.pad(examples, padding="longest", return_tensors="pt")

    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=batch_size, pin_memory=True
    )
    eval_dataloader = DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=batch_size, pin_memory=True)
    test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size, pin_memory=True)

    # creating model
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    # optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    # lr scheduler
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=(len(train_dataloader) * num_epochs),
    )
    
    if getattr(accelerator.state, "fsdp_plugin", None) is not None:
        accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model)


    model, train_dataloader, eval_dataloader, test_dataloader, optimizer, lr_scheduler = accelerator.prepare(
        model, train_dataloader, eval_dataloader, test_dataloader, optimizer, lr_scheduler
    )
    accelerator.print(model)

    for epoch in range(num_epochs):
        with TorchTracemalloc() as tracemalloc:
            model.train()
            total_loss = 0
            for step, batch in enumerate(tqdm(train_dataloader)):
                outputs = model(**batch)
                loss = outputs.loss
                total_loss  = loss.detach().float()
                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
        # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
        accelerator.print("GPU Memory before entering the train : {}".format(b2mb(tracemalloc.begin)))
        accelerator.print("GPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.used))
        accelerator.print("GPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.peaked))
        accelerator.print(
            "GPU Total Peak Memory consumed during the train (max): {}".format(
                tracemalloc.peaked   b2mb(tracemalloc.begin)
            )
        )

        accelerator.print("CPU Memory before entering the train : {}".format(b2mb(tracemalloc.cpu_begin)))
        accelerator.print("CPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.cpu_used))
        accelerator.print("CPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.cpu_peaked))
        accelerator.print(
            "CPU Total Peak Memory consumed during the train (max): {}".format(
                tracemalloc.cpu_peaked   b2mb(tracemalloc.cpu_begin)
            )
        )
        train_epoch_loss = total_loss / len(train_dataloader)
        train_ppl = torch.exp(train_epoch_loss)
        accelerator.print(f"{epoch=}: {train_ppl=} {train_epoch_loss=}")

if __name__ == "__main__":
    main()


output logs:

trainable params: 9,437,184 || all params: 12,930,494,464 || trainable%: 0.072983937515106
None
trainable params: 9,437,184 || all params: 12,930,494,464 || trainable%: 0.072983937515106
Found cached dataset json (/raid/sourab/.cache/huggingface/datasets/json/default-45aab44cd4d75e35/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
100%|██████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 1410.32it/s]
Found cached dataset json (/raid/sourab/.cache/huggingface/datasets/json/default-45aab44cd4d75e35/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
100%|██████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 1349.52it/s]
FSDP Warning: When using FSDP, it is efficient and recommended to call prepare for the model before creating the optimizer     
FullyShardedDataParallel(                                                                                                      
  (_fsdp_wrapped_module): PeftModelForSeq2SeqLM(
    (base_model): LoraModel(
      (model): MT5ForConditionalGeneration(
        (shared): Embedding(250112, 4096)
        (encoder): MT5Stack(
          (embed_tokens): Embedding(250112, 4096)
          (block): ModuleList(
            (0): FullyShardedDataParallel(
              (_fsdp_wrapped_module): MT5Block(
                (layer): ModuleList(
                  (0): MT5LayerSelfAttention(
                    (SelfAttention): MT5Attention(
                      (q): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (k): Linear(in_features=4096, out_features=4096, bias=False)
                      (v): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (o): Linear(in_features=4096, out_features=4096, bias=False)
                      (relative_attention_bias): Embedding(32, 64)
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): MT5LayerFF(
                    (DenseReluDense): MT5DenseGatedActDense(
                      (wi_0): Linear(in_features=4096, out_features=10240, bias=False)
                      (wi_1): Linear(in_features=4096, out_features=10240, bias=False)
                      (wo): Linear(in_features=10240, out_features=4096, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (1-23): 23 x FullyShardedDataParallel(
              (_fsdp_wrapped_module): MT5Block(
                (layer): ModuleList(
                  (0): MT5LayerSelfAttention(
                    (SelfAttention): MT5Attention(
                      (q): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (k): Linear(in_features=4096, out_features=4096, bias=False)
                      (v): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (o): Linear(in_features=4096, out_features=4096, bias=False)
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): MT5LayerFF(
                    (DenseReluDense): MT5DenseGatedActDense(
                      (wi_0): Linear(in_features=4096, out_features=10240, bias=False)
                      (wi_1): Linear(in_features=4096, out_features=10240, bias=False)
                      (wo): Linear(in_features=10240, out_features=4096, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
          )
          (final_layer_norm): MT5LayerNorm()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (decoder): MT5Stack(
          (embed_tokens): Embedding(250112, 4096)
          (block): ModuleList(
            (0): FullyShardedDataParallel(
              (_fsdp_wrapped_module): MT5Block(
                (layer): ModuleList(
                  (0): MT5LayerSelfAttention(
                    (SelfAttention): MT5Attention(
                      (q): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (k): Linear(in_features=4096, out_features=4096, bias=False)
                      (v): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (o): Linear(in_features=4096, out_features=4096, bias=False)
                      (relative_attention_bias): Embedding(32, 64)
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): MT5LayerCrossAttention(
                    (EncDecAttention): MT5Attention(
                      (q): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (k): Linear(in_features=4096, out_features=4096, bias=False)
                      (v): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (o): Linear(in_features=4096, out_features=4096, bias=False)
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): MT5LayerFF(
                    (DenseReluDense): MT5DenseGatedActDense(
                      (wi_0): Linear(in_features=4096, out_features=10240, bias=False)
                      (wi_1): Linear(in_features=4096, out_features=10240, bias=False)
                      (wo): Linear(in_features=10240, out_features=4096, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (1-23): 23 x FullyShardedDataParallel(
              (_fsdp_wrapped_module): MT5Block(
                (layer): ModuleList(
                  (0): MT5LayerSelfAttention(
                    (SelfAttention): MT5Attention(
                      (q): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (k): Linear(in_features=4096, out_features=4096, bias=False)
                      (v): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (o): Linear(in_features=4096, out_features=4096, bias=False)
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): MT5LayerCrossAttention(
                    (EncDecAttention): MT5Attention(
                      (q): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (k): Linear(in_features=4096, out_features=4096, bias=False)
                      (v): Linear(
                        in_features=4096, out_features=4096, bias=False
                        (lora_dropout): ModuleDict(
                          (default): Dropout(p=0.1, inplace=False)
                        )
                        (lora_A): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=4096, out_features=8, bias=False)
                          )
                        )
                        (lora_B): ModuleDict(
                          (default): FullyShardedDataParallel(
                            (_fsdp_wrapped_module): Linear(in_features=8, out_features=4096, bias=False)
                          )
                        )
                        (lora_embedding_A): ParameterDict()
                        (lora_embedding_B): ParameterDict()
                      )
                      (o): Linear(in_features=4096, out_features=4096, bias=False)
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): MT5LayerFF(
                    (DenseReluDense): MT5DenseGatedActDense(
                      (wi_0): Linear(in_features=4096, out_features=10240, bias=False)
                      (wi_1): Linear(in_features=4096, out_features=10240, bias=False)
                      (wo): Linear(in_features=10240, out_features=4096, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): MT5LayerNorm()
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
          )
          (final_layer_norm): MT5LayerNorm()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (lm_head): Linear(in_features=4096, out_features=250112, bias=False)
      )
    )
  )
)
/home/sourab/miniconda3/envs/ml/lib/python3.11/site-packages/torch/cuda/memory.py:307: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
/home/sourab/miniconda3/envs/ml/lib/python3.11/site-packages/torch/cuda/memory.py:307: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
GPU Memory before entering the train : 0
GPU Memory consumed at the end of the train (end-begin): 315
GPU Peak Memory consumed during the train (max-begin): 53360
GPU Total Peak Memory consumed during the train (max): 53360
CPU Memory before entering the train : 26816
CPU Memory consumed at the end of the train (end-begin): 33817
CPU Peak Memory consumed during the train (max-begin): 57895
CPU Total Peak Memory consumed during the train (max): 84711
epoch=0: train_ppl=tensor(16.8360, device='cuda:0') train_epoch_loss=tensor(2.8235, device='cuda:0')

@HongyanZhi
Copy link

@pacman100 Can you please share the manually warp code for lora? I encounter the same problem. Many Thanks!

@SapirW
Copy link

SapirW commented Oct 1, 2023

@pacman100 Is this issue solved in recent pytorch versions?
Should I not bother with trying FSDP for reducing memory in a fine-tuning task with most parameters frozen?

@vTuanpham
Copy link

vTuanpham commented Nov 4, 2023

Still having OOM issue with CPU offload on a 7b model in mixed precision

@hanyin88
Copy link

hanyin88 commented Dec 26, 2023

Hello, come across this important discussion. In the latest PyTorch Nightlies (for example in the LLaMA-receipes implementation), runing PEFT with FSDP indeeds saves memory. For example, when using a mini-batch size of 4 to train 13B model on 4 GPU (with pure BF16), VRAM per GPU is 25G (PEFT) vs 55g (full model). Could we confirm this issue has resolved in the latest PyTorch Nightlies? Many thanks.

Also note to future readers, there are some excellent discussion on this led by @awgu:
#104690
#100945
#104690

Also looking at LLaMA-recipes implementation, seems they largely followed the rec from this post. They will use the customized lambda_policy_fn only when using peft.

@vikram71198
Copy link

@awgu There's a lot of issues open regarding FSDP PEFT. I've been looking at a lot of these discussion threads and unfortunately even with CPU Offloading and a wrapping policy suggested here, I'm still facing persistent OOM issues. This was another issue we'd like to hear updates on.

Is it possible to consolidate everything regarding this topic (current workarounds, PT nightlies, current problems, WIP if any) either on this thread or in a new issue on this repo?

That would really help a lot of us currently struggling to evade OOMs with FSDP PEFT.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: fsdp oncall: distributed Add this issue/PR to distributed oncall triage queue triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests