Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Does TorchRec supports checking point / (load/save) #2534

Open
JacoCheung opened this issue Nov 4, 2024 · 1 comment
Open

[Question] Does TorchRec supports checking point / (load/save) #2534

JacoCheung opened this issue Nov 4, 2024 · 1 comment

Comments

@JacoCheung
Copy link

Hi, team, I would like to know how to load and dump a sharded embedding collection via state_dict. Basically

  1. How many files should I save? Should each rank have an exclusive sharding file or only single rank collectively gather the whole embedding and stores as one file? How should I handle the case where both DP and MP are applied?

  2. If each rank maintains a sharding file, how can I load and re-shard in a new distributed environment where the number of GPUs vary from the saved model.

  3. If there is one saved file, how should I load and re-shard especially in multi-node env?

It's more helpful if anyone can provide a sample code! Thanks!

@iamzainhuda
Copy link
Contributor

  1. You should have the checkpoint per rank, since we do not collectively gather the whole embedding. If you wanted to, you could do that and then reconstruct the sharded state dict by the original sharding plan. Although I wouldn't recommend this. You should be able to use torch.distributed.checkpoint utilities for TorchRec models.

  2. For changing the number of GPU's, you would need to understand how the sharding changes. Am I correct in understanding you would want to go from a model sharded on 8 GPUs to load onto 16 GPUs? Resharding here would be important, which you would also have to do yourself before you load. TorchRec doesn't have any utilities surrounding this.

  3. You can broadcast the parameters/state to the other ranks as you load, as a pre_load_state_dict_hook on top of DMP.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants