You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, team, I would like to know how to load and dump a sharded embedding collection via state_dict. Basically
How many files should I save? Should each rank have an exclusive sharding file or only single rank collectively gather the whole embedding and stores as one file? How should I handle the case where both DP and MP are applied?
If each rank maintains a sharding file, how can I load and re-shard in a new distributed environment where the number of GPUs vary from the saved model.
If there is one saved file, how should I load and re-shard especially in multi-node env?
It's more helpful if anyone can provide a sample code! Thanks!
The text was updated successfully, but these errors were encountered:
You should have the checkpoint per rank, since we do not collectively gather the whole embedding. If you wanted to, you could do that and then reconstruct the sharded state dict by the original sharding plan. Although I wouldn't recommend this. You should be able to use torch.distributed.checkpoint utilities for TorchRec models.
For changing the number of GPU's, you would need to understand how the sharding changes. Am I correct in understanding you would want to go from a model sharded on 8 GPUs to load onto 16 GPUs? Resharding here would be important, which you would also have to do yourself before you load. TorchRec doesn't have any utilities surrounding this.
You can broadcast the parameters/state to the other ranks as you load, as a pre_load_state_dict_hook on top of DMP.
Hi, team, I would like to know how to load and dump a sharded embedding collection via
state_dict
. BasicallyHow many files should I save? Should each rank have an exclusive sharding file or only single rank collectively gather the whole embedding and stores as one file? How should I handle the case where both DP and MP are applied?
If each rank maintains a sharding file, how can I load and re-shard in a new distributed environment where the number of GPUs vary from the saved model.
If there is one saved file, how should I load and re-shard especially in multi-node env?
It's more helpful if anyone can provide a sample code! Thanks!
The text was updated successfully, but these errors were encountered: