From c32c3c63c8133798f05e510dfa306436f4503999 Mon Sep 17 00:00:00 2001 From: Adithya Kamath Date: Sun, 16 Jul 2023 02:42:27 +0530 Subject: [PATCH] Constant Initializer Error fixed (#479) * Constant Init tests * Comments fixed --- .../initializers/constant_initalizers_test.py | 43 +++++++++++++++++++ .../initializers/constant_initializers.py | 1 + 2 files changed, 44 insertions(+) create mode 100644 keras_core/initializers/constant_initalizers_test.py diff --git a/keras_core/initializers/constant_initalizers_test.py b/keras_core/initializers/constant_initalizers_test.py new file mode 100644 index 000000000..e1a7b43dd --- /dev/null +++ b/keras_core/initializers/constant_initalizers_test.py @@ -0,0 +1,43 @@ +import numpy as np + +from keras_core import backend +from keras_core import initializers +from keras_core import testing + + +class ConstantInitializersTest(testing.TestCase): + def test_zeros_initializer(self): + shape = (3, 3) + + initializer = initializers.Zeros() + values = initializer(shape=shape) + self.assertEqual(values.shape, shape) + np_values = backend.convert_to_numpy(values).data + self.assertEqual(np_values, np.zeros(shape=shape)) + + self.run_class_serialization_test(initializer) + + def test_ones_initializer(self): + shape = (3, 3) + + initializer = initializers.Ones() + values = initializer(shape=shape) + self.assertEqual(values.shape, shape) + np_values = backend.convert_to_numpy(values).data + self.assertEqual(np_values, np.ones(shape=shape)) + + self.run_class_serialization_test(initializer) + + def test_constant_initializer(self): + shape = (3, 3) + constant_value = 6.0 + + initializer = initializers.Constant(value=constant_value) + values = initializer(shape=shape) + self.assertEqual(values.shape, shape) + np_values = backend.convert_to_numpy(values).data + self.assertEqual( + np_values, np.full(shape=shape, fill_value=constant_value) + ) + + self.run_class_serialization_test(initializer) diff --git a/keras_core/initializers/constant_initializers.py b/keras_core/initializers/constant_initializers.py index 62b4b7243..9dda16de1 100644 --- a/keras_core/initializers/constant_initializers.py +++ b/keras_core/initializers/constant_initializers.py @@ -30,6 +30,7 @@ def __init__(self, value=0.0): self.value = float(value) def __call__(self, shape, dtype=None): + dtype = standardize_dtype(dtype) return self.value * ops.ones(shape=shape, dtype=dtype) def get_config(self):