Request for developer guide: multi-node TPU distributed training with JAX #20356
Labels
keras-team-review-pending
Pending review by a Keras team member.
type:support
User is asking for help / asking an implementation question. Stackoverflow would be better suited.
Multi-node TPU Training with JAX
The multi-GPU JAX training guide is helpful, but it's unclear how to extend this to multi-node TPU setups.
Currently using
tpu v4-256
withtf.distribute.cluster_resolver.TPUClusterResolver
andtf.distribute.TPUStrategy
for data-parallel training. We're transitioning to jax and need the equivalent approach.Specifically:
TPUClusterResolver
).Detailed examples would be helpful. Thank you.
The text was updated successfully, but these errors were encountered: