Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request for developer guide: multi-node TPU distributed training with JAX #20356

Open
rivershah opened this issue Oct 15, 2024 · 0 comments
Open
Assignees
Labels
keras-team-review-pending Pending review by a Keras team member. type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited.

Comments

@rivershah
Copy link

Multi-node TPU Training with JAX

The multi-GPU JAX training guide is helpful, but it's unclear how to extend this to multi-node TPU setups.

Currently using tpu v4-256 with tf.distribute.cluster_resolver.TPUClusterResolver and tf.distribute.TPUStrategy for data-parallel training. We're transitioning to jax and need the equivalent approach.

Specifically:

  1. How to configure TPU runtime for jax.
  2. How to handle cluster resolution for TPUs (similar to TPUClusterResolver).
  3. Examples for multi-node TPU data-parallel training with jax.

Detailed examples would be helpful. Thank you.

@mehtamansi29 mehtamansi29 added type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited. keras-team-review-pending Pending review by a Keras team member. labels Oct 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
keras-team-review-pending Pending review by a Keras team member. type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited.
Projects
None yet
Development

No branches or pull requests

2 participants