diff --git a/keras_core/backend/tensorflow/distribution_lib.py b/keras_core/backend/tensorflow/distribution_lib.py index 0eae4cc1d..2d62419b4 100644 --- a/keras_core/backend/tensorflow/distribution_lib.py +++ b/keras_core/backend/tensorflow/distribution_lib.py @@ -25,6 +25,11 @@ def list_devices(device_type=None): device_type = ( device_type.upper() if device_type else dtensor.preferred_device_type() ) + + # DTensor doesn't support getting global devices, even when knowing the + # Mesh. Use TF API instead to get global devices. Coordinator service is + # enabled by default with DTensor, so that list_logical_devices() returns + # a list of global devices. More context can be found in b/254911601. return tf.config.list_logical_devices(device_type=device_type)