diff --git a/keras_core/backend/jax/numpy.py b/keras_core/backend/jax/numpy.py index eb3720a61..088ffcf8b 100644 --- a/keras_core/backend/jax/numpy.py +++ b/keras_core/backend/jax/numpy.py @@ -13,10 +13,13 @@ def add(x1, x2): def bincount(x, weights=None, minlength=0): if len(x.shape) == 2: - bincounts = [ - jnp.bincount(arr, weights=weights, minlength=minlength) - for arr in list(x) - ] + if weights is None: + bincounts = [jnp.bincount(arr, minlength=minlength) for arr in x] + else: + bincounts = [ + jnp.bincount(arr, weights=w, minlength=minlength) + for arr, w in zip(x, weights) + ] return jnp.stack(bincounts) return jnp.bincount(x, weights=weights, minlength=minlength) diff --git a/keras_core/backend/numpy/numpy.py b/keras_core/backend/numpy/numpy.py index 01a924b70..4f5f86c78 100644 --- a/keras_core/backend/numpy/numpy.py +++ b/keras_core/backend/numpy/numpy.py @@ -134,10 +134,13 @@ def average(x, axis=None, weights=None): def bincount(x, weights=None, minlength=0): if len(x.shape) == 2: - bincounts = [ - np.bincount(arr, weights=weights, minlength=minlength) - for arr in list(x) - ] + if weights is None: + bincounts = [np.bincount(arr, minlength=minlength) for arr in x] + else: + bincounts = [ + np.bincount(arr, weights=w, minlength=minlength) + for arr, w in zip(x, weights) + ] return np.stack(bincounts) return np.bincount(x, weights, minlength) diff --git a/keras_core/backend/torch/numpy.py b/keras_core/backend/torch/numpy.py index 85497413a..0d27c819d 100644 --- a/keras_core/backend/torch/numpy.py +++ b/keras_core/backend/torch/numpy.py @@ -248,10 +248,14 @@ def bincount(x, weights=None, minlength=0): if weights is not None: weights = convert_to_tensor(weights) if len(x.shape) == 2: - bincounts = [ - torch.bincount(arr, weights=weights, minlength=minlength) - for arr in list(x) - ] + if weights is None: + bincounts = [torch.bincount(arr, minlength=minlength) for arr in x] + else: + bincounts = [ + torch.bincount(arr, weights=w, minlength=minlength) + for arr, w in zip(x, weights) + ] + return torch.stack(bincounts) return torch.bincount(x, weights, minlength) diff --git a/keras_core/ops/numpy_test.py b/keras_core/ops/numpy_test.py index 2ae7f12c6..ad612fd81 100644 --- a/keras_core/ops/numpy_test.py +++ b/keras_core/ops/numpy_test.py @@ -2699,6 +2699,23 @@ def test_bincount(self): knp.Bincount(weights=weights, minlength=minlength)(x), np.bincount(x, weights=weights, minlength=minlength), ) + x = np.array([[1, 1, 2, 3, 2, 4, 4, 5]]) + weights = np.array([[0, 0, 3, 2, 1, 1, 4, 2]]) + expected_output = np.array([[0, 0, 4, 2, 5, 2]]) + self.assertAllClose( + knp.bincount(x, weights=weights, minlength=minlength), + expected_output, + ) + self.assertAllClose( + knp.Bincount(weights=weights, minlength=minlength)(x), + expected_output, + ) + # test with weights=None + expected_output = np.array([[0, 2, 2, 1, 2, 1]]) + self.assertAllClose( + knp.Bincount(weights=None, minlength=minlength)(x), + expected_output, + ) def test_broadcast_to(self): x = np.array([[1, 2, 3], [3, 2, 1]])