[Draft] Fix optimizer state_dict compatibility between Apex's fused_adam and distributed_fused_adam #10750
+33
−2
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
The use case I am supporting is the ability to take model checkpoints trained with fused_adam optimizer and then resume training with Apex's distributed_fused_adam while maintaining optimizer state.
There is mismatch in the optimizer state_dict where MoE-specific keys, which are ShardedTensorFactory objects and not ShardedTensor's, are not caught in the rename_fp32_params function, causing the moe keys to fall under the missing and unexpected keys when loading the optimizer state_dict.
Additionally, when resuming training from a model using fused_adam, state_dict['param_groups'] does not contain any keys aside from params. The param_groups need to be pulled from state_dict['optimizer'].
I override load_state_dict in NeMo's MegatronDistributedFusedAdam and manipulate the optimizer's state_dict if we see state_dict['param_groups'] only contains params key.
This behavior can be reproduced by training any MoE model with fused_adam and then attempt to resume training with distributed_fused_adam.
Additionally,
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Collection: [Note which collection this PR will affect]
Changelog
Usage
# Add a code snippet demonstrating how to use this
GitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information