Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable XLA with TensorFlow determinisim #20315

Merged
9 changes: 9 additions & 0 deletions keras/src/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,5 +1122,14 @@ def model_supports_jit(model):
return False
# XLA not supported by some layers
if all(x.supports_jit for x in model._flatten_layers()):
if backend.backend() == "tensorflow":
from tensorflow.python.framework.config import (
is_op_determinism_enabled,
)

if is_op_determinism_enabled():
# disable XLA with determinism enabled since not all ops are
# supported by XLA with determinism enabled.
return False
return True
return False
21 changes: 21 additions & 0 deletions keras/src/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,6 +1718,27 @@ def call(self, x, training=None):
for v in model._compile_loss.variables:
self.assertAllClose(v, 0.0)

@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="This test is only applicable to TensorFlow.",
)
@pytest.mark.requires_trainable_backend
def test_jit_compile_with_tf_determinism(self):
from tensorflow.python.framework.config import disable_op_determinism
from tensorflow.python.framework.config import enable_op_determinism

enable_op_determinism()

model = ExampleModel(units=3)
model.compile(
optimizer=optimizers.SGD(),
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
)

self.assertFalse(model.jit_compile)
disable_op_determinism()


class TrainerDistributeTest(testing.TestCase):
@pytest.mark.skipif(
Expand Down