diff --git a/keras/src/ops/nn.py b/keras/src/ops/nn.py index 98526678305..f9a2f41a20f 100644 --- a/keras/src/ops/nn.py +++ b/keras/src/ops/nn.py @@ -799,12 +799,9 @@ def average_pool( data_format = standardize_data_format(data_format) padding = padding.lower() if any_symbolic_tensors((inputs,)): - return AveragePool( - pool_size, - strides, - padding, - data_format, - ).symbolic_call(inputs) + return operation_utils.compute_pooling_output_shape( + inputs.shape, pool_size, strides, padding, data_format + ) return backend.nn.average_pool( inputs, pool_size, strides, padding, data_format ) diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index f71f51e71c8..e43a6c51586 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -167,11 +167,11 @@ def test_average_pool(self): input_shape = (None, 3, 8) x = KerasTensor(input_shape) self.assertEqual( - knn.average_pool(x, 2, 1).shape, + knn.average_pool(x, 2, 1), (None, 7, 3) if data_format == "channels_last" else (None, 3, 7), ) self.assertEqual( - knn.average_pool(x, 2, 2, padding="same").shape, + knn.average_pool(x, 2, 2, padding="same"), (None, 4, 3) if data_format == "channels_last" else (None, 3, 4), ) @@ -181,7 +181,7 @@ def test_average_pool(self): input_shape = (None, 3, 8, None) x = KerasTensor(input_shape) self.assertEqual( - knn.average_pool(x, 2, 1).shape, + knn.average_pool(x, 2, 1), ( (None, 7, None, 3) if data_format == "channels_last" @@ -189,7 +189,7 @@ def test_average_pool(self): ), ) self.assertEqual( - knn.average_pool(x, 2, 2, padding="same").shape, + knn.average_pool(x, 2, 2, padding="same"), ( (None, 4, None, 3) if data_format == "channels_last" @@ -197,7 +197,7 @@ def test_average_pool(self): ), ) self.assertEqual( - knn.average_pool(x, (2, 2), (2, 2), padding="same").shape, + knn.average_pool(x, (2, 2), (2, 2), padding="same"), ( (None, 4, None, 3) if data_format == "channels_last" @@ -780,11 +780,11 @@ def test_average_pool(self): input_shape = (1, 3, 8) x = KerasTensor(input_shape) self.assertEqual( - knn.average_pool(x, 2, 1).shape, + knn.average_pool(x, 2, 1), (1, 7, 3) if data_format == "channels_last" else (1, 3, 7), ) self.assertEqual( - knn.average_pool(x, 2, 2, padding="same").shape, + knn.average_pool(x, 2, 2, padding="same"), (1, 4, 3) if data_format == "channels_last" else (1, 3, 4), ) @@ -794,15 +794,15 @@ def test_average_pool(self): input_shape = (1, 3, 8, 8) x = KerasTensor(input_shape) self.assertEqual( - knn.average_pool(x, 2, 1).shape, + knn.average_pool(x, 2, 1), (1, 7, 7, 3) if data_format == "channels_last" else (1, 3, 7, 7), ) self.assertEqual( - knn.average_pool(x, 2, 2, padding="same").shape, + knn.average_pool(x, 2, 2, padding="same"), (1, 4, 4, 3) if data_format == "channels_last" else (1, 3, 4, 4), ) self.assertEqual( - knn.average_pool(x, (2, 2), (2, 2), padding="same").shape, + knn.average_pool(x, (2, 2), (2, 2), padding="same"), (1, 4, 4, 3) if data_format == "channels_last" else (1, 3, 4, 4), )