Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference with FSDP during training affects checkpoints #34530

Closed
2 of 4 tasks
pandrei7 opened this issue Oct 31, 2024 · 11 comments
Closed
2 of 4 tasks

Inference with FSDP during training affects checkpoints #34530

pandrei7 opened this issue Oct 31, 2024 · 11 comments
Labels

Comments

@pandrei7
Copy link

System Info

Output from transformers-cli env:

- `transformers` version: 4.45.2
- Platform: Linux-6.1.0-21-cloud-amd64-x86_64-with-glibc2.36
- Python version: 3.12.5
- Huggingface_hub version: 0.25.0
- Safetensors version: 0.4.5
- Accelerate version: 1.0.1
- Accelerate config:    not found
- PyTorch version (GPU?): 2.5.0 cu124 (True)
- Tensorflow version (GPU?): 2.17.0 (False)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: just using the Trainer, and running with accelerate
- Using GPU in script?: I'm running on GPUs
- GPU type: NVIDIA H100 80GB HBM3

Relevant environment and library versions:

Linux Debian 6.1.90-1
CUDA version: 12.4

accelerate==1.0.1
datasets==3.0.1
torch==2.4.1
torchaudio==2.4.1
torchvision==0.19.1
transformers==4.45.2

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hello! I'm running into an issue with checkpoints saved when training an LLM with FSDP and the default HuggingFace trainer, if I also do inference during training. I provide code at the end of this post for clarity.

I also asked this on the forum before coming here, but I haven't found a solution yet.

What I'm trying to achieve

I want to write a callback to monitor model outputs on a validation set throughout the training process. This requires doing inference with model.generate(). Since I'm also using FSDP, I need to summon all weights on a single device, as described in this Github issue.

My issue

The callback I provide below seems to work fine for evaluation, but it affects the checkpoints that get saved. Specifically, when unsharding the final checkpoint and trying to replicate the results I see from my training script, I get different, much worse results from the checkpoint.

To test this, I trained an LLM to memorize a simple phrase: "Two times 10 equals 20.". At the end of training, my callback reports the completions I expect, meaning the model trained well. However, if I load the checkpoint from disk and feed it the same prompts, I get this:

# With callback
# Outputs from the training script, after training.
"Two"                 -> "times 10 equals 20."
"Two times"           -> "10 equals 20."
"Two times 10"        -> "equals 20."
"Two times 10 equals" -> "20."
# Outputs from the checkpoint loaded from disk.
"Two"                 -> "               "
"Two times"           -> "equals               "
"Two times 10"        -> "               "
"Two times 10 equals" -> "               "

This does not happen if I don't run the callback during training. If I remove it, the checkpoint produced outputs the expected results:

# Without callback
# Outputs from the checkpoint loaded from disk.
"Two"                 -> "times 10 equals 20."
"Two times"           -> "10 equals 20."
"Two times 10"        -> "equals 20."
"Two times 10 equals" -> "20."

To make extra sure, I also tried this experiment with DDP instead of FSDP (I removed the summon instruction). The DDP checkpoint is correct regardless of using my callback or not.

# With DDP
# Outputs from the training script, after training.
"Two"                 -> "times 10 equals 20."
"Two times"           -> "10 equals 20."
"Two times 10"        -> "equals 20."
"Two times 10 equals" -> "20."
# Outputs from the checkpoint loaded from disk.
"Two"                 -> "times 10 equals 20."
"Two times"           -> "10 equals 20."
"Two times 10"        -> "equals 20."
"Two times 10 equals" -> "20."

I believe this points to summon_full_params being the problem. Do you think this could be a problem with the library, or maybe with my implementation? Any ideas or advice? Thank you!

Minimal example

main.py
from typing import cast

import accelerate
import datasets
import torch
import transformers
from torch.distributed import fsdp


class ValidCallback(transformers.TrainerCallback):
    def __init__(self, tokenizer: transformers.PreTrainedTokenizerBase, dataset: datasets.Dataset) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.dataset = dataset

    def on_epoch_end(
        self,
        args: transformers.TrainingArguments,
        state: transformers.TrainerState,
        control: transformers.TrainerControl,
        **kwargs,
    ) -> None:
        if state.epoch is None or int(state.epoch) % 25 != 0:
            return
        model = cast(transformers.PreTrainedModel, kwargs["model"])
        with torch.no_grad():
            self.run(model)

    @torch.no_grad()
    def run(self, model: transformers.PreTrainedModel) -> None:
        model.eval()

        for batch in self.dataset.iter(batch_size=7):
            encoding = self.tokenizer(batch["text"], return_tensors="pt", padding=True).to(model.device)

            with fsdp.FullyShardedDataParallel.summon_full_params(model):
                outputs = model.generate(
                    inputs=encoding.input_ids,
                    attention_mask=encoding.attention_mask,
                    pad_token_id=self.tokenizer.eos_token_id,
                    max_new_tokens=16,
                    do_sample=False,
                )

            predictions = self.tokenizer.batch_decode(
                outputs[:, encoding.input_ids.shape[1] :],  # Skip the returned prompt.
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )

            if accelerate.PartialState().is_main_process:
                print(predictions)


def main() -> None:
    # Load model and tokenizer.
    checkpoint = "mistralai/Mistral-7B-v0.3"
    tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
    tokenizer.padding_side = "left"
    if not tokenizer.pad_token:
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
    model.resize_token_embeddings(len(tokenizer))

    # Load and prepare a toy dataset.
    def tokenize_function(examples):
        tokenized = tokenizer(examples["text"], max_length=32, padding="max_length", truncation=True)
        tokenized["labels"] = cast(list, tokenized["input_ids"]).copy()
        return tokenized

    train_dataset = datasets.Dataset.from_dict({"text": ["Two times 10 equals 20."] * 100})
    valid_dataset = datasets.Dataset.from_dict(
        {"text": ["Two", "Two times", "Two times 10", "Two times 10 equals", "Two times 10 equals 20."]}
    )
    train_dataset = train_dataset.map(
        tokenize_function, batched=True, remove_columns=list(train_dataset.features)
    )

    # Train.
    trainer = transformers.Trainer(
        model=model,
        train_dataset=train_dataset,
        args=transformers.TrainingArguments(
            output_dir="./output-minimal",
            save_strategy="steps",
            save_steps=1_000_000,
            overwrite_output_dir=True,
            remove_unused_columns=False,
            optim="adamw_torch_fused",
            bf16=True,
            learning_rate=1e-2,
            num_train_epochs=100,
            per_device_train_batch_size=1,
            ddp_timeout=9999999,
            report_to=[],
        ),
        callbacks=[
            ValidCallback(tokenizer, valid_dataset),
        ],
    )
    trainer.train()


if __name__ == "__main__":
    main()
fsdp.yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
  fsdp_activation_checkpointing: false
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

I run my code on Slurm, using this command:

srun bash -c "accelerate launch \
    --config_file fsdp.yaml \
    --main_process_ip $(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) \
    --main_process_port 6000 \
    --machine_rank \$SLURM_PROCID \
    main.py"

Expected behavior

I would expect the checkpoint saved on disk to produce the same outputs as those shown by the script after training.

@pandrei7 pandrei7 added the bug label Oct 31, 2024
@Rocketknight1
Copy link
Member

cc @muellerzr @SunMarc !

@SunMarc
Copy link
Member

SunMarc commented Nov 5, 2024

Thanks for the nice report ! This is indeed a very strange behavior. Could you try to see if you get the same model at the end with/without the callback. At first glance, it looks like with fsdp.FullyShardedDataParallel.summon_full_params(model) is the potential culprit. Could you try to just call alone in on_epoch_end ?

@alexandru-dinu
Copy link

alexandru-dinu commented Nov 5, 2024

Hey @SunMarc! Just a note re:

Could you try to see if you get the same model at the end with/without the callback.

I am following this issue and also replied to the HuggingFace forum. TL;DR when unsharding the model, only the *.safetensors file differ between runs with and without the callback -- so we don't get the same model.

@pandrei7
Copy link
Author

pandrei7 commented Nov 6, 2024

Hi @SunMarc! Thanks a lot for looking into this!

I confirm that I get different model weights depending on whether I use the callback or not. All three Safetensors files show up with diff.

I tried to run the generation in on_epoch_end without calling summon_full_params, but I get this error when I reach model.generate:

RuntimeError: 'weight' must be 2-D

I assume this behaviour is expected, based on this comment. I hope this is what you were asking, but do tell if I got it wrong.

I looked a bit more into PyTorch's documentation for summon_full_params, and tried setting writeback=False, just to make sure. But it has no effect: predictions after training look fine, but the checkpoint is wrong.

@Tataaa-cans
Copy link

false

Copy link

github-actions bot commented Dec 2, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@SunMarc
Copy link
Member

SunMarc commented Dec 2, 2024

cc @XuehaiPan if you have some time !

@XuehaiPan
Copy link
Contributor

    def on_epoch_end(
        self,
        args: transformers.TrainingArguments,
        state: transformers.TrainerState,
        control: transformers.TrainerControl,
        **kwargs,
    ) -> None:
        if state.epoch is None or int(state.epoch) % 25 != 0:
            return
        model = cast(transformers.PreTrainedModel, kwargs["model"])
        with torch.no_grad():
            self.run(model)

    @torch.no_grad()
    def run(self, model: transformers.PreTrainedModel) -> None:
        model.eval()

        for batch in self.dataset.iter(batch_size=7):
            encoding = self.tokenizer(batch["text"], return_tensors="pt", padding=True).to(model.device)

            with fsdp.FullyShardedDataParallel.summon_full_params(model):
                outputs = model.generate(
                    inputs=encoding.input_ids,
                    attention_mask=encoding.attention_mask,
                    pad_token_id=self.tokenizer.eos_token_id,
                    max_new_tokens=16,
                    do_sample=False,
                )

            predictions = self.tokenizer.batch_decode(
                outputs[:, encoding.input_ids.shape[1] :],  # Skip the returned prompt.
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )

            if accelerate.PartialState().is_main_process:
                print(predictions)

Hi @pandrei7, I wonder have you ever tried to remove the model.eval() statement or change the model back to training mode after validation?

@alexandru-dinu
Copy link

alexandru-dinu commented Dec 3, 2024

Hey @XuehaiPan, thanks! I tried both options:

  1. removing model.eval() call
  2. keeping model.eval() and adding model.train() at the end of the callback

Unfortunately, both options give corrupted checkpoints. Here's the inference on the dummy dataset from @pandrei7's example:

  1. ['.20 0 0 20 20 0 ', '.0.20 0 0 20 0 0', '0 0 0 0 0 0 0 0', '0 0 0 0 0 0 0 0 ', '0 0 0 0 0 0 0 0 ']
  2. ['times 0...... ', '0.......', 'equals.....0 equals. ', '.....0 equals..', '...0 equals....']

The correct inference result should be:

"Two"                 -> "times 10 equals 20."
"Two times"           -> "10 equals 20."
"Two times 10"        -> "equals 20."
"Two times 10 equals" -> "20."

@pandrei7
Copy link
Author

pandrei7 commented Dec 4, 2024

Hi @XuehaiPan! Thanks for responding.

I confirm what @alexandru-dinu said: I tried both approaches, but the issue persists. Predictions are correct while training, but the checkpoint I load from disk generates incorrect completions.

I'm wondering if this might be an issue with the way I use PyTorch... Do you think I should try asking there as well?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

6 participants