Skip to content

Commit

Permalink
Revert "Replace RandomGenerator with SeedGenerator (#2150)" (#2161)
Browse files Browse the repository at this point in the history
This reverts commit 365a675.
  • Loading branch information
sampathweb authored Nov 18, 2023
1 parent 365a675 commit e9b3d34
Show file tree
Hide file tree
Showing 61 changed files with 190 additions and 366 deletions.
5 changes: 2 additions & 3 deletions benchmarks/vectorized_jittered_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from tensorflow import keras

from keras_cv import bounding_box
from keras_cv.backend import random
from keras_cv.layers import JitteredResize
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down Expand Up @@ -259,8 +258,8 @@ def test_consistency_with_old_impl(self):

# makes offsets fixed to (0.5, 0.5)
with unittest.mock.patch.object(
random,
"uniform",
layer._random_generator,
"random_uniform",
return_value=tf.convert_to_tensor([[0.5, 0.5]]),
):
output = layer(image)
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/vectorized_mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from tensorflow import keras

from keras_cv import bounding_box
from keras_cv.backend import random
from keras_cv.layers import Mosaic
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down Expand Up @@ -102,7 +101,7 @@ def _batch_augment(self, inputs):
minval=0,
maxval=batch_size,
dtype=tf.int32,
seed=random.make_seed(seed=self._seed_generator),
seed=self._random_generator.make_legacy_seed(),
)
# concatenate the batches with permutation order to get all 4 images of
# the mosaic
Expand Down
5 changes: 1 addition & 4 deletions benchmarks/vectorized_random_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from tensorflow import keras

from keras_cv import bounding_box
from keras_cv.backend import random
from keras_cv.layers import RandomCrop
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down Expand Up @@ -73,9 +72,7 @@ def get_random_transformation(self, image=None, **kwargs):
h_diff = image_shape[H_AXIS] - self.height
w_diff = image_shape[W_AXIS] - self.width
dtype = image_shape.dtype
rands = random.uniform(
[2], 0, dtype.max, dtype, seed=self._seed_generator
)
rands = self._random_generator.random_uniform([2], 0, dtype.max, dtype)
h_start = rands[0] % (h_diff + 1)
w_start = rands[1] % (w_diff + 1)
return {"top": h_start, "left": w_start}
Expand Down
13 changes: 6 additions & 7 deletions benchmarks/vectorized_random_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from tensorflow import keras

from keras_cv import bounding_box
from keras_cv.backend import random
from keras_cv.layers import RandomFlip
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down Expand Up @@ -103,11 +102,11 @@ def get_random_transformation(self, **kwargs):
flip_vertical = False
if self.horizontal:
flip_horizontal = (
random.uniform(shape=[], seed=self._seed_generator) > 0.5
self._random_generator.random_uniform(shape=[]) > 0.5
)
if self.vertical:
flip_vertical = (
random.uniform(shape=[], seed=self._seed_generator) > 0.5
self._random_generator.random_uniform(shape=[]) > 0.5
)
return {
"flip_horizontal": tf.cast(flip_horizontal, dtype=tf.bool),
Expand Down Expand Up @@ -237,14 +236,14 @@ def test_consistency_with_old_impl(self):
)

with unittest.mock.patch.object(
random,
"uniform",
layer._random_generator,
"random_uniform",
return_value=tf.convert_to_tensor([[0.6]]),
):
output = layer(image)
with unittest.mock.patch.object(
random,
"uniform",
old_layer._random_generator,
"random_uniform",
return_value=tf.convert_to_tensor(0.6),
):
old_output = old_layer(image)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/vectorized_random_hue.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, factor, value_range, seed=None, **kwargs):
self.seed = seed

def get_random_transformation(self, **kwargs):
invert = preprocessing_utils.random_inversion(self._seed_generator)
invert = preprocessing_utils.random_inversion(self._random_generator)
# We must scale self.factor() to the range [-0.5, 0.5]. This is because
# the tf.image operation performs rotation on the hue saturation value
# orientation. This can be thought of as an angle in the range
Expand Down
8 changes: 2 additions & 6 deletions benchmarks/vectorized_random_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from tensorflow import keras

from keras_cv import bounding_box
from keras_cv.backend import random
from keras_cv.layers import RandomRotation
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down Expand Up @@ -123,11 +122,8 @@ def __init__(
def get_random_transformation(self, **kwargs):
min_angle = self.lower * 2.0 * np.pi
max_angle = self.upper * 2.0 * np.pi
angle = random.uniform(
shape=[1],
minval=min_angle,
maxval=max_angle,
seed=self._seed_generator,
angle = self._random_generator.random_uniform(
shape=[1], minval=min_angle, maxval=max_angle
)
return {"angle": angle}

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/vectorized_random_shear.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _get_shear_amount(self, constraint):
if constraint is None:
return None

invert = preprocessing.random_inversion(self._seed_generator)
invert = preprocessing.random_inversion(self._random_generator)
return invert * constraint()

def augment_image(self, image, transformation=None, **kwargs):
Expand Down
7 changes: 2 additions & 5 deletions benchmarks/vectorized_random_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from keras import backend
from tensorflow import keras

from keras_cv.backend import random
from keras_cv.layers import RandomTranslation
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down Expand Up @@ -218,19 +217,17 @@ def augment_image(self, image, transformation, **kwargs):

def get_random_transformation(self, image=None, **kwargs):
batch_size = 1
height_translation = random.uniform(
height_translation = self._random_generator.random_uniform(
shape=[batch_size, 1],
minval=self.height_lower,
maxval=self.height_upper,
dtype=tf.float32,
seed=self._seed_generator,
)
width_translation = random.uniform(
width_translation = self._random_generator.random_uniform(
shape=[batch_size, 1],
minval=self.width_lower,
maxval=self.width_upper,
dtype=tf.float32,
seed=self._seed_generator,
)
return {
"height_translation": height_translation,
Expand Down
7 changes: 2 additions & 5 deletions benchmarks/vectorized_random_zoom.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from keras import backend
from tensorflow import keras

from keras_cv.backend import random
from keras_cv.layers import RandomZoom
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down Expand Up @@ -144,18 +143,16 @@ def __init__(
self.seed = seed

def get_random_transformation(self, image=None, **kwargs):
height_zoom = random.uniform(
height_zoom = self._random_generator.random_uniform(
shape=[1, 1],
minval=1.0 + self.height_lower,
maxval=1.0 + self.height_upper,
seed=self._seed_generator,
)
if self.width_factor is not None:
width_zoom = random.uniform(
width_zoom = self._random_generator.random_uniform(
shape=[1, 1],
minval=1.0 + self.width_lower,
maxval=1.0 + self.width_upper,
seed=self._seed_generator,
)
else:
width_zoom = height_zoom
Expand Down
7 changes: 2 additions & 5 deletions benchmarks/vectorized_randomly_zoomed_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from tensorflow import keras

from keras_cv import core
from keras_cv.backend import random
from keras_cv.layers import RandomlyZoomedCrop
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down Expand Up @@ -110,20 +109,18 @@ def get_random_transformation(

new_width = crop_size[1] * tf.sqrt(aspect_ratio)

height_offset = random.uniform(
height_offset = self._random_generator.random_uniform(
(),
minval=tf.minimum(0.0, original_height - new_height),
maxval=tf.maximum(0.0, original_height - new_height),
dtype=tf.float32,
seed=self._seed_generator,
)

width_offset = random.uniform(
width_offset = self._random_generator.random_uniform(
(),
minval=tf.minimum(0.0, original_width - new_width),
maxval=tf.maximum(0.0, original_width - new_width),
dtype=tf.float32,
seed=self._seed_generator,
)

new_height = new_height / original_height
Expand Down
36 changes: 21 additions & 15 deletions keras_cv/backend/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,23 @@ def __init__(self, seed=None, **kwargs):
seed=seed, **kwargs
)
else:
self._current_seed = [seed, 0]
self._current_seed = [0, seed]

def next(self, ordered=True):
if keras_3():
return self._seed_generator.next(ordered=ordered)
else:
self._current_seed[1] += 1
self._current_seed[0] += 1
return self._current_seed[:]


def make_seed(seed=None):
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
if isinstance(seed, SeedGenerator):
seed_0, seed_1 = seed.next()
if seed_0 is None:
init_seed = seed_1
else:
init_seed = seed_0 + seed_1
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed
return init_seed


def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
init_seed = make_seed(seed)
kwargs = {}
if dtype:
kwargs["dtype"] = dtype
Expand All @@ -75,7 +68,11 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):


def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
init_seed = make_seed(seed)
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed
kwargs = {}
if dtype:
kwargs["dtype"] = dtype
Expand All @@ -100,7 +97,12 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):


def shuffle(x, axis=0, seed=None):
init_seed = make_seed(seed)
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed

if keras_3():
return keras.random.shuffle(x=x, axis=axis, seed=init_seed)
else:
Expand All @@ -110,7 +112,11 @@ def shuffle(x, axis=0, seed=None):


def categorical(logits, num_samples, dtype=None, seed=None):
init_seed = make_seed(seed)
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed
kwargs = {}
if dtype:
kwargs["dtype"] = dtype
Expand Down
30 changes: 10 additions & 20 deletions keras_cv/layers/preprocessing/aug_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from keras_cv import layers
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import random
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)
Expand Down Expand Up @@ -107,41 +106,32 @@ def _sample_from_dirichlet(self, alpha):
gamma_sample = tf.random.gamma(
shape=(),
alpha=alpha,
seed=random.make_seed(seed=self._seed_generator),
seed=self._random_generator.make_legacy_seed(),
)
return gamma_sample / tf.reduce_sum(
gamma_sample, axis=-1, keepdims=True
)

def _sample_from_beta(self, alpha, beta):
sample_alpha = tf.random.gamma(
(),
alpha=alpha,
seed=random.make_seed(seed=self._seed_generator),
(), alpha=alpha, seed=self._random_generator.make_legacy_seed()
)
sample_beta = tf.random.gamma(
(),
alpha=beta,
seed=random.make_seed(seed=self._seed_generator),
(), alpha=beta, seed=self._random_generator.make_legacy_seed()
)
return sample_alpha / (sample_alpha + sample_beta)

def _sample_depth(self):
return random.uniform(
return self._random_generator.random_uniform(
shape=(),
minval=self.chain_depth[0],
maxval=self.chain_depth[1] + 1,
dtype=tf.int32,
seed=self._seed_generator,
)

def _loop_on_depth(self, depth_level, image_aug):
op_index = random.uniform(
shape=(),
minval=0,
maxval=8,
dtype=tf.int32,
seed=self._seed_generator,
op_index = self._random_generator.random_uniform(
shape=(), minval=0, maxval=8, dtype=tf.int32
)
image_aug = self._apply_op(image_aug, op_index)
depth_level += 1
Expand Down Expand Up @@ -214,7 +204,7 @@ def _solarize(self, image):

def _shear_x(self, image):
x = tf.cast(self.severity_factor() * 0.3, tf.float32)
x *= preprocessing.random_inversion(self._seed_generator)
x *= preprocessing.random_inversion(self._random_generator)
transform_x = layers.RandomShear._format_transform(
[1.0, x, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
)
Expand All @@ -224,7 +214,7 @@ def _shear_x(self, image):

def _shear_y(self, image):
y = tf.cast(self.severity_factor() * 0.3, tf.float32)
y *= preprocessing.random_inversion(self._seed_generator)
y *= preprocessing.random_inversion(self._random_generator)
transform_x = self._format_random_shear_transform(
[1.0, 0.0, 0.0, y, 1.0, 0.0, 0.0, 0.0]
)
Expand All @@ -241,7 +231,7 @@ def _translate_x(self, image):
shape = tf.cast(tf.shape(image), tf.float32)
x = tf.cast(self.severity_factor() * shape[1] / 3, tf.float32)
x = tf.expand_dims(tf.expand_dims(x, axis=0), axis=0)
x *= preprocessing.random_inversion(self._seed_generator)
x *= preprocessing.random_inversion(self._random_generator)
x = tf.cast(x, tf.int32)

translations = tf.cast(
Expand All @@ -256,7 +246,7 @@ def _translate_y(self, image):
shape = tf.cast(tf.shape(image), tf.float32)
y = tf.cast(self.severity_factor() * shape[0] / 3, tf.float32)
y = tf.expand_dims(tf.expand_dims(y, axis=0), axis=0)
y *= preprocessing.random_inversion(self._seed_generator)
y *= preprocessing.random_inversion(self._random_generator)
y = tf.cast(y, tf.int32)

translations = tf.cast(
Expand Down
Loading

0 comments on commit e9b3d34

Please sign in to comment.