-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inference with FSDP during training affects checkpoints #34530
Comments
cc @muellerzr @SunMarc ! |
Thanks for the nice report ! This is indeed a very strange behavior. Could you try to see if you get the same model at the end with/without the callback. At first glance, it looks like with fsdp.FullyShardedDataParallel.summon_full_params(model) is the potential culprit. Could you try to just call alone in |
Hey @SunMarc! Just a note re:
I am following this issue and also replied to the HuggingFace forum. TL;DR when unsharding the model, only the |
Hi @SunMarc! Thanks a lot for looking into this! I confirm that I get different model weights depending on whether I use the callback or not. All three Safetensors files show up with I tried to run the generation in
I assume this behaviour is expected, based on this comment. I hope this is what you were asking, but do tell if I got it wrong. I looked a bit more into PyTorch's documentation for |
false |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
cc @XuehaiPan if you have some time ! |
Hi @pandrei7, I wonder have you ever tried to remove the |
Hey @XuehaiPan, thanks! I tried both options:
Unfortunately, both options give corrupted checkpoints. Here's the inference on the dummy dataset from @pandrei7's example:
The correct inference result should be:
|
Hi @XuehaiPan! Thanks for responding. I confirm what @alexandru-dinu said: I tried both approaches, but the issue persists. Predictions are correct while training, but the checkpoint I load from disk generates incorrect completions. I'm wondering if this might be an issue with the way I use PyTorch... Do you think I should try asking there as well? |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
System Info
Output from
transformers-cli env
:Relevant environment and library versions:
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Hello! I'm running into an issue with checkpoints saved when training an LLM with FSDP and the default HuggingFace trainer, if I also do inference during training. I provide code at the end of this post for clarity.
I also asked this on the forum before coming here, but I haven't found a solution yet.
What I'm trying to achieve
I want to write a callback to monitor model outputs on a validation set throughout the training process. This requires doing inference with
model.generate()
. Since I'm also using FSDP, I need to summon all weights on a single device, as described in this Github issue.My issue
The callback I provide below seems to work fine for evaluation, but it affects the checkpoints that get saved. Specifically, when unsharding the final checkpoint and trying to replicate the results I see from my training script, I get different, much worse results from the checkpoint.
To test this, I trained an LLM to memorize a simple phrase: "Two times 10 equals 20.". At the end of training, my callback reports the completions I expect, meaning the model trained well. However, if I load the checkpoint from disk and feed it the same prompts, I get this:
This does not happen if I don't run the callback during training. If I remove it, the checkpoint produced outputs the expected results:
To make extra sure, I also tried this experiment with DDP instead of FSDP (I removed the summon instruction). The DDP checkpoint is correct regardless of using my callback or not.
I believe this points to
summon_full_params
being the problem. Do you think this could be a problem with the library, or maybe with my implementation? Any ideas or advice? Thank you!Minimal example
main.py
fsdp.yaml
I run my code on Slurm, using this command:
Expected behavior
I would expect the checkpoint saved on disk to produce the same outputs as those shown by the script after training.
The text was updated successfully, but these errors were encountered: