Skip to content

Commit

Permalink
corrected jax expression, numpy function
Browse files Browse the repository at this point in the history
  • Loading branch information
sqali committed Sep 21, 2023
1 parent 7f76459 commit 5efabf5
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion keras_core/backend/jax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,4 @@ def rsqrt(x):


def erf(x):
return jnp.erf(x)
return jax.lax.erf(x)
2 changes: 1 addition & 1 deletion keras_core/backend/numpy/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,4 @@ def rsqrt(x):


def erf(x):
return scipy.special.erf(x)
return np.array(scipy.special.erf(x))
6 changes: 3 additions & 3 deletions keras_core/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def test_erf_operation_basic(self):
)

# Output from the erf operation in keras_core
output_from_erf_op = kmath.erf(sample_values).numpy()
output_from_erf_op = kmath.erf(sample_values)

# Assert that the outputs are close
self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5)
Expand All @@ -860,7 +860,7 @@ def test_erf_operation_dtype(self):
expected_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(
sample_values
)
output_from_erf_op = kmath.erf(sample_values).numpy()
output_from_erf_op = kmath.erf(sample_values)
self.assertAllClose(expected_output, output_from_erf_op, atol=1e-5)

def test_erf_operation_edge_cases(self):
Expand All @@ -869,7 +869,7 @@ def test_erf_operation_edge_cases(self):
expected_edge_output = (2 / np.sqrt(np.pi)) * np.vectorize(math.erf)(
edge_values
)
output_from_edge_erf_op = kmath.erf(edge_values).numpy()
output_from_edge_erf_op = kmath.erf(edge_values)
self.assertAllClose(
expected_edge_output, output_from_edge_erf_op, atol=1e-5
)

0 comments on commit 5efabf5

Please sign in to comment.