Skip to content

Commit

Permalink
update numpy test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
divyashreepathihalli committed Sep 20, 2023
1 parent d4cc4e7 commit 08d5979
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 12 deletions.
11 changes: 7 additions & 4 deletions keras_core/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@ def add(x1, x2):

def bincount(x, weights=None, minlength=0):
if len(x.shape) == 2:
bincounts = [
jnp.bincount(arr, weights=weights, minlength=minlength)
for arr in list(x)
]
if weights is None:
bincounts = [jnp.bincount(arr, minlength=minlength) for arr in x]
else:
bincounts = [
jnp.bincount(arr, weights=w, minlength=minlength)
for arr, w in zip(x, weights)
]
return jnp.stack(bincounts)
return jnp.bincount(x, weights=weights, minlength=minlength)

Expand Down
11 changes: 7 additions & 4 deletions keras_core/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,13 @@ def average(x, axis=None, weights=None):

def bincount(x, weights=None, minlength=0):
if len(x.shape) == 2:
bincounts = [
np.bincount(arr, weights=weights, minlength=minlength)
for arr in list(x)
]
if weights is None:
bincounts = [np.bincount(arr, minlength=minlength) for arr in x]
else:
bincounts = [
np.bincount(arr, weights=w, minlength=minlength)
for arr, w in zip(x, weights)
]
return np.stack(bincounts)
return np.bincount(x, weights, minlength)

Expand Down
12 changes: 8 additions & 4 deletions keras_core/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,14 @@ def bincount(x, weights=None, minlength=0):
if weights is not None:
weights = convert_to_tensor(weights)
if len(x.shape) == 2:
bincounts = [
torch.bincount(arr, weights=weights, minlength=minlength)
for arr in list(x)
]
if weights is None:
bincounts = [torch.bincount(arr, minlength=minlength) for arr in x]
else:
bincounts = [
torch.bincount(arr, weights=w, minlength=minlength)
for arr, w in zip(x, weights)
]

return torch.stack(bincounts)
return torch.bincount(x, weights, minlength)

Expand Down
17 changes: 17 additions & 0 deletions keras_core/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2699,6 +2699,23 @@ def test_bincount(self):
knp.Bincount(weights=weights, minlength=minlength)(x),
np.bincount(x, weights=weights, minlength=minlength),
)
x = np.array([[1, 1, 2, 3, 2, 4, 4, 5]])
weights = np.array([[0, 0, 3, 2, 1, 1, 4, 2]])
expected_output = np.array([[0, 0, 4, 2, 5, 2]])
self.assertAllClose(
knp.bincount(x, weights=weights, minlength=minlength),
expected_output,
)
self.assertAllClose(
knp.Bincount(weights=weights, minlength=minlength)(x),
expected_output,
)
# test with weights=None
expected_output = np.array([[0, 2, 2, 1, 2, 1]])
self.assertAllClose(
knp.Bincount(weights=None, minlength=minlength)(x),
expected_output,
)

def test_broadcast_to(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
Expand Down

0 comments on commit 08d5979

Please sign in to comment.