From c1d7955b011e9d142924b91d7fd63552f09cdf82 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 26 Sep 2024 14:53:12 -0700 Subject: [PATCH 01/35] Rebased phase 1 changes --- .../src/models/retinanet/anchor_generator.py | 44 ++- .../models/retinanet/anchor_generator_test.py | 29 +- .../src/models/retinanet/feature_pyramid.py | 373 ++++++++++++++++++ .../models/retinanet/feature_pyramid_test.py | 81 ++++ .../retinanet/retinanet_label_encoder.py | 270 +++++++++++++ .../retinanet/retinanet_label_encoder_test.py | 85 ++++ keras_hub/src/tests/test_case.py | 29 +- keras_hub/src/utils/tensor_utils.py | 106 +++++ keras_hub/src/utils/tensor_utils_test.py | 102 +++++ 9 files changed, 1096 insertions(+), 23 deletions(-) create mode 100644 keras_hub/src/models/retinanet/feature_pyramid.py create mode 100644 keras_hub/src/models/retinanet/feature_pyramid_test.py create mode 100644 keras_hub/src/models/retinanet/retinanet_label_encoder.py create mode 100644 keras_hub/src/models/retinanet/retinanet_label_encoder_test.py diff --git a/keras_hub/src/models/retinanet/anchor_generator.py b/keras_hub/src/models/retinanet/anchor_generator.py index 04d5f7dc9e..bb46988926 100644 --- a/keras_hub/src/models/retinanet/anchor_generator.py +++ b/keras_hub/src/models/retinanet/anchor_generator.py @@ -24,29 +24,31 @@ class AnchorGenerator(keras.layers.Layer): for larger objects. Args: - bounding_box_format (str): The format of the bounding boxes + bounding_box_format: str. The format of the bounding boxes to be generated. Expected to be a string like 'xyxy', 'xywh', etc. - min_level (int): Minimum level of the output feature pyramid. - max_level (int): Maximum level of the output feature pyramid. - num_scales (int): Number of intermediate scales added on each level. + min_level: int. Minimum level of the output feature pyramid. + max_level: int. Maximum level of the output feature pyramid. + num_scales: int. Number of intermediate scales added on each level. For example, num_scales=2 adds one additional intermediate anchor scale [2^0, 2^0.5] on each level. - aspect_ratios (list of float): Aspect ratios of anchors added on + aspect_ratios: List[float]. Aspect ratios of anchors added on each level. Each number indicates the ratio of width to height. - anchor_size (float): Scale of size of the base anchor relative to the + anchor_size: float. Scale of size of the base anchor relative to the feature stride 2^level. Call arguments: - images (Optional[Tensor]): An image tensor with shape `[B, H, W, C]` or - `[H, W, C]`. If provided, its shape will be used to determine anchor + inputs: An image tensor with shape `[B, H, W, C]` or + `[H, W, C]`. Its shape will be used to determine anchor sizes. Returns: Dict: A dictionary mapping feature levels - (e.g., 'P3', 'P4', etc.) to anchor boxes. Each entry contains a tensor - of shape `(H/stride * W/stride * num_anchors_per_location, 4)`, - where H and W are the height and width of the image, stride is 2^level, - and num_anchors_per_location is `num_scales * len(aspect_ratios)`. + (e.g., 'P3', 'P4', etc.) to anchor boxes. Each entry contains a + tensor of shape + `(H/stride * W/stride * num_anchors_per_location, 4)`, + where H and W are the height and width of the image, + stride is 2^level, and num_anchors_per_location is + `num_scales * len(aspect_ratios)`. Example: ```python @@ -81,8 +83,8 @@ def __init__( self.anchor_size = anchor_size self.built = True - def call(self, images): - images_shape = ops.shape(images) + def call(self, inputs): + images_shape = ops.shape(inputs) if len(images_shape) == 4: image_shape = images_shape[1:-1] else: @@ -147,8 +149,18 @@ def call(self, images): def compute_output_shape(self, input_shape): multilevel_boxes_shape = {} - for level in range(self.min_level, self.max_level + 1): - multilevel_boxes_shape[f"P{level}"] = (None, None, 4) + if len(input_shape) == 4: + image_height, image_width = input_shape[1:-1] + else: + image_height, image_width = input_shape[:-1] + + for i in range(self.min_level, self.max_level + 1): + multilevel_boxes_shape[f"P{i}"] = ( + (image_height // 2 ** (i)) + * (image_width // 2 ** (i)) + * self.anchors_per_location, + 4, + ) return multilevel_boxes_shape @property diff --git a/keras_hub/src/models/retinanet/anchor_generator_test.py b/keras_hub/src/models/retinanet/anchor_generator_test.py index 8b0669188a..c843c32f27 100644 --- a/keras_hub/src/models/retinanet/anchor_generator_test.py +++ b/keras_hub/src/models/retinanet/anchor_generator_test.py @@ -1,3 +1,4 @@ +import numpy as np from absl.testing import parameterized from keras import ops @@ -7,6 +8,32 @@ class AnchorGeneratorTest(TestCase): + def test_layer_behaviors(self): + images_shape = (8, 128, 128, 3) + self.run_layer_test( + cls=AnchorGenerator, + init_kwargs={ + "bounding_box_format": "xyxy", + "min_level": 3, + "max_level": 7, + "num_scales": 3, + "aspect_ratios": [0.5, 1.0, 2.0], + "anchor_size": 8, + }, + input_data=np.random.uniform(size=images_shape), + expected_output_shape={ + "P3": (2304, 4), + "P4": (576, 4), + "P5": (144, 4), + "P6": (36, 4), + "P7": (9, 4), + }, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + run_training_check=False, + run_precision_checks=False, + ) + @parameterized.parameters( # Single scale anchor ("yxyx", 5, 5, 1, [1.0], 2.0, [64, 64]) @@ -86,7 +113,7 @@ def test_anchor_generator( anchor_size, ) images = ops.ones(shape=(1, image_shape[0], image_shape[1], 3)) - multilevel_boxes = anchor_generator(images=images) + multilevel_boxes = anchor_generator(images) for key in expected_boxes: expected_boxes[key] = ops.convert_to_tensor(expected_boxes[key]) expected_boxes[key] = convert_format( diff --git a/keras_hub/src/models/retinanet/feature_pyramid.py b/keras_hub/src/models/retinanet/feature_pyramid.py new file mode 100644 index 0000000000..5c0bbb906c --- /dev/null +++ b/keras_hub/src/models/retinanet/feature_pyramid.py @@ -0,0 +1,373 @@ +import keras + + +class FeaturePyramid(keras.layers.Layer): + """A Feature Pyramid Network (FPN) layer. + + This implements the paper: + Tsung-Yi Lin, Piotr Dollar, Ross Girshick, Kaiming He, Bharath Hariharan, + and Serge Belongie. Feature Pyramid Networks for Object Detection. + (https://arxiv.org/pdf/1612.03144) + + Feature Pyramid Networks (FPNs) are basic components that are added to an + existing feature extractor (CNN) to combine features at different scales. + For the basic FPN, the inputs are features `Ci` from different levels of a + CNN, which is usually the last block for each level, where the feature is + scaled from the image by a factor of `1/2^i`. + + There is an output associated with each level in the basic FPN. The output + Pi at level `i` (corresponding to Ci) is given by performing a merge + operation on the outputs of: + + 1) a lateral operation on Ci (usually a conv2D layer with kernel = 1 and + strides = 1) + 2) a top-down upsampling operation from Pi+1 (except for the top most level) + + The final output of each level will also have a conv2D operation + (typically with kernel = 3 and strides = 1). + + The inputs to the layer should be a dict with int keys should match the + pyramid_levels, e.g. for `pyramid_levels` = [3,4,5], the expected input + dict should be `{P3:c3, P4:c4, P5:c5}`. + + The output of the layer will have same structures as the inputs, a dict with + extra coarser layers will be added based on the `max_level` provided. + keys and value for each of the level. + + Args: + min_level: int. The minimum level of the feature pyramid. + max_level: int. The maximum level of the feature pyramid. + num_filters: int. The number of filters in each feature map. + activation: string or `keras.activations`. The activation function + to be used in network. + Defaults to `"relu"`. + kernel_initializer: `str` or `keras.initializers` initializer. + The kernel initializer for the convolution layers. + Defaults to `"VarianceScaling"`. + bias_initializer: `str` or `keras.initializers` initializer. + The bias initializer for the convolution layers. + Defaults to `"zeros"`. + batch_norm_momentum: float. + The momentum for the batch normalization layers. + Defaults to `0.99`. + batch_norm_epsilon: float. + The epsilon for the batch normalization layers. + Defaults to `0.001`. + kernel_regularizer: `str` or `keras.regularizers` regularizer. + The kernel regularizer for the convolution layers. + Defaults to `None`. + bias_regularizer: `str` or `keras.regularizers` regularizer. + The bias regularizer for the convolution layers. + Defaults to `None`. + use_batch_norm: bool. Whether to use batch normalization. + Defaults to `False`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. + """ + + def __init__( + self, + min_level, + max_level, + num_filters=256, + activation="relu", + kernel_initializer="VarianceScaling", + bias_initializer="zeros", + batch_norm_momentum=0.99, + batch_norm_epsilon=0.001, + kernel_regularizer=None, + bias_regularizer=None, + use_batch_norm=False, + **kwargs, + ): + super().__init__(**kwargs) + if min_level > max_level: + raise ValueError( + f"Minimum level ({min_level}) must be less than or equal to " + f"maximum level ({max_level})." + ) + self.min_level = min_level + self.max_level = max_level + self.num_filters = num_filters + self.activation = keras.activations.get(activation) + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + self.batch_norm_momentum = batch_norm_momentum + self.batch_norm_epsilon = batch_norm_epsilon + self.use_batch_norm = use_batch_norm + if kernel_regularizer is not None: + self.kernel_regularizer = keras.regularizers.get(kernel_regularizer) + else: + self.kernel_regularizer = None + if bias_regularizer is not None: + self.bias_regularizer = keras.regularizers.get(bias_regularizer) + else: + self.bias_regularizer = None + self.data_format = keras.backend.image_data_format() + self.batch_norm_axis = -1 if self.data_format == "channels_last" else 1 + + def build(self, input_shapes): + input_shapes = { + ( + input_name.split("_")[0] + if "shape" in input_name + else input_name + ): input_shapes[input_name] + for input_name in input_shapes + } + input_levels = [int(level[1]) for level in input_shapes] + backbone_max_level = min(max(input_levels), self.max_level) + + # Build lateral layers + self.lateral_conv_layers = {} + for i in range(self.min_level, backbone_max_level + 1): + level = f"P{i}" + self.lateral_conv_layers[level] = keras.layers.Conv2D( + filters=self.num_filters, + kernel_size=1, + padding="same", + data_format=self.data_format, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + kernel_regularizer=self.kernel_regularizer, + bias_regularizer=self.bias_regularizer, + dtype=self.dtype_policy, + name=f"lateral_conv_{level}", + ) + self.lateral_conv_layers[level].build(input_shapes[level]) + + self.lateral_batch_norm_layers = {} + if self.use_batch_norm: + for i in range(self.min_level, backbone_max_level + 1): + level = f"P{i}" + self.lateral_batch_norm_layers[level] = ( + keras.layers.BatchNormalization( + axis=self.batch_norm_axis, + momentum=self.batch_norm_epsilon, + epsilon=self.batch_norm_epsilon, + name=f"lateral_norm_{level}", + ) + ) + self.lateral_batch_norm_layers[level].build( + (None, None, None, 256) + if self.data_format == "channels_last" + else (None, 256, None, None) + ) + + # Build output layers + self.output_conv_layers = {} + for i in range(self.min_level, backbone_max_level + 1): + level = f"P{i}" + self.output_conv_layers[level] = keras.layers.Conv2D( + filters=self.num_filters, + kernel_size=3, + padding="same", + data_format=self.data_format, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + kernel_regularizer=self.kernel_regularizer, + bias_regularizer=self.bias_regularizer, + dtype=self.dtype_policy, + name=f"output_conv_{level}", + ) + self.output_conv_layers[level].build( + (None, None, None, 256) + if self.data_format == "channels_last" + else (None, 256, None, None) + ) + + # Build coarser layers + for i in range(backbone_max_level + 1, self.max_level + 1): + level = f"P{i}" + self.output_conv_layers[level] = keras.layers.Conv2D( + filters=self.num_filters, + strides=2, + kernel_size=3, + padding="same", + data_format=self.data_format, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + kernel_regularizer=self.kernel_regularizer, + bias_regularizer=self.bias_regularizer, + dtype=self.dtype_policy, + name=f"coarser_{level}", + ) + self.output_conv_layers[level].build( + (None, None, None, 256) + if self.data_format == "channels_last" + else (None, 256, None, None) + ) + + # Build batch norm layers + self.output_batch_norms = {} + if self.use_batch_norm: + for i in range(self.min_level, self.max_level + 1): + level = f"P{i}" + self.output_batch_norms[level] = ( + keras.layers.BatchNormalization( + axis=self.batch_norm_axis, + momentum=self.batch_norm_epsilon, + epsilon=self.batch_norm_epsilon, + name=f"output_norm_{level}", + ) + ) + self.output_batch_norms[level].build( + (None, None, None, 256) + if self.data_format == "channels_last" + else (None, 256, None, None) + ) + + # The same upsampling layer is used for all levels + self.top_down_op = keras.layers.UpSampling2D( + size=2, + data_format=self.data_format, + dtype=self.dtype_policy, + name="upsampling", + ) + # The same merge layer is used for all levels + self.merge_op = keras.layers.Add( + dtype=self.dtype_policy, name="merge_op" + ) + + self.built = True + + def call(self, inputs): + """ + Inputs: + The input to the model is expected to be an `Dict[Tensors]`, + containing the feature maps on top of which the FPN + will be added. + + Outputs: + A dictionary of feature maps and added coarser levels based + on minimum and maximum levels provided to the layer. + """ + + output_features = {} + + # Get the backbone max level + input_levels = [int(level[1]) for level in inputs] + backbone_max_level = min(max(input_levels), self.max_level) + + for i in range(backbone_max_level, self.min_level - 1, -1): + level = f"P{i}" + output = self.lateral_conv_layers[level](inputs[level]) + if i < backbone_max_level: + # for the top most output, it doesn't need to merge with any + # upper stream outputs + upstream_output = self.top_down_op(output_features[f"P{i+1}"]) + output = self.merge_op([output, upstream_output]) + output_features[level] = ( + self.lateral_batch_norm_layers[level](output) + if self.use_batch_norm + else output + ) + + # Post apply the output layers so that we don't leak them to the down + # stream level + for i in range(backbone_max_level, self.min_level - 1, -1): + level = f"P{i}" + output_features[level] = self.output_conv_layers[level]( + output_features[level] + ) + + for i in range(backbone_max_level + 1, self.max_level + 1): + level = f"P{i}" + feats_in = output_features[f"P{i-1}"] + if i > backbone_max_level + 1: + feats_in = self.activation(feats_in) + output_features[level] = ( + self.output_batch_norms[level]( + self.output_conv_layers[level](feats_in) + ) + if self.use_batch_norm + else self.output_conv_layers[level](feats_in) + ) + + return output_features + + def get_config(self): + config = super().get_config() + config.update( + { + "min_level": self.min_level, + "max_level": self.max_level, + "num_filters": self.num_filters, + "use_batch_norm": self.use_batch_norm, + "activation": keras.activations.serialize(self.activation), + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + "batch_norm_momentum": self.batch_norm_momentum, + "batch_norm_epsilon": self.batch_norm_epsilon, + "kernel_regularizer": ( + keras.regularizers.serialize(self.kernel_regularizer) + if self.kernel_regularizer is not None + else None + ), + "bias_regularizer": ( + keras.regularizers.serialize(self.bias_regularizer) + if self.bias_regularizer is not None + else None + ), + } + ) + + return config + + def compute_output_shape(self, input_shapes): + output_shape = {} + print(input_shapes) + input_levels = [int(level[1]) for level in input_shapes] + backbone_max_level = min(max(input_levels), self.max_level) + + for i in range(self.min_level, backbone_max_level + 1): + level = f"P{i}" + if self.data_format == "channels_last": + output_shape[level] = input_shapes[level][:-1] + (256,) + else: + output_shape[level] = ( + input_shapes[level][0], + 256, + ) + input_shapes[level][1:3] + + intermediate_shape = input_shapes[f"P{backbone_max_level}"] + intermediate_shape = ( + ( + intermediate_shape[0], + intermediate_shape[1] // 2, + intermediate_shape[2] // 2, + 256, + ) + if self.data_format == "channels_last" + else ( + intermediate_shape[0], + 256, + intermediate_shape[1] // 2, + intermediate_shape[2] // 2, + ) + ) + + for i in range(backbone_max_level + 1, self.max_level + 1): + level = f"P{i}" + output_shape[level] = intermediate_shape + intermediate_shape = ( + ( + intermediate_shape[0], + intermediate_shape[1] // 2, + intermediate_shape[2] // 2, + 256, + ) + if self.data_format == "channels_last" + else ( + intermediate_shape[0], + 256, + intermediate_shape[1] // 2, + intermediate_shape[2] // 2, + ) + ) + + return output_shape diff --git a/keras_hub/src/models/retinanet/feature_pyramid_test.py b/keras_hub/src/models/retinanet/feature_pyramid_test.py new file mode 100644 index 0000000000..728233c6ae --- /dev/null +++ b/keras_hub/src/models/retinanet/feature_pyramid_test.py @@ -0,0 +1,81 @@ +from absl.testing import parameterized +from keras import ops +from keras import random + +from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid +from keras_hub.src.tests.test_case import TestCase + + +class FeaturePyramidTest(TestCase): + def test_layer_behaviors(self): + self.run_layer_test( + cls=FeaturePyramid, + init_kwargs={ + "min_level": 3, + "max_level": 7, + "activation": "relu", + "batch_norm_momentum": 0.99, + "batch_norm_epsilon": 0.0001, + "kernel_initializer": "HeNormal", + "bias_initializer": "Zeros", + }, + input_data={ + "P3": random.uniform(shape=(2, 64, 64, 4)), + "P4": random.uniform(shape=(2, 32, 32, 8)), + "P5": random.uniform(shape=(2, 16, 16, 16)), + }, + expected_output_shape={ + "P3": (2, 64, 64, 256), + "P4": (2, 32, 32, 256), + "P5": (2, 16, 16, 256), + "P6": (2, 8, 8, 256), + "P7": (2, 4, 4, 256), + }, + expected_num_trainable_weights=16, + expected_num_non_trainable_weights=0, + ) + + @parameterized.named_parameters( + ( + "equal_resolutions", + 3, + 7, + {"P3": (2, 16, 16, 3), "P4": (2, 8, 8, 3), "P5": (2, 4, 4, 3)}, + ), + ( + "different_resolutions", + 2, + 6, + { + "P2": (2, 64, 128, 4), + "P3": (2, 32, 64, 8), + "P4": (2, 16, 32, 16), + "P5": (2, 8, 16, 32), + }, + ), + ) + def test_layer_output_shapes(self, min_level, max_level, input_shapes): + layer = FeaturePyramid(min_level=min_level, max_level=max_level) + + inputs = { + level: ops.ones(input_shapes[level]) for level in input_shapes + } + if layer.data_format == "channels_first": + inputs = { + level: ops.transpose(inputs[level], (0, 3, 1, 2)) + for level in inputs + } + + output = layer(inputs) + + for level in inputs: + self.assertEqual( + output[level].shape, + ( + (input_shapes[level][0],) + + (layer.num_filters,) + + input_shapes[level][1:3] + if layer.data_format == "channels_first" + else input_shapes[level][:-1] + (layer.num_filters,) + ), + ) diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder.py b/keras_hub/src/models/retinanet/retinanet_label_encoder.py new file mode 100644 index 0000000000..1803f7062f --- /dev/null +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder.py @@ -0,0 +1,270 @@ +import keras +from keras import ops + +from keras_hub.src.bounding_box.converters import _encode_box_to_deltas +from keras_hub.src.bounding_box.iou import compute_iou +from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator +from keras_hub.src.models.retinanet.box_matcher import BoxMatcher +from keras_hub.src.utils import tensor_utils + + +class RetinaNetLabelEncoder(keras.layers.Layer): + """Transforms the raw labels into targets for training. + + RetinaNet is a single-stage object detection network that uses a feature + pyramid network and focal loss. This class is crucial for preparing the + ground truth data to match the network's anchor-based detection approach. + + This class generates targets for a batch of samples which consists of input + images, bounding boxes for the objects present, and their class ids. It + matches ground truth boxes to anchor boxes based on IoU (Intersection over + Union) and encodes the box coordinates as offsets from the anchors. + + Targets are always represented in 'center_yxwh' format for numerical + consistency during training, regardless of the input format. + + Args: + bounding_box_format: str. The format of bounding boxes of input dataset. + Refer TODO: Add link to Keras Core Docs. + min_level: int. Minimum level of the output feature pyramid. + max_level: int. Maximum level of the output feature pyramid. + num_scales: int. Number of intermediate scales added on each level. + For example, num_scales=2 adds one additional intermediate anchor + scale [2^0, 2^0.5] on each level. + aspect_ratios: List[float]. Aspect ratios of anchors added on + each level. Each number indicates the ratio of width to height. + anchor_size: float. Scale of size of the base anchor relative to the + feature stride 2^level. + positive_threshold: float. the threshold to set an anchor to positive + match to gt box. Values above it are positive matches. + Defaults to `0.5` + negative_threshold: float. the threshold to set an anchor to negative + match to gt box. Values below it are negative matches. + Defaults to `0.4` + box_variance: List[float]. The scaling factors used to scale the + bounding box targets. + Defaults to `[0.1, 0.1, 0.2, 0.2]`. + background_class: int. The class ID used for the background class, + Defaults to `-1`. + ignore_class: int. The class ID used for the ignore class, + Defaults to `-2`. + box_matcher_match_values: List[int]. Representing + matched results (e.g. positive or negative or ignored match). + `len(match_values)` must equal to `len(thresholds) + 1`. + Defaults to `[-1, -2, -1]`. + box_matcher_force_match_for_each_col: bool. If True, each column + (ground truth box) will be matched to at least one row (anchor box). + This means some columns may be matched to multiple rows while others + may not be matched to any. + Defaults to `False`. + + Note: `tf.RaggedTensor` are not supported. + """ + + def __init__( + self, + bounding_box_format, + min_level, + max_level, + num_scales, + aspect_ratios, + anchor_size, + positive_threshold=0.5, + negative_threshold=0.4, + box_variance=[0.1, 0.1, 0.2, 0.2], + background_class=-1.0, + ignore_class=-2.0, + box_matcher_match_values=[-1, -2, 1], + box_matcher_force_match_for_each_col=False, + **kwargs, + ): + super().__init__(**kwargs) + self.bounding_box_format = bounding_box_format + self.min_level = min_level + self.max_level = max_level + self.num_scales = num_scales + self.aspect_ratios = aspect_ratios + self.anchor_size = anchor_size + self.positive_threshold = positive_threshold + self.box_variance = box_variance + self.negative_threshold = negative_threshold + self.background_class = background_class + self.ignore_class = ignore_class + + self.anchor_generator = AnchorGenerator( + bounding_box_format=bounding_box_format, + min_level=min_level, + max_level=max_level, + num_scales=num_scales, + aspect_ratios=aspect_ratios, + anchor_size=anchor_size, + ) + + self.box_matcher = BoxMatcher( + thresholds=[negative_threshold, positive_threshold], + match_values=box_matcher_match_values, + force_match_for_each_col=box_matcher_force_match_for_each_col, + ) + + def build(self, images_shape, gt_boxes_shape, gt_classes_shape): + self.built = True + + def call(self, images, gt_boxes, gt_classes): + """Creates box and classification targets for a batch. + + Args: + images: A Tensor. The input images argument should be + of shape `[B, H, W, C]` or `[B, C, H, W]`. + boxes: A Tensor with shape of `[B, num_boxes, 4]`. + labels: A Tensor with shape of `[B, num_boxes, num_classes]` + + Returns: + encoded_box_targets: A Tensor of shape `[batch_size, num_anchors, 4]` + containing the encoded box targets. + class_targets: A Tensor of shape `[batch_size, num_anchors, 1]` + containing the class targets for each anchor. + """ + + images_shape = ops.shape(images) + if len(images_shape) != 4: + raise ValueError( + "`RetinaNetLabelEncoder`'s `call()` method does not " + "support unbatched inputs for the `images` argument. " + f"Received `shape(images)={images_shape}`." + ) + image_shape = images_shape[1:] + + if len(ops.shape(gt_classes)) == 2: + gt_classes = ops.expand_dims(gt_classes, axis=-1) + + anchor_boxes = self.anchor_generator(images) + anchor_boxes = ops.concatenate(list(anchor_boxes.values()), axis=0) + + box_targets, class_targets = self._encode_sample( + gt_boxes, gt_classes, anchor_boxes, image_shape + ) + box_targets = ops.reshape( + box_targets, (-1, ops.shape(box_targets)[1], 4) + ) + return box_targets, class_targets + + def _encode_sample(self, gt_boxes, gt_classes, anchor_boxes, image_shape): + """Creates box and classification targets for a batched sample. + + Matches ground truth boxes to anchor boxes based on IOU. + 1. Calculates the pairwise IOU for the M `anchor_boxes` and N `gt_boxes` + to get a `(M, N)` shaped matrix. + 2. The ground truth box with the maximum IOU in each row is assigned to + the anchor box provided the IOU is greater than `match_iou`. + 3. If the maximum IOU in a row is less than `ignore_iou`, the anchor + box is assigned with the background class. + 4. The remaining anchor boxes that do not have any class assigned are + ignored during training. + + Args: + gt_boxes: A Tensor of shape `[B, num_boxes, 4]`. Should be in + `bounding_box_format`. + gt_classes: A Tensor fo shape `[B, num_boxes, num_classes, 1]`. + anchor_boxes: A Tensor with the shape `[total_anchors, 4]` + representing all the anchor boxes for a given input image shape, + where each anchor box is of the format `[x, y, width, height]`. + image_shape: Tuple indicating the image shape `[H, W, C]`. + + Returns: + Encoded boudning boxes in the format of `center_yxwh` and + corresponding labels for each encoded bounding box. + """ + + iou_matrix = compute_iou( + anchor_boxes, + gt_boxes, + bounding_box_format=self.bounding_box_format, + image_shape=image_shape, + ) + + matched_gt_idx, matched_vals = self.box_matcher(iou_matrix) + matched_vals = ops.expand_dims(matched_vals, axis=-1) + positive_mask = ops.cast(ops.equal(matched_vals, 1), self.dtype) + ignore_mask = ops.cast(ops.equal(matched_vals, -2), self.dtype) + + matched_gt_boxes = tensor_utils.target_gather(gt_boxes, matched_gt_idx) + + matched_gt_boxes = ops.reshape( + matched_gt_boxes, (-1, ops.shape(matched_gt_boxes)[1], 4) + ) + + box_target = _encode_box_to_deltas( + anchors=anchor_boxes, + boxes=matched_gt_boxes, + anchor_format=self.bounding_box_format, + box_format=self.bounding_box_format, + variance=self.box_variance, + image_shape=image_shape, + ) + + matched_gt_cls_ids = tensor_utils.target_gather( + gt_classes, matched_gt_idx + ) + cls_target = ops.where( + ops.not_equal(positive_mask, 1.0), + self.background_class, + matched_gt_cls_ids, + ) + cls_target = ops.where( + ops.equal(ignore_mask, 1.0), self.ignore_class, cls_target + ) + label = ops.concatenate( + [box_target, ops.cast(cls_target, box_target.dtype)], axis=-1 + ) + + # In the case that a box in the corner of an image matches with an all + # -1 box that is outside the image, we should assign the box to the + # ignore class. There are rare cases where a -1 box can be matched, + # resulting in a NaN during training. The unit test passing all -1s to + # the label encoder ensures that we properly handle this edge-case. + label = ops.where( + ops.expand_dims(ops.any(ops.isnan(label), axis=-1), axis=-1), + self.ignore_class, + label, + ) + + return label[:, :, :4], label[:, :, 4] + + def get_config(self): + config = super().get_config() + config.update( + { + "bounding_box_format": self.bounding_box_format, + "min_level": self.min_level, + "max_level": self.max_level, + "num_scales": self.num_scales, + "aspect_ratios": self.aspect_ratios, + "anchor_size": self.anchor_size, + "positive_threshold": self.positive_threshold, + "box_variance": self.box_variance, + "negative_threshold": self.negative_threshold, + "background_class": self.background_class, + "ignore_class": self.ignore_class, + } + ) + return config + + def compute_output_shape( + self, images_shape, gt_boxes_shape, gt_classes_shape + ): + min_level = self.anchor_generator.min_level + max_level = self.anchor_generator.max_level + batch_size, image_H, image_W = images_shape[:-1] + + total_num_anchors = 0 + for i in range(min_level, max_level + 1): + total_num_anchors += ( + (image_H // 2 ** (i)) + * (image_W // 2 ** (i)) + * self.anchor_generator.anchors_per_location + ) + + return (batch_size, total_num_anchors, 4), ( + batch_size, + total_num_anchors, + ) diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py b/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py new file mode 100644 index 0000000000..de329685a8 --- /dev/null +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py @@ -0,0 +1,85 @@ +import numpy as np +from keras import ops + +from keras_hub.src.models.retinanet.retinanet_label_encoder import ( + RetinaNetLabelEncoder, +) +from keras_hub.src.tests.test_case import TestCase + + +class RetinaNetLabelEncoderTest(TestCase): + def test_layer_behaviors(self): + images_shape = (8, 128, 128, 3) + boxes_shape = (8, 10, 4) + classes_shape = (8, 10) + self.run_layer_test( + cls=RetinaNetLabelEncoder, + init_kwargs={ + "bounding_box_format": "xyxy", + "min_level": 3, + "max_level": 7, + "num_scales": 3, + "aspect_ratios": [0.5, 1.0, 2.0], + "anchor_size": 8, + }, + input_data={ + "images": np.random.uniform(size=images_shape), + "gt_boxes": np.random.uniform( + size=boxes_shape, low=0.0, high=1.0 + ), + "gt_classes": np.random.uniform( + size=classes_shape, low=0, high=5 + ), + }, + expected_output_shape=((8, 3069, 4), (8, 3069)), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + run_training_check=False, + run_precision_checks=False, + ) + + def test_label_encoder_output_shapes(self): + images_shape = (8, 128, 128, 3) + boxes_shape = (8, 10, 4) + classes_shape = (8, 10) + + images = np.random.uniform(size=images_shape) + boxes = np.random.uniform(size=boxes_shape, low=0.0, high=1.0) + classes = np.random.uniform(size=classes_shape, low=0, high=5) + + encoder = RetinaNetLabelEncoder( + bounding_box_format="xyxy", + min_level=3, + max_level=7, + num_scales=3, + aspect_ratios=[0.5, 1.0, 2.0], + anchor_size=8, + ) + + box_targets, class_targets = encoder(images, boxes, classes) + + self.assertEqual(box_targets.shape, (8, 3069, 4)) + self.assertEqual(class_targets.shape, (8, 3069)) + + def test_all_negative_1(self): + images_shape = (8, 128, 128, 3) + boxes_shape = (8, 10, 4) + classes_shape = (8, 10) + + images = np.random.uniform(size=images_shape) + boxes = -np.ones(shape=boxes_shape, dtype="float32") + classes = -np.ones(shape=classes_shape, dtype="float32") + + encoder = RetinaNetLabelEncoder( + bounding_box_format="xyxy", + min_level=3, + max_level=7, + num_scales=3, + aspect_ratios=[0.5, 1.0, 2.0], + anchor_size=8, + ) + + box_targets, class_targets = encoder(images, boxes, classes) + + self.assertFalse(ops.any(ops.isnan(box_targets))) + self.assertFalse(ops.any(ops.isnan(class_targets))) diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index 310d8e8b44..6d06c7266c 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -13,6 +13,7 @@ from keras_hub.src.layers.modeling.reversible_embedding import ( ReversibleEmbedding, ) +from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid from keras_hub.src.tokenizers.tokenizer import Tokenizer from keras_hub.src.utils.keras_utils import has_quantization_support from keras_hub.src.utils.tensor_utils import is_float_dtype @@ -127,7 +128,10 @@ def __init__(self, layer): def call(self, x): if isinstance(x, dict): - return self.layer(**x) + if isinstance(layer, FeaturePyramid): + return self.layer(x) + else: + return self.layer(**x) else: return self.layer(x) @@ -147,7 +151,10 @@ def call(self, x): layer = cls(**init_kwargs) if isinstance(input_data, dict): shapes = {k + "_shape": v.shape for k, v in input_data.items()} - layer.build(**shapes) + if isinstance(layer, FeaturePyramid): + layer.build(shapes) + else: + layer.build(**shapes) else: layer.build(input_data.shape) run_build_asserts(layer) @@ -158,7 +165,10 @@ def call(self, x): ) layer = cls(**init_kwargs) if isinstance(keras_tensor_inputs, dict): - keras_tensor_outputs = layer(**keras_tensor_inputs) + if isinstance(layer, FeaturePyramid): + keras_tensor_outputs = layer(keras_tensor_inputs) + else: + keras_tensor_outputs = layer(**keras_tensor_inputs) else: keras_tensor_outputs = layer(keras_tensor_inputs) run_build_asserts(layer) @@ -167,7 +177,10 @@ def call(self, x): # Eager call test and compiled training test. layer = cls(**init_kwargs) if isinstance(input_data, dict): - output_data = layer(**input_data) + if isinstance(layer, FeaturePyramid): + output_data = layer(input_data) + else: + output_data = layer(**input_data) else: output_data = layer(input_data) run_output_asserts(layer, output_data, eager=True) @@ -305,8 +318,12 @@ def run_precision_test(self, cls, init_kwargs, input_data): output_data = layer(input_data) output_spec = layer.compute_output_spec(input_data) elif isinstance(input_data, dict): - output_data = layer(**input_data) - output_spec = layer.compute_output_spec(**input_data) + if isinstance(layer, FeaturePyramid): + output_data = layer(input_data) + output_spec = layer.compute_output_spec(input_data) + else: + output_data = layer(**input_data) + output_spec = layer.compute_output_spec(**input_data) else: output_data = layer(input_data) output_spec = layer.compute_output_spec(input_data) diff --git a/keras_hub/src/utils/tensor_utils.py b/keras_hub/src/utils/tensor_utils.py index 3d18aae99d..36177c8943 100644 --- a/keras_hub/src/utils/tensor_utils.py +++ b/keras_hub/src/utils/tensor_utils.py @@ -308,3 +308,109 @@ def any_equal(inputs, values, padding_mask): output = ops.logical_or(output, value_equality) return ops.logical_and(output, padding_mask) + + +def target_gather( + targets, + indices, + mask=None, + mask_val=0.0, +): + """A utility function wrapping `ops.take`, which deals with: + 1) both batched and unbatched `targets`. + 2) when unbatched `targets` have empty rows, the result will be filled + with `mask_val`. + 3) target masking. + + Args: + targets: `[N, ...]` or `[batch_size, N, ...]` Tensor representing + targets such as boxes, keypoints, etc. + indices: [M] or [batch_size, M] int32 Tensor representing indices within + `targets` to gather. + mask: `[M, ...]` or `[batch_size, M, ...]` boolean Tensor + representing the masking for each target. `True` means the + corresponding entity should be masked to `mask_val`, `False` + means the corresponding entity should be the target value. + Defaults to `None`. + mask_val: float. representing the masking value if `mask` is True + on the entity. + Defaults to `0.0` + + Returns: + targets: `[M, ...]` or `[batch_size, M, ...]` Tensor representing + selected targets. + + Raise: + ValueError: If `targets` is higher than rank 3. + """ + targets_shape = list(targets.shape) + if len(targets_shape) > 3: + raise ValueError( + f"`target_gather` does not support `targets` with rank " + f"larger than 3, got {len(targets.shape)}" + ) + + def gather_unbatched(labels, match_indices, mask, mask_val): + """Gather based on unbatched labels and boxes.""" + num_gt_boxes = labels.shape[0] + + def assign_when_rows_empty(): + if len(labels.shape) > 1: + mask_shape = [match_indices.shape[0], labels.shape[-1]] + else: + mask_shape = [match_indices.shape[0]] + return ops.cast(mask_val, labels.dtype) * ops.ones( + mask_shape, dtype=labels.dtype + ) + + def assign_when_rows_not_empty(): + targets = ops.take(labels, match_indices, axis=0) + if mask is None: + return targets + else: + masked_targets = ops.cast( + mask_val, labels.dtype + ) * ops.ones_like(mask, dtype=labels.dtype) + return ops.where(mask, masked_targets, targets) + + if num_gt_boxes > 0: + return assign_when_rows_not_empty() + else: + return assign_when_rows_empty() + + def _gather_batched(labels, match_indices, mask, mask_val): + """Gather based on batched labels.""" + batch_size = labels.shape[0] + if batch_size == 1: + if mask is not None: + result = gather_unbatched( + ops.squeeze(labels, axis=0), + ops.squeeze(match_indices, axis=0), + ops.squeeze(mask, axis=0), + mask_val, + ) + else: + result = gather_unbatched( + ops.squeeze(labels, axis=0), + ops.squeeze(match_indices, axis=0), + None, + mask_val, + ) + return ops.expand_dims(result, axis=0) + else: + targets = ops.take_along_axis( + labels, ops.expand_dims(match_indices, axis=-1), axis=1 + ) + + if mask is None: + return targets + else: + masked_targets = ops.cast( + mask_val, labels.dtype + ) * ops.ones_like(mask, dtype=labels.dtype) + return ops.where(mask, masked_targets, targets) + + if len(targets_shape) <= 2: + return gather_unbatched(targets, indices, mask, mask_val) + elif len(targets_shape) == 3: + return _gather_batched(targets, indices, mask, mask_val) diff --git a/keras_hub/src/utils/tensor_utils_test.py b/keras_hub/src/utils/tensor_utils_test.py index 42d04a029b..0b6ef1f346 100644 --- a/keras_hub/src/utils/tensor_utils_test.py +++ b/keras_hub/src/utils/tensor_utils_test.py @@ -10,6 +10,7 @@ from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch from keras_hub.src.utils.tensor_utils import is_tensor_type from keras_hub.src.utils.tensor_utils import preprocessing_function +from keras_hub.src.utils.tensor_utils import target_gather from keras_hub.src.utils.tensor_utils import tensor_to_list @@ -202,3 +203,104 @@ def test_input_shaped_values(self): result = any_equal(inputs, values, padding_mask) result = ops.convert_to_numpy(result) self.assertAllEqual(result, expected_output) + + +class TargetGatherTest(TestCase): + def test_target_gather_boxes_batched(self): + target_boxes = np.array( + [[0, 0, 5, 5], [0, 5, 5, 10], [5, 0, 10, 5], [5, 5, 10, 10]] + ) + target_boxes = ops.expand_dims(target_boxes, axis=0) + indices = np.array([[0, 2]], dtype="int32") + expected_boxes = np.array([[0, 0, 5, 5], [5, 0, 10, 5]]) + expected_boxes = ops.expand_dims(expected_boxes, axis=0) + res = target_gather(target_boxes, indices) + self.assertAllClose(expected_boxes, res) + + def test_target_gather_boxes_unbatched(self): + target_boxes = np.array( + [[0, 0, 5, 5], [0, 5, 5, 10], [5, 0, 10, 5], [5, 5, 10, 10]], + "int32", + ) + indices = np.array([0, 2], dtype="int32") + expected_boxes = np.array([[0, 0, 5, 5], [5, 0, 10, 5]]) + res = target_gather(target_boxes, indices) + self.assertAllClose(expected_boxes, res) + + def test_target_gather_classes_batched(self): + target_classes = np.array([[1, 2, 3, 4]]) + target_classes = ops.expand_dims(target_classes, axis=-1) + indices = np.array([[0, 2]], dtype="int32") + expected_classes = np.array([[1, 3]]) + expected_classes = ops.expand_dims(expected_classes, axis=-1) + res = target_gather(target_classes, indices) + self.assertAllClose(expected_classes, res) + + def test_target_gather_classes_unbatched(self): + target_classes = np.array([1, 2, 3, 4]) + target_classes = ops.expand_dims(target_classes, axis=-1) + indices = np.array([0, 2], dtype="int32") + expected_classes = np.array([1, 3]) + expected_classes = ops.expand_dims(expected_classes, axis=-1) + res = target_gather(target_classes, indices) + self.assertAllClose(expected_classes, res) + + def test_target_gather_classes_batched_with_mask(self): + target_classes = np.array([[1, 2, 3, 4]]) + target_classes = ops.expand_dims(target_classes, axis=-1) + indices = np.array([[0, 2]], dtype="int32") + masks = np.array(([[False, True]])) + masks = ops.expand_dims(masks, axis=-1) + # the second element is masked + expected_classes = np.array([[1, 0]]) + expected_classes = ops.expand_dims(expected_classes, axis=-1) + res = target_gather(target_classes, indices, masks) + self.assertAllClose(expected_classes, res) + + def test_target_gather_classes_batched_with_mask_val(self): + target_classes = np.array([[1, 2, 3, 4]]) + target_classes = ops.expand_dims(target_classes, axis=-1) + indices = np.array([[0, 2]], dtype="int32") + masks = np.array(([[False, True]])) + masks = ops.expand_dims(masks, axis=-1) + # the second element is masked + expected_classes = np.array([[1, -1]]) + expected_classes = ops.expand_dims(expected_classes, axis=-1) + res = target_gather(target_classes, indices, masks, -1) + self.assertAllClose(expected_classes, res) + + def test_target_gather_classes_unbatched_with_mask(self): + target_classes = np.array([1, 2, 3, 4]) + target_classes = ops.expand_dims(target_classes, axis=-1) + indices = np.array([0, 2], dtype="int32") + masks = np.array([False, True]) + masks = ops.expand_dims(masks, axis=-1) + expected_classes = np.array([1, 0]) + expected_classes = ops.expand_dims(expected_classes, axis=-1) + res = target_gather(target_classes, indices, masks) + self.assertAllClose(expected_classes, res) + + def test_target_gather_with_empty_targets(self): + target_classes = np.array([]) + target_classes = ops.expand_dims(target_classes, axis=-1) + indices = np.array([0, 2], dtype="int32") + # return all 0s since input is empty + expected_classes = np.array([0, 0]) + expected_classes = ops.expand_dims(expected_classes, axis=-1) + res = target_gather(target_classes, indices) + self.assertAllClose(expected_classes, res) + + def test_target_gather_classes_multi_batch(self): + target_classes = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) + target_classes = ops.expand_dims(target_classes, axis=-1) + indices = np.array([[0, 2], [1, 3]], dtype="int32") + expected_classes = np.array([[1, 3], [6, 8]]) + expected_classes = ops.expand_dims(expected_classes, axis=-1) + res = target_gather(target_classes, indices) + self.assertAllClose(expected_classes, res) + + def test_target_gather_invalid_rank(self): + targets = np.random.normal(size=[32, 2, 2, 2]) + indices = np.array([0, 1], dtype="int32") + with self.assertRaisesRegex(ValueError, "larger than 3"): + _ = target_gather(targets, indices) From deaeac4eed5b7ae42b4cdae84613422ccc1a0432 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 26 Sep 2024 14:53:12 -0700 Subject: [PATCH 02/35] Rebased phase 1 changes --- keras_hub/src/models/retinanet/retinanet_label_encoder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder.py b/keras_hub/src/models/retinanet/retinanet_label_encoder.py index a5bf475b29..1803f7062f 100644 --- a/keras_hub/src/models/retinanet/retinanet_label_encoder.py +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder.py @@ -115,11 +115,11 @@ def call(self, images, gt_boxes, gt_classes): Args: images: A Tensor. The input images argument should be of shape `[B, H, W, C]` or `[B, C, H, W]`. - gt_boxes: A Tensor with shape of `[B, num_boxes, 4]`. - gt_labels: A Tensor with shape of `[B, num_boxes, num_classes]` + boxes: A Tensor with shape of `[B, num_boxes, 4]`. + labels: A Tensor with shape of `[B, num_boxes, num_classes]` Returns: - box_targets: A Tensor of shape `[batch_size, num_anchors, 4]` + encoded_box_targets: A Tensor of shape `[batch_size, num_anchors, 4]` containing the encoded box targets. class_targets: A Tensor of shape `[batch_size, num_anchors, 1]` containing the class targets for each anchor. From f90add8a1e16051e3f57e76f153e21ff9713afdd Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 27 Sep 2024 16:52:59 -0700 Subject: [PATCH 03/35] nit --- keras_hub/src/models/retinanet/retinanet_label_encoder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder.py b/keras_hub/src/models/retinanet/retinanet_label_encoder.py index 1803f7062f..51c0d188fb 100644 --- a/keras_hub/src/models/retinanet/retinanet_label_encoder.py +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder.py @@ -115,11 +115,11 @@ def call(self, images, gt_boxes, gt_classes): Args: images: A Tensor. The input images argument should be of shape `[B, H, W, C]` or `[B, C, H, W]`. - boxes: A Tensor with shape of `[B, num_boxes, 4]`. - labels: A Tensor with shape of `[B, num_boxes, num_classes]` + gt_boxes: A Tensor with shape of `[B, num_boxes, 4]`. + gt_classes: A Tensor with shape of `[B, num_boxes, num_classes]` Returns: - encoded_box_targets: A Tensor of shape `[batch_size, num_anchors, 4]` + box_targets: A Tensor of shape `[batch_size, num_anchors, 4]` containing the encoded box targets. class_targets: A Tensor of shape `[batch_size, num_anchors, 1]` containing the class targets for each anchor. From 6c26534c215747d38bd4b0b7a2102acebde9ec6a Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 3 Oct 2024 11:37:03 -0700 Subject: [PATCH 04/35] Retina Phase 2 --- keras_hub/api/layers/__init__.py | 3 + keras_hub/api/models/__init__.py | 11 + keras_hub/src/models/image_object_detector.py | 84 ++++++ .../image_object_detector_preprocessor.py | 117 ++++++++ .../src/models/retinanet/feature_pyramid.py | 63 ++-- .../src/models/retinanet/prediction_head.py | 148 ++++++++++ .../models/retinanet/prediction_head_test.py | 17 ++ .../models/retinanet/retinanet_backbone.py | 84 ++++++ .../retinanet/retinanet_backbone_test.py | 53 ++++ .../retinanet/retinanet_image_converter.py | 8 + .../retinanet/retinanet_label_encoder.py | 61 ++-- .../retinanet/retinanet_label_encoder_test.py | 29 +- .../retinanet/retinanet_object_detector.py | 270 ++++++++++++++++++ .../retinanet_object_detector_preprocessor.py | 14 + .../retinanet_object_detector_test.py | 101 +++++++ keras_hub/src/tests/test_case.py | 1 + 16 files changed, 991 insertions(+), 73 deletions(-) create mode 100644 keras_hub/src/models/image_object_detector.py create mode 100644 keras_hub/src/models/image_object_detector_preprocessor.py create mode 100644 keras_hub/src/models/retinanet/prediction_head.py create mode 100644 keras_hub/src/models/retinanet/prediction_head_test.py create mode 100644 keras_hub/src/models/retinanet/retinanet_backbone.py create mode 100644 keras_hub/src/models/retinanet/retinanet_backbone_test.py create mode 100644 keras_hub/src/models/retinanet/retinanet_image_converter.py create mode 100644 keras_hub/src/models/retinanet/retinanet_object_detector.py create mode 100644 keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py create mode 100644 keras_hub/src/models/retinanet/retinanet_object_detector_test.py diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 6b85148caf..0d3ed939bc 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -43,6 +43,9 @@ from keras_hub.src.models.resnet.resnet_image_converter import ( ResNetImageConverter, ) +from keras_hub.src.models.retinanet.retinanet_image_converter import ( + RetinaNetImageConverter, +) from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 371277465a..856e25ba92 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -167,6 +167,10 @@ from keras_hub.src.models.image_classifier_preprocessor import ( ImageClassifierPreprocessor, ) +from keras_hub.src.models.image_object_detector import ImageObjectDetector +from keras_hub.src.models.image_object_detector_preprocessor import ( + ImageObjectDetectorPreprocessor, +) from keras_hub.src.models.image_segmenter import ImageSegmenter from keras_hub.src.models.image_segmenter_preprocessor import ( ImageSegmenterPreprocessor, @@ -233,6 +237,13 @@ from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import ( ResNetImageClassifierPreprocessor, ) +from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone +from keras_hub.src.models.retinanet.retinanet_object_detector import ( + RetinaNetObjectDetector, +) +from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( + RetinaNetObjectDetectorPreprocessor, +) from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone from keras_hub.src.models.roberta.roberta_masked_lm import RobertaMaskedLM from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import ( diff --git a/keras_hub/src/models/image_object_detector.py b/keras_hub/src/models/image_object_detector.py new file mode 100644 index 0000000000..f4723f5cdf --- /dev/null +++ b/keras_hub/src/models/image_object_detector.py @@ -0,0 +1,84 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.task import Task + + +@keras_hub_export("keras_hub.models.ImageObjectDetector") +class ImageObjectDetector(Task): + """Base class for all image classification tasks. + + `ImageObjectDetector` tasks wrap a `keras_hub.models.Backbone` and + a `keras_hub.models.Preprocessor` to create a model that can be used for + image classification. `ImageObjectDetector` tasks take an additional + `num_classes` argument, controlling the number of predicted output classes. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a string and `y` is dictionary with `boxes` and + `classes`. + + All `ImageObjectDetector` tasks include a `from_preset()` constructor which + can be used to load a pre-trained config and weights. + """ + + def compile( + self, + optimizer="auto", + box_loss="auto", + classification_loss="auto", + metrics=None, + **kwargs, + ): + """Configures the `ImageSegmenter` task for training. + + The `ImageSegmenter` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + box_loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.Huber` loss will be + applied for the object detector task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + classification_loss: `"auto"`, a loss name, or a `keras.losses.Loss` + instance. Defaults to `"auto"`, where a + `keras.losses.BinaryFocalCrossentropy` loss will be + applied for the object detector task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `a list of metrics to be evaluated by + the model during training and testing. Defaults to `None`. + See `keras.Model.compile` and `keras.metrics` for + more info on possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if optimizer == "auto": + optimizer = keras.optimizers.Adam(5e-5) + if box_loss == "auto": + box_loss = keras.losses.Huber(reduction="sum") + if classification_loss == "auto": + activation = getattr(self, "activation", None) + activation = keras.activations.get(activation) + from_logits = activation != keras.activations.sigmoid + classification_loss = keras.losses.BinaryFocalCrossentropy( + from_logits=from_logits, reduction="sum" + ) + if metrics is not None: + raise ValueError("User metrics not yet supported") + + losses = {"box": box_loss, "classification": classification_loss} + + super().compile( + optimizer=optimizer, + loss=losses, + metrics=metrics, + **kwargs, + ) diff --git a/keras_hub/src/models/image_object_detector_preprocessor.py b/keras_hub/src/models/image_object_detector_preprocessor.py new file mode 100644 index 0000000000..0c8eb85e12 --- /dev/null +++ b/keras_hub/src/models/image_object_detector_preprocessor.py @@ -0,0 +1,117 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.bounding_box.converters import convert_format +from keras_hub.src.models.preprocessor import Preprocessor +from keras_hub.src.utils.tensor_utils import preprocessing_function + + +@keras_hub_export("keras_hub.models.ImageObjectDetectorPreprocessor") +class ImageObjectDetectorPreprocessor(Preprocessor): + """Base class for object detector preprocessing layers. + + `ImageObjectDetectorPreprocessor` tasks wraps a + `keras_hub.layers.Preprocessor` to create a preprocessing layer for + object detection tasks. It is intended to be paired with a + `keras_hub.models.ImageObjectDetector` task. + + All `ImageObjectDetectorPreprocessor` take inputs three inputs, `x`, `y`, and + `sample_weight`. `x`, the first input, should always be included. It can + be a image or batch of images. See examples below. `y` and `sample_weight` + are optional inputs that will be passed through unaltered. Usually, `y` will + be the a dict of `{"boxes": Tensor(batch_size, num_boxes, 4), + "classes": (batch_size, num_boxes)}. + + The layer will output either `x`, an `(x, y)` tuple if labels were provided, + or an `(x, y, sample_weight)` tuple if labels and sample weight were + provided. `x` will be the input images after all model preprocessing has + been applied. + + All `ImageObjectDetectorPreprocessor` tasks include a `from_preset()` + constructor which can be used to load a pre-trained config and vocabularies. + You can call the `from_preset()` constructor directly on this base class, in + which case the correct class for your model will be automatically + instantiated. + + Args: + image_converter: Preprocessing pipeline for images. + source_bounding_box_format: str. The format of the source bounding boxes. + supported formats include: + - `"rel_yxyx"` + - `"rel_xyxy"` + - `"rel_xywh" + Defaults to `"rel_yxyx"`. + target_bounding_box_format: str. TODO Add link to keras-core bounding + box formats page. + + + Examples. + ```python + preprocessor = keras_hub.models.ImageObjectDetectorPreprocessor.from_preset( + "retinanet_resnet50", + ) + + # Resize a single image for resnet 50. + x = np.ones((512, 512, 3)) + x = preprocessor(x) + + # Resize a labeled image. + x, y = np.ones((512, 512, 3)), 1 + x, y = preprocessor(x, y) + + # Resize a batch of labeled images. + x, y = [np.ones((512, 512, 3)), np.zeros((512, 512, 3))], [1, 0] + x, y = preprocessor(x, y) + + # Use a `tf.data.Dataset`. + ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(2) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def __init__( + self, + target_bounding_box_format, + source_bounding_box_format="rel_yxyx", + image_converter=None, + **kwargs, + ): + super().__init__(**kwargs) + if "rel" not in source_bounding_box_format: + raise ValueError( + f"Only relative bounding box formats are supported " + f"but received source_bounding_box_format=" + f"`{source_bounding_box_format}` " + f"please provide source bounding box format from one of these " + f"`rel_xyxy` or `rel_yxyx` or `rel_xywh`. Make sure provided " + f"ground truth bounding boxes are normalized/relative to image." + ) + self.source_bounding_box_format = source_bounding_box_format + self.target_bounding_box_format = target_bounding_box_format + self.image_converter = image_converter + + @preprocessing_function + def call(self, x, y=None, sample_weight=None): + if self.image_converter: + x = self.image_converter(x) + + if y is not None: + y = convert_format( + y, + source=self.source_bounding_box_format, + target=self.target_bounding_box_format, + images=x, + ) + + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + def get_config(self): + config = super().get_config() + config.update( + { + "source_bounding_box_format": self.source_bounding_box_format, + "target_bounding_box_format": self.target_bounding_box_format, + } + ) + + return config diff --git a/keras_hub/src/models/retinanet/feature_pyramid.py b/keras_hub/src/models/retinanet/feature_pyramid.py index 5c0bbb906c..1533ad9773 100644 --- a/keras_hub/src/models/retinanet/feature_pyramid.py +++ b/keras_hub/src/models/retinanet/feature_pyramid.py @@ -149,9 +149,9 @@ def build(self, input_shapes): ) ) self.lateral_batch_norm_layers[level].build( - (None, None, None, 256) + (None, None, None, self.num_filters) if self.data_format == "channels_last" - else (None, 256, None, None) + else (None, self.num_filters, None, None) ) # Build output layers @@ -171,9 +171,9 @@ def build(self, input_shapes): name=f"output_conv_{level}", ) self.output_conv_layers[level].build( - (None, None, None, 256) + (None, None, None, self.num_filters) if self.data_format == "channels_last" - else (None, 256, None, None) + else (None, self.num_filters, None, None) ) # Build coarser layers @@ -193,9 +193,9 @@ def build(self, input_shapes): name=f"coarser_{level}", ) self.output_conv_layers[level].build( - (None, None, None, 256) + (None, None, None, self.num_filters) if self.data_format == "channels_last" - else (None, 256, None, None) + else (None, self.num_filters, None, None) ) # Build batch norm layers @@ -212,9 +212,9 @@ def build(self, input_shapes): ) ) self.output_batch_norms[level].build( - (None, None, None, 256) + (None, None, None, self.num_filters) if self.data_format == "channels_last" - else (None, 256, None, None) + else (None, self.num_filters, None, None) ) # The same upsampling layer is used for all levels @@ -320,34 +320,35 @@ def get_config(self): def compute_output_shape(self, input_shapes): output_shape = {} - print(input_shapes) input_levels = [int(level[1]) for level in input_shapes] backbone_max_level = min(max(input_levels), self.max_level) for i in range(self.min_level, backbone_max_level + 1): level = f"P{i}" if self.data_format == "channels_last": - output_shape[level] = input_shapes[level][:-1] + (256,) + output_shape[level] = input_shapes[level][:-1] + ( + self.num_filters, + ) else: output_shape[level] = ( input_shapes[level][0], - 256, + self.num_filters, ) + input_shapes[level][1:3] intermediate_shape = input_shapes[f"P{backbone_max_level}"] intermediate_shape = ( ( intermediate_shape[0], - intermediate_shape[1] // 2, - intermediate_shape[2] // 2, - 256, + intermediate_shape[1] // 2 if intermediate_shape[1] else None, + intermediate_shape[2] // 2 if intermediate_shape[1] else None, + self.num_filters, ) if self.data_format == "channels_last" else ( intermediate_shape[0], - 256, - intermediate_shape[1] // 2, - intermediate_shape[2] // 2, + self.num_filters, + intermediate_shape[1] // 2 if intermediate_shape[1] else None, + intermediate_shape[2] // 2 if intermediate_shape[1] else None, ) ) @@ -357,16 +358,32 @@ def compute_output_shape(self, input_shapes): intermediate_shape = ( ( intermediate_shape[0], - intermediate_shape[1] // 2, - intermediate_shape[2] // 2, - 256, + ( + intermediate_shape[1] // 2 + if intermediate_shape[1] + else None + ), + ( + intermediate_shape[2] // 2 + if intermediate_shape[1] + else None + ), + self.num_filters, ) if self.data_format == "channels_last" else ( intermediate_shape[0], - 256, - intermediate_shape[1] // 2, - intermediate_shape[2] // 2, + self.num_filters, + ( + intermediate_shape[1] // 2 + if intermediate_shape[1] + else None + ), + ( + intermediate_shape[2] // 2 + if intermediate_shape[1] + else None + ), ) ) diff --git a/keras_hub/src/models/retinanet/prediction_head.py b/keras_hub/src/models/retinanet/prediction_head.py new file mode 100644 index 0000000000..35be2e9551 --- /dev/null +++ b/keras_hub/src/models/retinanet/prediction_head.py @@ -0,0 +1,148 @@ +import keras + + +class PredictionHead(keras.layers.Layer): + """The classification/box predictions head. + + Arguments: + output_filters: int. Number of convolution filters in the final layer. + num_filters: int. Number of convolution filters used in base layers. + Defaults to `256`. + num_conv_layers: int. Number of convolution layers before final layer. + Defaults to `4`. + kernel_initializer: `str` or `keras.initializers` initializer. + The kernel initializer for the convolution layers. + Defaults to `"random_normal"`. + bias_initializer: `str` or `keras.initializers` initializer. + The bias initializer for the convolution layers. + Defaults to `"zeros"`. + kernel_regularizer: `str` or `keras.regularizers` regularizer. + The kernel regularizer for the convolution layers. + Defaults to `None`. + bias_regularizer: `str` or `keras.regularizers` regularizer. + The bias regularizer for the convolution layers. + Defaults to `None`. + + Returns: + A function representing either the classification + or the box regression head depending on `output_filters`. + """ + + def __init__( + self, + output_filters, + num_filters=256, + num_conv_layers=4, + activation="relu", + kernel_initializer="random_normal", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.output_filters = output_filters + self.num_filters = num_filters + self.num_conv_layers = num_conv_layers + self.activation = keras.activations.get(activation) + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + if kernel_regularizer is not None: + self.kernel_regularizer = keras.regularizers.get(kernel_regularizer) + else: + self.kernel_regularizer = None + if bias_regularizer is not None: + self.bias_regularizer = keras.regularizers.get(bias_regularizer) + else: + self.bias_regularizer = None + + self.data_format = keras.backend.image_data_format() + + def build(self, input_shape): + self.conv_layers = [ + keras.layers.Conv2D( + self.num_filters, + kernel_size=3, + padding="same", + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + kernel_regularizer=self.kernel_regularizer, + bias_regularizer=self.bias_regularizer, + activation=self.activation, + data_format=self.data_format, + dtype=self.dtype_policy, + ) + for _ in range(self.num_conv_layers) + ] + + intermediate_shape = input_shape + for conv in self.conv_layers: + conv.build(intermediate_shape) + intermediate_shape = ( + input_shape[:-1] + (self.num_filters,) + if self.data_format == "channels_last" + else (input_shape[0], self.num_filters) + (input_shape[1:-1]) + ) + + self.prediction_layer = keras.layers.Conv2D( + self.output_filters, + kernel_size=3, + strides=1, + padding="same", + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + kernel_regularizer=self.kernel_regularizer, + bias_regularizer=self.bias_regularizer, + dtype=self.dtype_policy, + ) + + self.prediction_layer.build( + (None, None, None, self.num_filters) + if self.data_format == "channels_last" + else (None, self.num_filters, None, None) + ) + + self.built = True + + def call(self, input): + x = input + for conv in self.conv_layers: + x = conv(x) + output = self.prediction_layer(x) + return output + + def get_config(self): + config = super().get_config() + config.update( + { + "output_filters": self.output_filters, + "num_filters": self.num_filters, + "num_conv_layers": self.num_conv_layers, + "activation": keras.activations.serialize(self.activation), + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "kernel_regularizer": ( + keras.regularizers.serialize(self.kernel_regularizer) + if self.kernel_regularizer is not None + else None + ), + "bias_regularizer": ( + keras.regularizers.serialize(self.bias_regularizer) + if self.bias_regularizer is not None + else None + ), + } + ) + return config + + def compute_output_shape(self, input_shape): + return ( + input_shape[:-1] + (self.output_filters,) + if self.data_format == "channels_last" + else (input_shape[0],) + (self.output_filters,) + input_shape[1:-1] + ) diff --git a/keras_hub/src/models/retinanet/prediction_head_test.py b/keras_hub/src/models/retinanet/prediction_head_test.py new file mode 100644 index 0000000000..ca1b949fed --- /dev/null +++ b/keras_hub/src/models/retinanet/prediction_head_test.py @@ -0,0 +1,17 @@ +from keras import random + +from keras_hub.src.models.retinanet.prediction_head import PredictionHead +from keras_hub.src.tests.test_case import TestCase + + +class FeaturePyramidTest(TestCase): + def test_layer_behaviors(self): + self.run_layer_test( + cls=PredictionHead, + init_kwargs={ + "output_filters": 9 * 4, # anchors_per_location * box length(4) + }, + input_data=random.uniform(shape=(2, 64, 64, 256)), + expected_output_shape=(2, 64, 64, 36), + expected_num_trainable_weights=10, + ) diff --git a/keras_hub/src/models/retinanet/retinanet_backbone.py b/keras_hub/src/models/retinanet/retinanet_backbone.py new file mode 100644 index 0000000000..23567cb3d9 --- /dev/null +++ b/keras_hub/src/models/retinanet/retinanet_backbone.py @@ -0,0 +1,84 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone +from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid +from keras_hub.src.utils.keras_utils import standardize_data_format + + +@keras_hub_export("keras_hub.models.RetinaNetBackbone") +class RetinaNetBackbone(FeaturePyramidBackbone): + def __init__( + self, + backbone, + min_level, + max_level, + image_shape=(None, None, 3), + data_format=None, + dtype=None, + **kwargs, + ): + + if min_level > max_level: + raise ValueError( + f"Minimum level ({min_level}) must be less than or equal to " + f"maximum level ({max_level})." + ) + + data_format = standardize_data_format(data_format) + input_levels = [int(level[1]) for level in backbone.pyramid_outputs] + backbone_max_level = min(max(input_levels), max_level) + image_encoder = keras.Model( + inputs=backbone.input, + outputs={ + f"P{i}": backbone.pyramid_outputs[f"P{i}"] + for i in range(min_level, backbone_max_level + 1) + }, + name="backbone", + ) + + feature_pyramid = FeaturePyramid( + min_level=min_level, max_level=max_level, name="fpn", dtype=dtype + ) + + # === Functional model === + image_input = keras.layers.Input(image_shape, name="images") + + image_encoder_outputs = image_encoder(image_input) + feature_pyramid_outputs = feature_pyramid(image_encoder_outputs) + + # === config === + self.min_level = min_level + self.max_level = max_level + self.backbone = backbone + self.feature_pyramid = feature_pyramid + self.image_shape = image_shape + + super().__init__( + inputs=image_input, + outputs=feature_pyramid_outputs, + dtype=dtype, + **kwargs, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "backbone": keras.layers.serialize(self.backbone), + "min_level": self.min_level, + "max_level": self.max_level, + "image_shape": self.image_shape, + } + ) + return config + + @classmethod + def from_config(cls, config): + config.update( + { + "backbone": keras.layers.deserialize(config["backbone"]), + } + ) + + return super().from_config(config) diff --git a/keras_hub/src/models/retinanet/retinanet_backbone_test.py b/keras_hub/src/models/retinanet/retinanet_backbone_test.py new file mode 100644 index 0000000000..f2e5db3eda --- /dev/null +++ b/keras_hub/src/models/retinanet/retinanet_backbone_test.py @@ -0,0 +1,53 @@ +import pytest +from keras import ops + +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class RetinaNetBackboneTest(TestCase): + def setUp(self): + resnet_kwargs = { + "input_conv_filters": [64], + "input_conv_kernel_sizes": [7], + "stackwise_num_filters": [64, 64, 64], + "stackwise_num_blocks": [2, 2, 2], + "stackwise_num_strides": [1, 2, 2], + "block_type": "bottleneck_block", + "use_pre_activation": False, + } + backbone = ResNetBackbone(**resnet_kwargs) + + self.init_kwargs = { + "backbone": backbone, + "min_level": 3, + "max_level": 7, + } + + self.input_size = 256 + self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) + + def test_backbone_basics(self): + self.run_vision_backbone_test( + cls=RetinaNetBackbone, + init_kwargs={**self.init_kwargs}, + input_data=self.input_data, + expected_output_shape={ + "P3": (2, 32, 32, 256), + "P4": (2, 16, 16, 256), + "P5": (2, 8, 8, 256), + "P6": (2, 4, 4, 256), + "P7": (2, 2, 2, 256), + }, + expected_pyramid_output_keys=False, + run_mixed_precision_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=RetinaNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/retinanet/retinanet_image_converter.py b/keras_hub/src/models/retinanet/retinanet_image_converter.py new file mode 100644 index 0000000000..b37091fd6f --- /dev/null +++ b/keras_hub/src/models/retinanet/retinanet_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone + + +@keras_hub_export("keras_hub.layers.RetinaNetImageConverter") +class RetinaNetImageConverter(ImageConverter): + backbone_cls = RetinaNetBackbone diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder.py b/keras_hub/src/models/retinanet/retinanet_label_encoder.py index 51c0d188fb..b50ad958dc 100644 --- a/keras_hub/src/models/retinanet/retinanet_label_encoder.py +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder.py @@ -1,9 +1,10 @@ import keras +import keras.src from keras import ops from keras_hub.src.bounding_box.converters import _encode_box_to_deltas +from keras_hub.src.bounding_box.converters import convert_format from keras_hub.src.bounding_box.iou import compute_iou -from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator from keras_hub.src.models.retinanet.box_matcher import BoxMatcher from keras_hub.src.utils import tensor_utils @@ -24,17 +25,9 @@ class RetinaNetLabelEncoder(keras.layers.Layer): consistency during training, regardless of the input format. Args: + anchor_generator: TODO: Add anchor_generator exposed layer details. bounding_box_format: str. The format of bounding boxes of input dataset. Refer TODO: Add link to Keras Core Docs. - min_level: int. Minimum level of the output feature pyramid. - max_level: int. Maximum level of the output feature pyramid. - num_scales: int. Number of intermediate scales added on each level. - For example, num_scales=2 adds one additional intermediate anchor - scale [2^0, 2^0.5] on each level. - aspect_ratios: List[float]. Aspect ratios of anchors added on - each level. Each number indicates the ratio of width to height. - anchor_size: float. Scale of size of the base anchor relative to the - feature stride 2^level. positive_threshold: float. the threshold to set an anchor to positive match to gt box. Values above it are positive matches. Defaults to `0.5` @@ -63,12 +56,8 @@ class RetinaNetLabelEncoder(keras.layers.Layer): def __init__( self, + anchor_generator, bounding_box_format, - min_level, - max_level, - num_scales, - aspect_ratios, - anchor_size, positive_threshold=0.5, negative_threshold=0.4, box_variance=[0.1, 0.1, 0.2, 0.2], @@ -79,27 +68,14 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + self.anchor_generator = anchor_generator self.bounding_box_format = bounding_box_format - self.min_level = min_level - self.max_level = max_level - self.num_scales = num_scales - self.aspect_ratios = aspect_ratios - self.anchor_size = anchor_size self.positive_threshold = positive_threshold self.box_variance = box_variance self.negative_threshold = negative_threshold self.background_class = background_class self.ignore_class = ignore_class - self.anchor_generator = AnchorGenerator( - bounding_box_format=bounding_box_format, - min_level=min_level, - max_level=max_level, - num_scales=num_scales, - aspect_ratios=aspect_ratios, - anchor_size=anchor_size, - ) - self.box_matcher = BoxMatcher( thresholds=[negative_threshold, positive_threshold], match_values=box_matcher_match_values, @@ -174,7 +150,12 @@ def _encode_sample(self, gt_boxes, gt_classes, anchor_boxes, image_shape): Encoded boudning boxes in the format of `center_yxwh` and corresponding labels for each encoded bounding box. """ - + anchor_boxes = convert_format( + anchor_boxes, + source=self.anchor_generator.bounding_box_format, + target=self.bounding_box_format, + image_shape=image_shape, + ) iou_matrix = compute_iou( anchor_boxes, gt_boxes, @@ -234,12 +215,10 @@ def get_config(self): config = super().get_config() config.update( { + "anchor_generator": keras.layers.serialize( + self.anchor_generator + ), "bounding_box_format": self.bounding_box_format, - "min_level": self.min_level, - "max_level": self.max_level, - "num_scales": self.num_scales, - "aspect_ratios": self.aspect_ratios, - "anchor_size": self.anchor_size, "positive_threshold": self.positive_threshold, "box_variance": self.box_variance, "negative_threshold": self.negative_threshold, @@ -249,6 +228,18 @@ def get_config(self): ) return config + @classmethod + def from_config(cls, config): + config.update( + { + "anchor_generator": keras.layers.deserialize( + config["anchor_generator"] + ), + } + ) + + return super().from_config(config) + def compute_output_shape( self, images_shape, gt_boxes_shape, gt_classes_shape ): diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py b/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py index de329685a8..f5097342e7 100644 --- a/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py @@ -1,6 +1,7 @@ import numpy as np from keras import ops +from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator from keras_hub.src.models.retinanet.retinanet_label_encoder import ( RetinaNetLabelEncoder, ) @@ -8,6 +9,16 @@ class RetinaNetLabelEncoderTest(TestCase): + def setUp(self): + self.anchor_generator = AnchorGenerator( + bounding_box_format="xyxy", + min_level=3, + max_level=7, + num_scales=3, + aspect_ratios=[0.5, 1.0, 2.0], + anchor_size=8, + ) + def test_layer_behaviors(self): images_shape = (8, 128, 128, 3) boxes_shape = (8, 10, 4) @@ -15,12 +26,8 @@ def test_layer_behaviors(self): self.run_layer_test( cls=RetinaNetLabelEncoder, init_kwargs={ + "anchor_generator": self.anchor_generator, "bounding_box_format": "xyxy", - "min_level": 3, - "max_level": 7, - "num_scales": 3, - "aspect_ratios": [0.5, 1.0, 2.0], - "anchor_size": 8, }, input_data={ "images": np.random.uniform(size=images_shape), @@ -48,12 +55,8 @@ def test_label_encoder_output_shapes(self): classes = np.random.uniform(size=classes_shape, low=0, high=5) encoder = RetinaNetLabelEncoder( + anchor_generator=self.anchor_generator, bounding_box_format="xyxy", - min_level=3, - max_level=7, - num_scales=3, - aspect_ratios=[0.5, 1.0, 2.0], - anchor_size=8, ) box_targets, class_targets = encoder(images, boxes, classes) @@ -71,12 +74,8 @@ def test_all_negative_1(self): classes = -np.ones(shape=classes_shape, dtype="float32") encoder = RetinaNetLabelEncoder( + anchor_generator=self.anchor_generator, bounding_box_format="xyxy", - min_level=3, - max_level=7, - num_scales=3, - aspect_ratios=[0.5, 1.0, 2.0], - anchor_size=8, ) box_targets, class_targets = encoder(images, boxes, classes) diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py new file mode 100644 index 0000000000..e3146b088d --- /dev/null +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -0,0 +1,270 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.bounding_box.converters import _decode_deltas_to_boxes +from keras_hub.src.bounding_box.converters import convert_format +from keras_hub.src.models.image_object_detector import ImageObjectDetector +from keras_hub.src.models.retinanet.non_max_supression import NonMaxSuppression +from keras_hub.src.models.retinanet.prediction_head import PredictionHead +from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone +from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( + RetinaNetObjectDetectorPreprocessor, +) + +BOX_VARIANCE = [0.1, 0.1, 0.2, 0.2] + + +@keras_hub_export("keras_hub.models.RetinaNetObjectDetector") +class RetinaNetObjectDetector(ImageObjectDetector): + + backbone_cls = RetinaNetBackbone + preprocessor_cls = RetinaNetObjectDetectorPreprocessor + + def __init__( + self, + backbone, + label_encoder, + anchor_generator, + num_classes, + bounding_box_format, + preprocessor=None, + activation=None, + head_dtype=None, + prediction_decoder=None, + **kwargs, + ): + # === Layers === + head_dtype = head_dtype or backbone.dtype_policy + prior_probability = keras.initializers.Constant( + -1 * keras.ops.log((1 - 0.01) / 0.01) + ) + box_head = PredictionHead( + anchor_generator.anchors_per_location * 4, + bias_initializer=prior_probability, + dtype=head_dtype, + ) + + classification_head = PredictionHead( + anchor_generator.anchors_per_location * num_classes, + dtype=head_dtype, + ) + + # === Functional Model === + image_input = keras.layers.Input(backbone.image_shape, name="images") + + feature_map = backbone(image_input) + + cls_pred = [] + box_pred = [] + for level in feature_map: + box_pred.append( + keras.layers.Reshape((-1, 4))(box_head(feature_map[level])) + ) + cls_pred.append( + keras.layers.Reshape((-1, num_classes))( + classification_head(feature_map[level]) + ) + ) + + cls_pred = keras.layers.Concatenate(axis=1, name="classification")( + cls_pred + ) + # box_pred is always in "center_yxhw" delta-encoded no matter what + # format you pass in. + box_pred = keras.layers.Concatenate(axis=1, name="box")(box_pred) + + outputs = {"box": box_pred, "classification": cls_pred} + + # === Config === + self.bounding_box_format = bounding_box_format + self.num_classes = num_classes + self.backbone = backbone + self.preprocessor = preprocessor + self.label_encoder = label_encoder + self.anchor_generator = anchor_generator + self.activation = activation + self.box_head = box_head + self.classification_head = classification_head + self._prediction_decoder = prediction_decoder or NonMaxSuppression( + from_logits=(activation != keras.activations.sigmoid), + bounding_box_format=bounding_box_format, + ) + + super().__init__( + inputs=image_input, + outputs=outputs, + **kwargs, + ) + + def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): + y_for_label_encoder = convert_format( + y, + source=self.bounding_box_format, + target=self.label_encoder.bounding_box_format, + images=x, + ) + + boxes, classes = self.label_encoder( + images=x, + gt_boxes=y_for_label_encoder["boxes"], + gt_classes=y_for_label_encoder["classes"], + ) + + box_pred = y_pred["box"] + cls_pred = y_pred["classification"] + + if boxes.shape[-1] != 4: + raise ValueError( + "boxes should have shape (None, None, 4). Got " + f"boxes.shape={tuple(boxes.shape)}" + ) + + if box_pred.shape[-1] != 4: + raise ValueError( + "box_pred should have shape (None, None, 4). Got " + f"box_pred.shape={tuple(box_pred.shape)}. Does your model's " + "`num_classes` parameter match your losses `num_classes` " + "parameter?" + ) + if cls_pred.shape[-1] != self.num_classes: + raise ValueError( + "cls_pred should have shape (None, None, 4). Got " + f"cls_pred.shape={tuple(cls_pred.shape)}. Does your model's " + "`num_classes` parameter match your losses `num_classes` " + "parameter?" + ) + + cls_labels = ops.one_hot( + ops.cast(classes, "int32"), self.num_classes, dtype="float32" + ) + positive_mask = ops.cast(ops.greater(classes, -1.0), dtype="float32") + normalizer = ops.sum(positive_mask) + cls_weights = ops.cast(ops.not_equal(classes, -2.0), dtype="float32") + cls_weights /= normalizer + box_weights = positive_mask / normalizer + + y_true = { + "box": boxes, + "classification": cls_labels, + } + sample_weights = { + "box": box_weights, + "classification": cls_weights, + } + zero_weight = { + "box": ops.zeros_like(box_weights), + "classification": ops.zeros_like(cls_weights), + } + + sample_weight = ops.cond( + normalizer == 0, + lambda: zero_weight, + lambda: sample_weights, + ) + return super().compute_loss( + x=x, y=y_true, y_pred=y_pred, sample_weight=sample_weight, **kwargs + ) + + def predict_step(self, *args): + outputs = super().predict_step(*args) + return self.decode_predictions(outputs, args[-1]) + + def decode_predictions(self, predictions, data): + if isinstance(data, tuple): + images = data[0] + else: + images = data + box_pred, cls_pred = predictions["box"], predictions["classification"] + # box_pred is on "center_yxhw" format, convert to target format. + image_shape = ops.shape(images)[1:] + anchor_boxes = self.anchor_generator(images) + anchor_boxes = ops.concatenate(list(anchor_boxes.values()), axis=0) + box_pred = _decode_deltas_to_boxes( + anchors=anchor_boxes, + boxes_delta=box_pred, + anchor_format=self.anchor_generator.bounding_box_format, + box_format=self.bounding_box_format, + variance=BOX_VARIANCE, + image_shape=image_shape, + ) + # box_pred is now in "self.bounding_box_format" format + box_pred = convert_format( + box_pred, + source=self.bounding_box_format, + target=self.prediction_decoder.bounding_box_format, + image_shape=image_shape, + ) + y_pred = self.prediction_decoder( + box_pred, cls_pred, image_shape=image_shape + ) + y_pred["boxes"] = convert_format( + y_pred["boxes"], + source=self.prediction_decoder.bounding_box_format, + target=self.bounding_box_format, + image_shape=image_shape, + ) + return y_pred + + @property + def prediction_decoder(self): + return self._prediction_decoder + + @prediction_decoder.setter + def prediction_decoder(self, prediction_decoder): + if prediction_decoder.bounding_box_format != self.bounding_box_format: + raise ValueError( + "Expected `prediction_decoder` and RetinaNet to " + "use the same `bounding_box_format`, but got " + "`prediction_decoder.bounding_box_format=" + f"{prediction_decoder.bounding_box_format}`, and " + "`self.bounding_box_format=" + f"{self.bounding_box_format}`." + ) + self._prediction_decoder = prediction_decoder + self.make_predict_function(force=True) + self.make_train_function(force=True) + self.make_test_function(force=True) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "bounding_box_format": self.bounding_box_format, + "anchor_generator": keras.layers.serialize( + self.anchor_generator + ), + "label_encoder": keras.layers.serialize(self.label_encoder), + "prediction_decoder": keras.layers.serialize( + self._prediction_decoder + ), + } + ) + + return config + + @classmethod + def from_config(cls, config): + if "label_encoder" in config and isinstance( + config["label_encoder"], dict + ): + config["label_encoder"] = keras.layers.deserialize( + config["label_encoder"] + ) + + if "anchor_generator" in config and isinstance( + config["anchor_generator"], dict + ): + config["anchor_generator"] = keras.layers.deserialize( + config["anchor_generator"] + ) + + if "prediction_decoder" in config and isinstance( + config["prediction_decoder"], dict + ): + config["prediction_decoder"] = keras.layers.deserialize( + config["prediction_decoder"] + ) + + return super().from_config(config) diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py b/keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py new file mode 100644 index 0000000000..8bc6d1f796 --- /dev/null +++ b/keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py @@ -0,0 +1,14 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_object_detector_preprocessor import ( + ImageObjectDetectorPreprocessor, +) +from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone +from keras_hub.src.models.retinanet.retinanet_image_converter import ( + RetinaNetImageConverter, +) + + +@keras_hub_export("keras_hub.models.RetinaNetObjectDetectorPreprocessor") +class RetinaNetObjectDetectorPreprocessor(ImageObjectDetectorPreprocessor): + backbone_cls = RetinaNetBackbone + image_converter_cls = RetinaNetImageConverter diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py new file mode 100644 index 0000000000..29ee4de9bc --- /dev/null +++ b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py @@ -0,0 +1,101 @@ +import pytest +from keras import ops + +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator +from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone +from keras_hub.src.models.retinanet.retinanet_label_encoder import ( + RetinaNetLabelEncoder, +) +from keras_hub.src.models.retinanet.retinanet_object_detector import ( + RetinaNetObjectDetector, +) +from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( + RetinaNetObjectDetectorPreprocessor, +) +from keras_hub.src.tests.test_case import TestCase + + +class RetinaNetObjectDetectorTest(TestCase): + def setUp(self): + resnet_kwargs = { + "input_conv_filters": [64], + "input_conv_kernel_sizes": [7], + "stackwise_num_filters": [64, 64, 64], + "stackwise_num_blocks": [2, 2, 2], + "stackwise_num_strides": [1, 2, 2], + "image_shape": (None, None, 3), + "block_type": "bottleneck_block", + "use_pre_activation": False, + } + backbone = ResNetBackbone(**resnet_kwargs) + + retinanet_backbone_kwargs = { + "backbone": backbone, + "min_level": 3, + "max_level": 7, + } + + feature_extractor = RetinaNetBackbone(**retinanet_backbone_kwargs) + anchor_generator = AnchorGenerator( + bounding_box_format="yxyx", + min_level=3, + max_level=7, + num_scales=3, + aspect_ratios=[0.5, 1.0, 2.0], + anchor_size=8, + ) + label_encoder = RetinaNetLabelEncoder( + bounding_box_format="yxyx", anchor_generator=anchor_generator + ) + + image_converter = ImageConverter( + image_size=(256, 256), + ) + + preprocessor = RetinaNetObjectDetectorPreprocessor( + image_converter=image_converter, target_bounding_box_format="xyxy" + ) + + self.init_kwargs = { + "backbone": feature_extractor, + "anchor_generator": anchor_generator, + "label_encoder": label_encoder, + "num_classes": 10, + "bounding_box_format": "yxyx", + "preprocessor": preprocessor, + } + + self.input_size = 512 + self.images = ops.ones((1, self.input_size, self.input_size, 3)) + self.labels = { + "boxes": ops.convert_to_tensor( + [[[20, 10, 120, 110], [30, 20, 130, 120]]] + ), + "classes": ops.convert_to_tensor([[0, 2]]), + } + + self.train_data = (self.images, self.labels) + + @pytest.mark.large + def test_detection_basics(self): + self.run_task_test( + cls=RetinaNetObjectDetector, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape={ + "boxes": (1, 100, 4), + "classes": (1, 100), + "confidence": (1, 100), + "num_detections": (1,), + }, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=RetinaNetObjectDetector, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index 6d06c7266c..edd12f6986 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -341,6 +341,7 @@ def run_precision_test(self, cls, init_kwargs, input_data): continue if isinstance(sublayer, keras.layers.InputLayer): continue + print(sublayer) self.assertEqual(policy.compute_dtype, sublayer.compute_dtype) self.assertEqual(policy.variable_dtype, sublayer.variable_dtype) From baee6e2a298e4d3a2d29b6ce5d3c5040f5d82847 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 3 Oct 2024 12:12:26 -0700 Subject: [PATCH 05/35] nit --- keras_hub/src/models/retinanet/retinanet_backbone_test.py | 3 ++- keras_hub/src/tests/test_case.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/models/retinanet/retinanet_backbone_test.py b/keras_hub/src/models/retinanet/retinanet_backbone_test.py index f2e5db3eda..b4e99b823f 100644 --- a/keras_hub/src/models/retinanet/retinanet_backbone_test.py +++ b/keras_hub/src/models/retinanet/retinanet_backbone_test.py @@ -14,6 +14,7 @@ def setUp(self): "stackwise_num_filters": [64, 64, 64], "stackwise_num_blocks": [2, 2, 2], "stackwise_num_strides": [1, 2, 2], + "image_shape": (None, None, 3), "block_type": "bottleneck_block", "use_pre_activation": False, } @@ -31,7 +32,7 @@ def setUp(self): def test_backbone_basics(self): self.run_vision_backbone_test( cls=RetinaNetBackbone, - init_kwargs={**self.init_kwargs}, + init_kwargs=self.init_kwargs, input_data=self.input_data, expected_output_shape={ "P3": (2, 32, 32, 256), diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index edd12f6986..6d06c7266c 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -341,7 +341,6 @@ def run_precision_test(self, cls, init_kwargs, input_data): continue if isinstance(sublayer, keras.layers.InputLayer): continue - print(sublayer) self.assertEqual(policy.compute_dtype, sublayer.compute_dtype) self.assertEqual(policy.variable_dtype, sublayer.variable_dtype) From 5ee905ec537117b5d3d9035b8c1951e0fb9e6191 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 4 Oct 2024 10:07:14 -0700 Subject: [PATCH 06/35] Expose Anchor Generator as layer, docstring correction and test correction --- keras_hub/api/layers/__init__.py | 1 + keras_hub/src/models/image_object_detector.py | 8 ++-- .../image_object_detector_preprocessor.py | 19 ++++---- .../src/models/retinanet/anchor_generator.py | 2 + .../src/models/retinanet/feature_pyramid.py | 47 +++++++++++++------ .../src/models/retinanet/prediction_head.py | 12 ++--- .../models/retinanet/prediction_head_test.py | 2 + .../models/retinanet/retinanet_backbone.py | 41 +++++++++------- .../retinanet/retinanet_backbone_test.py | 10 ++-- .../retinanet/retinanet_label_encoder.py | 6 +-- .../retinanet_object_detector_test.py | 8 ++-- 11 files changed, 94 insertions(+), 62 deletions(-) diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 0d3ed939bc..17cd5b77bc 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -43,6 +43,7 @@ from keras_hub.src.models.resnet.resnet_image_converter import ( ResNetImageConverter, ) +from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator from keras_hub.src.models.retinanet.retinanet_image_converter import ( RetinaNetImageConverter, ) diff --git a/keras_hub/src/models/image_object_detector.py b/keras_hub/src/models/image_object_detector.py index f4723f5cdf..60f8593042 100644 --- a/keras_hub/src/models/image_object_detector.py +++ b/keras_hub/src/models/image_object_detector.py @@ -6,9 +6,9 @@ @keras_hub_export("keras_hub.models.ImageObjectDetector") class ImageObjectDetector(Task): - """Base class for all image classification tasks. + """Base class for all image object detections tasks. - `ImageObjectDetector` tasks wrap a `keras_hub.models.Backbone` and + The `ImageObjectDetector` tasks wrap a `keras_hub.models.Backbone` and a `keras_hub.models.Preprocessor` to create a model that can be used for image classification. `ImageObjectDetector` tasks take an additional `num_classes` argument, controlling the number of predicted output classes. @@ -29,9 +29,9 @@ def compile( metrics=None, **kwargs, ): - """Configures the `ImageSegmenter` task for training. + """Configures the `ImageObjectDetector` task for training. - The `ImageSegmenter` task extends the default compilation signature of + The `ImageObjectDetector` task extends the default compilation signature of `keras.Model.compile` with defaults for `optimizer`, `loss`, and `metrics`. To override these defaults, pass any value to these arguments during compilation. diff --git a/keras_hub/src/models/image_object_detector_preprocessor.py b/keras_hub/src/models/image_object_detector_preprocessor.py index 0c8eb85e12..6a61831a9e 100644 --- a/keras_hub/src/models/image_object_detector_preprocessor.py +++ b/keras_hub/src/models/image_object_detector_preprocessor.py @@ -15,14 +15,14 @@ class ImageObjectDetectorPreprocessor(Preprocessor): object detection tasks. It is intended to be paired with a `keras_hub.models.ImageObjectDetector` task. - All `ImageObjectDetectorPreprocessor` take inputs three inputs, `x`, `y`, and + All `ImageObjectDetectorPreprocessor` take three inputs, `x`, `y`, and `sample_weight`. `x`, the first input, should always be included. It can be a image or batch of images. See examples below. `y` and `sample_weight` are optional inputs that will be passed through unaltered. Usually, `y` will be the a dict of `{"boxes": Tensor(batch_size, num_boxes, 4), "classes": (batch_size, num_boxes)}. - The layer will output either `x`, an `(x, y)` tuple if labels were provided, + The layer will returns either `x`, an `(x, y)` tuple if labels were provided, or an `(x, y, sample_weight)` tuple if labels and sample weight were provided. `x` will be the input images after all model preprocessing has been applied. @@ -41,8 +41,8 @@ class ImageObjectDetectorPreprocessor(Preprocessor): - `"rel_xyxy"` - `"rel_xywh" Defaults to `"rel_yxyx"`. - target_bounding_box_format: str. TODO Add link to keras-core bounding - box formats page. + target_bounding_box_format: str. TODO + https://github.com/keras-team/keras-hub/issues/1907 Examples. @@ -80,11 +80,12 @@ def __init__( if "rel" not in source_bounding_box_format: raise ValueError( f"Only relative bounding box formats are supported " - f"but received source_bounding_box_format=" - f"`{source_bounding_box_format}` " - f"please provide source bounding box format from one of these " - f"`rel_xyxy` or `rel_yxyx` or `rel_xywh`. Make sure provided " - f"ground truth bounding boxes are normalized/relative to image." + f"Received source_bounding_box_format=" + f"`{source_bounding_box_format}`. " + f"Please provide a source bounding box format from one of " + f"the following `rel_xyxy` or `rel_yxyx` or `rel_xywh`. " + f"Ensure that the provided ground truth bounding boxes are " + f"normalized and relative to the image size. " ) self.source_bounding_box_format = source_bounding_box_format self.target_bounding_box_format = target_bounding_box_format diff --git a/keras_hub/src/models/retinanet/anchor_generator.py b/keras_hub/src/models/retinanet/anchor_generator.py index bb46988926..faf9708b61 100644 --- a/keras_hub/src/models/retinanet/anchor_generator.py +++ b/keras_hub/src/models/retinanet/anchor_generator.py @@ -3,9 +3,11 @@ import keras from keras import ops +from keras_hub.src.api_export import keras_hub_export from keras_hub.src.bounding_box.converters import convert_format +@keras_hub_export("keras_hub.layers.AnchorGenerator") class AnchorGenerator(keras.layers.Layer): """Generates anchor boxes for object detection tasks. diff --git a/keras_hub/src/models/retinanet/feature_pyramid.py b/keras_hub/src/models/retinanet/feature_pyramid.py index 1533ad9773..386742439c 100644 --- a/keras_hub/src/models/retinanet/feature_pyramid.py +++ b/keras_hub/src/models/retinanet/feature_pyramid.py @@ -41,10 +41,10 @@ class FeaturePyramid(keras.layers.Layer): activation: string or `keras.activations`. The activation function to be used in network. Defaults to `"relu"`. - kernel_initializer: `str` or `keras.initializers` initializer. + kernel_initializer: `str` or `keras.initializers`. The kernel initializer for the convolution layers. Defaults to `"VarianceScaling"`. - bias_initializer: `str` or `keras.initializers` initializer. + bias_initializer: `str` or `keras.initializers`. The bias initializer for the convolution layers. Defaults to `"zeros"`. batch_norm_momentum: float. @@ -53,10 +53,10 @@ class FeaturePyramid(keras.layers.Layer): batch_norm_epsilon: float. The epsilon for the batch normalization layers. Defaults to `0.001`. - kernel_regularizer: `str` or `keras.regularizers` regularizer. + kernel_regularizer: `str` or `keras.regularizers`. The kernel regularizer for the convolution layers. Defaults to `None`. - bias_regularizer: `str` or `keras.regularizers` regularizer. + bias_regularizer: `str` or `keras.regularizers`. The bias regularizer for the convolution layers. Defaults to `None`. use_batch_norm: bool. Whether to use batch normalization. @@ -117,7 +117,6 @@ def build(self, input_shapes): } input_levels = [int(level[1]) for level in input_shapes] backbone_max_level = min(max(input_levels), self.max_level) - # Build lateral layers self.lateral_conv_layers = {} for i in range(self.min_level, backbone_max_level + 1): @@ -134,7 +133,11 @@ def build(self, input_shapes): dtype=self.dtype_policy, name=f"lateral_conv_{level}", ) - self.lateral_conv_layers[level].build(input_shapes[level]) + self.lateral_conv_layers[level].build( + (None, None, None, input_shapes[level][-1]) + if self.data_format == "channels_last" + else (None, input_shapes[level][-1], None, None) + ) self.lateral_batch_norm_layers = {} if self.use_batch_norm: @@ -339,16 +342,32 @@ def compute_output_shape(self, input_shapes): intermediate_shape = ( ( intermediate_shape[0], - intermediate_shape[1] // 2 if intermediate_shape[1] else None, - intermediate_shape[2] // 2 if intermediate_shape[1] else None, + ( + intermediate_shape[1] // 2 + if intermediate_shape[1] is not None + else None + ), + ( + intermediate_shape[2] // 2 + if intermediate_shape[1] is not None + else None + ), self.num_filters, ) if self.data_format == "channels_last" else ( intermediate_shape[0], self.num_filters, - intermediate_shape[1] // 2 if intermediate_shape[1] else None, - intermediate_shape[2] // 2 if intermediate_shape[1] else None, + ( + intermediate_shape[1] // 2 + if intermediate_shape[1] is not None + else None + ), + ( + intermediate_shape[2] // 2 + if intermediate_shape[1] is not None + else None + ), ) ) @@ -360,12 +379,12 @@ def compute_output_shape(self, input_shapes): intermediate_shape[0], ( intermediate_shape[1] // 2 - if intermediate_shape[1] + if intermediate_shape[1] is not None else None ), ( intermediate_shape[2] // 2 - if intermediate_shape[1] + if intermediate_shape[1] is not None else None ), self.num_filters, @@ -376,12 +395,12 @@ def compute_output_shape(self, input_shapes): self.num_filters, ( intermediate_shape[1] // 2 - if intermediate_shape[1] + if intermediate_shape[1] is not None else None ), ( intermediate_shape[2] // 2 - if intermediate_shape[1] + if intermediate_shape[1] is not None else None ), ) diff --git a/keras_hub/src/models/retinanet/prediction_head.py b/keras_hub/src/models/retinanet/prediction_head.py index 35be2e9551..f12e4d3ee2 100644 --- a/keras_hub/src/models/retinanet/prediction_head.py +++ b/keras_hub/src/models/retinanet/prediction_head.py @@ -10,16 +10,16 @@ class PredictionHead(keras.layers.Layer): Defaults to `256`. num_conv_layers: int. Number of convolution layers before final layer. Defaults to `4`. - kernel_initializer: `str` or `keras.initializers` initializer. + kernel_initializer: `str` or `keras.initializers`. The kernel initializer for the convolution layers. Defaults to `"random_normal"`. - bias_initializer: `str` or `keras.initializers` initializer. + bias_initializer: `str` or `keras.initializers`. The bias initializer for the convolution layers. Defaults to `"zeros"`. - kernel_regularizer: `str` or `keras.regularizers` regularizer. + kernel_regularizer: `str` or `keras.regularizers`. The kernel regularizer for the convolution layers. Defaults to `None`. - bias_regularizer: `str` or `keras.regularizers` regularizer. + bias_regularizer: `str` or `keras.regularizers`. The bias regularizer for the convolution layers. Defaults to `None`. @@ -31,8 +31,8 @@ class PredictionHead(keras.layers.Layer): def __init__( self, output_filters, - num_filters=256, - num_conv_layers=4, + num_filters, + num_conv_layers, activation="relu", kernel_initializer="random_normal", bias_initializer="zeros", diff --git a/keras_hub/src/models/retinanet/prediction_head_test.py b/keras_hub/src/models/retinanet/prediction_head_test.py index ca1b949fed..7416565d12 100644 --- a/keras_hub/src/models/retinanet/prediction_head_test.py +++ b/keras_hub/src/models/retinanet/prediction_head_test.py @@ -10,6 +10,8 @@ def test_layer_behaviors(self): cls=PredictionHead, init_kwargs={ "output_filters": 9 * 4, # anchors_per_location * box length(4) + "num_filters": 256, + "num_conv_layers": 4, }, input_data=random.uniform(shape=(2, 64, 64, 256)), expected_output_shape=(2, 64, 64, 36), diff --git a/keras_hub/src/models/retinanet/retinanet_backbone.py b/keras_hub/src/models/retinanet/retinanet_backbone.py index 23567cb3d9..e4eb3eeb1b 100644 --- a/keras_hub/src/models/retinanet/retinanet_backbone.py +++ b/keras_hub/src/models/retinanet/retinanet_backbone.py @@ -1,16 +1,16 @@ import keras from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone +from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid from keras_hub.src.utils.keras_utils import standardize_data_format @keras_hub_export("keras_hub.models.RetinaNetBackbone") -class RetinaNetBackbone(FeaturePyramidBackbone): +class RetinaNetBackbone(Backbone): def __init__( self, - backbone, + image_encoder, min_level, max_level, image_shape=(None, None, 3), @@ -26,14 +26,21 @@ def __init__( ) data_format = standardize_data_format(data_format) - input_levels = [int(level[1]) for level in backbone.pyramid_outputs] + input_levels = [ + int(level[1]) for level in image_encoder.pyramid_outputs + ] backbone_max_level = min(max(input_levels), max_level) - image_encoder = keras.Model( - inputs=backbone.input, - outputs={ - f"P{i}": backbone.pyramid_outputs[f"P{i}"] - for i in range(min_level, backbone_max_level + 1) - }, + + if backbone_max_level < 5 and max_level >= 5: + raise ValueError( + f"Backbone maximum level ({backbone_max_level}) is less than " + f"the desired maximum level ({max_level}). " + f"Please ensure that the backbone can generate features up to " + f"the specified maximum level." + ) + feature_extractor = keras.Model( + inputs=image_encoder.inputs, + outputs=image_encoder.pyramid_outputs, name="backbone", ) @@ -42,15 +49,15 @@ def __init__( ) # === Functional model === - image_input = keras.layers.Input(image_shape, name="images") + image_input = keras.layers.Input(image_shape, name="inputs") - image_encoder_outputs = image_encoder(image_input) - feature_pyramid_outputs = feature_pyramid(image_encoder_outputs) + feature_extractor_outputs = feature_extractor(image_input) + feature_pyramid_outputs = feature_pyramid(feature_extractor_outputs) # === config === self.min_level = min_level self.max_level = max_level - self.backbone = backbone + self.image_encoder = image_encoder self.feature_pyramid = feature_pyramid self.image_shape = image_shape @@ -65,7 +72,7 @@ def get_config(self): config = super().get_config() config.update( { - "backbone": keras.layers.serialize(self.backbone), + "image_encoder": keras.layers.serialize(self.image_encoder), "min_level": self.min_level, "max_level": self.max_level, "image_shape": self.image_shape, @@ -77,7 +84,9 @@ def get_config(self): def from_config(cls, config): config.update( { - "backbone": keras.layers.deserialize(config["backbone"]), + "image_encoder": keras.layers.deserialize( + config["image_encoder"] + ), } ) diff --git a/keras_hub/src/models/retinanet/retinanet_backbone_test.py b/keras_hub/src/models/retinanet/retinanet_backbone_test.py index b4e99b823f..becd479ce0 100644 --- a/keras_hub/src/models/retinanet/retinanet_backbone_test.py +++ b/keras_hub/src/models/retinanet/retinanet_backbone_test.py @@ -18,12 +18,12 @@ def setUp(self): "block_type": "bottleneck_block", "use_pre_activation": False, } - backbone = ResNetBackbone(**resnet_kwargs) + image_encoder = ResNetBackbone(**resnet_kwargs) self.init_kwargs = { - "backbone": backbone, + "image_encoder": image_encoder, "min_level": 3, - "max_level": 7, + "max_level": 4, } self.input_size = 256 @@ -37,12 +37,10 @@ def test_backbone_basics(self): expected_output_shape={ "P3": (2, 32, 32, 256), "P4": (2, 16, 16, 256), - "P5": (2, 8, 8, 256), - "P6": (2, 4, 4, 256), - "P7": (2, 2, 2, 256), }, expected_pyramid_output_keys=False, run_mixed_precision_check=False, + run_data_format_check=False, ) @pytest.mark.large diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder.py b/keras_hub/src/models/retinanet/retinanet_label_encoder.py index b50ad958dc..0b58acea01 100644 --- a/keras_hub/src/models/retinanet/retinanet_label_encoder.py +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder.py @@ -25,9 +25,9 @@ class RetinaNetLabelEncoder(keras.layers.Layer): consistency during training, regardless of the input format. Args: - anchor_generator: TODO: Add anchor_generator exposed layer details. - bounding_box_format: str. The format of bounding boxes of input dataset. - Refer TODO: Add link to Keras Core Docs. + anchor_generator: A `keras_hub.layers.AnchorGenerator`. + bounding_box_format: str. TODO: + https://github.com/keras-team/keras-hub/issues/1907 positive_threshold: float. the threshold to set an anchor to positive match to gt box. Values above it are positive matches. Defaults to `0.5` diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py index 29ee4de9bc..4f0600761a 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py @@ -29,19 +29,19 @@ def setUp(self): "block_type": "bottleneck_block", "use_pre_activation": False, } - backbone = ResNetBackbone(**resnet_kwargs) + image_encoder = ResNetBackbone(**resnet_kwargs) retinanet_backbone_kwargs = { - "backbone": backbone, + "image_encoder": image_encoder, "min_level": 3, - "max_level": 7, + "max_level": 4, } feature_extractor = RetinaNetBackbone(**retinanet_backbone_kwargs) anchor_generator = AnchorGenerator( bounding_box_format="yxyx", min_level=3, - max_level=7, + max_level=4, num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=8, From 84533d4814b474fb75eee31a1849ef89b3f32dfb Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 4 Oct 2024 11:25:03 -0700 Subject: [PATCH 07/35] nit --- .../models/retinanet/retinanet_backbone.py | 23 +++++++++++++++- .../retinanet/retinanet_backbone_test.py | 1 + .../retinanet/retinanet_object_detector.py | 27 +++++++++++++++++++ 3 files changed, 50 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/retinanet/retinanet_backbone.py b/keras_hub/src/models/retinanet/retinanet_backbone.py index e4eb3eeb1b..2c79d5ea90 100644 --- a/keras_hub/src/models/retinanet/retinanet_backbone.py +++ b/keras_hub/src/models/retinanet/retinanet_backbone.py @@ -8,6 +8,24 @@ @keras_hub_export("keras_hub.models.RetinaNetBackbone") class RetinaNetBackbone(Backbone): + """RetinaNet Backbone. + + Args: + image_encoder (keras.Model): The backbone model used to extract features + from the input image. + It should have pyramid outputs. + min_level (int): The minimum feature pyramid level. + max_level (int): The maximum feature pyramid level. + image_shape (tuple): The shape of the input image. + data_format (str): The data format of the input image (channels_first or channels_last). + dtype (str): The data type of the input image. + **kwargs: Additional arguments passed to the base class. + + Raises: + ValueError: If `min_level` is greater than `max_level`. + ValueError: If `backbone_max_level` is less than 5 and `max_level` is greater than or equal to 5. + """ + def __init__( self, image_encoder, @@ -40,7 +58,10 @@ def __init__( ) feature_extractor = keras.Model( inputs=image_encoder.inputs, - outputs=image_encoder.pyramid_outputs, + outputs={ + f"P{level}": image_encoder.pyramid_outputs[f"P{level}"] + for level in input_levels + }, name="backbone", ) diff --git a/keras_hub/src/models/retinanet/retinanet_backbone_test.py b/keras_hub/src/models/retinanet/retinanet_backbone_test.py index becd479ce0..dedfad398d 100644 --- a/keras_hub/src/models/retinanet/retinanet_backbone_test.py +++ b/keras_hub/src/models/retinanet/retinanet_backbone_test.py @@ -41,6 +41,7 @@ def test_backbone_basics(self): expected_pyramid_output_keys=False, run_mixed_precision_check=False, run_data_format_check=False, + run_quantization_check=False, ) @pytest.mark.large diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index e3146b088d..bd5d075122 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -17,6 +17,33 @@ @keras_hub_export("keras_hub.models.RetinaNetObjectDetector") class RetinaNetObjectDetector(ImageObjectDetector): + """RetinaNet object detector model. + + This class implements the RetinaNet object detection architecture. + It consists of a feature extractor backbone, a feature pyramid network(FPN), + and two prediction heads for classification and regression. + + Args: + backbone: `keras.Model`. A `keras.models.RetinaNetBackbone` class, defining the + backbone network architecture. + label_encoder: `keras.layers.Layer`. A `RetinaNetLabelEncoder` class + that accepts an image Tensor, a bounding box Tensor and a bounding + box class Tensor to its `call()` method, and returns + `RetinaNetObjectDetector` training targets. + anchor_generator: A `keras_Hub.layers.AnchorGenerator`. + num_classes: The number of object classes to be detected. + bounding_box_format: The format of bounding boxes of input dataset. + TODO: https://github.com/keras-team/keras-hub/issues/1907 + preprocessor: Optional. An instance of the + `RetinaNetObjectDetectorPreprocessor` class or a custom preprocessor. + activation: Optional. The activation function to be used in the + classification head. + head_dtype: Optional. The data type for the prediction heads. + prediction_decoder: Optional. A `keras.layers.Layer` that is + responsible for transforming RetinaNet predictions into usable + bounding box Tensors. + Defaults to `NonMaxSuppression` class instance. + """ backbone_cls = RetinaNetBackbone preprocessor_cls = RetinaNetObjectDetectorPreprocessor From b6ceb8f0df587632de342296aacf8ce946dccc90 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 4 Oct 2024 11:52:42 -0700 Subject: [PATCH 08/35] Add missing args for prediction heads --- .../src/models/retinanet/retinanet_object_detector.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index bd5d075122..2eef9e84ec 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -67,13 +67,17 @@ def __init__( -1 * keras.ops.log((1 - 0.01) / 0.01) ) box_head = PredictionHead( - anchor_generator.anchors_per_location * 4, + output_filters=anchor_generator.anchors_per_location * 4, + num_conv_layers=4, + num_filters=256, bias_initializer=prior_probability, dtype=head_dtype, ) classification_head = PredictionHead( - anchor_generator.anchors_per_location * num_classes, + output_filters=anchor_generator.anchors_per_location * num_classes, + num_conv_layers=4, + num_filters=256, dtype=head_dtype, ) From 4c7a28bdded593f6b826c477987c7894afb16fba Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Tue, 8 Oct 2024 10:45:19 -0700 Subject: [PATCH 09/35] - Use FeaturePyramidBackbone cls for RetinaNet backbone. - Correct test cases. --- keras_hub/src/models/image_object_detector.py | 4 +- .../image_object_detector_preprocessor.py | 3 +- .../src/models/retinanet/anchor_generator.py | 4 +- .../src/models/retinanet/feature_pyramid.py | 13 +++- .../src/models/retinanet/prediction_head.py | 40 +++++----- .../models/retinanet/retinanet_backbone.py | 29 ++++--- .../retinanet/retinanet_backbone_test.py | 25 ++++-- .../retinanet/retinanet_label_encoder.py | 4 +- .../retinanet/retinanet_label_encoder_test.py | 17 ++--- .../retinanet/retinanet_object_detector.py | 76 ++++++++++--------- .../retinanet_object_detector_test.py | 7 +- 11 files changed, 124 insertions(+), 98 deletions(-) diff --git a/keras_hub/src/models/image_object_detector.py b/keras_hub/src/models/image_object_detector.py index 60f8593042..4016d7dff2 100644 --- a/keras_hub/src/models/image_object_detector.py +++ b/keras_hub/src/models/image_object_detector.py @@ -6,11 +6,11 @@ @keras_hub_export("keras_hub.models.ImageObjectDetector") class ImageObjectDetector(Task): - """Base class for all image object detections tasks. + """Base class for all image object detection tasks. The `ImageObjectDetector` tasks wrap a `keras_hub.models.Backbone` and a `keras_hub.models.Preprocessor` to create a model that can be used for - image classification. `ImageObjectDetector` tasks take an additional + object detection. `ImageObjectDetector` tasks take an additional `num_classes` argument, controlling the number of predicted output classes. To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` diff --git a/keras_hub/src/models/image_object_detector_preprocessor.py b/keras_hub/src/models/image_object_detector_preprocessor.py index 6a61831a9e..72a3c23153 100644 --- a/keras_hub/src/models/image_object_detector_preprocessor.py +++ b/keras_hub/src/models/image_object_detector_preprocessor.py @@ -96,14 +96,13 @@ def call(self, x, y=None, sample_weight=None): if self.image_converter: x = self.image_converter(x) - if y is not None: + if y is not None and keras.ops.is_tensor(y): y = convert_format( y, source=self.source_bounding_box_format, target=self.target_bounding_box_format, images=x, ) - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) def get_config(self): diff --git a/keras_hub/src/models/retinanet/anchor_generator.py b/keras_hub/src/models/retinanet/anchor_generator.py index faf9708b61..e04779abb6 100644 --- a/keras_hub/src/models/retinanet/anchor_generator.py +++ b/keras_hub/src/models/retinanet/anchor_generator.py @@ -102,8 +102,8 @@ def call(self, inputs): feat_size_x = math.ceil(image_shape[1] / 2**level) # Calculate the stride (step size) for this level - stride_y = ops.cast(image_shape[0] / feat_size_y, "float32") - stride_x = ops.cast(image_shape[1] / feat_size_x, "float32") + stride_y = image_shape[0] / feat_size_y + stride_x = image_shape[1] / feat_size_x # Generate anchor center points # Start from stride/2 to center anchors on pixels diff --git a/keras_hub/src/models/retinanet/feature_pyramid.py b/keras_hub/src/models/retinanet/feature_pyramid.py index 386742439c..2322063eea 100644 --- a/keras_hub/src/models/retinanet/feature_pyramid.py +++ b/keras_hub/src/models/retinanet/feature_pyramid.py @@ -1,5 +1,7 @@ import keras +from keras_hub.src.utils.keras_utils import standardize_data_format + class FeaturePyramid(keras.layers.Layer): """A Feature Pyramid Network (FPN) layer. @@ -78,6 +80,7 @@ def __init__( kernel_regularizer=None, bias_regularizer=None, use_batch_norm=False, + data_format=None, **kwargs, ): super().__init__(**kwargs) @@ -103,8 +106,8 @@ def __init__( self.bias_regularizer = keras.regularizers.get(bias_regularizer) else: self.bias_regularizer = None - self.data_format = keras.backend.image_data_format() - self.batch_norm_axis = -1 if self.data_format == "channels_last" else 1 + self.data_format = standardize_data_format(data_format) + self.batch_norm_axis = -1 if data_format == "channels_last" else 1 def build(self, input_shapes): input_shapes = { @@ -286,7 +289,10 @@ def call(self, inputs): if self.use_batch_norm else self.output_conv_layers[level](feats_in) ) - + output_features = { + f"P{i}": output_features[f"P{i}"] + for i in range(self.min_level, self.max_level + 1) + } return output_features def get_config(self): @@ -297,6 +303,7 @@ def get_config(self): "max_level": self.max_level, "num_filters": self.num_filters, "use_batch_norm": self.use_batch_norm, + "data_format": self.data_format, "activation": keras.activations.serialize(self.activation), "kernel_initializer": keras.initializers.serialize( self.kernel_initializer diff --git a/keras_hub/src/models/retinanet/prediction_head.py b/keras_hub/src/models/retinanet/prediction_head.py index f12e4d3ee2..95395e864b 100644 --- a/keras_hub/src/models/retinanet/prediction_head.py +++ b/keras_hub/src/models/retinanet/prediction_head.py @@ -4,28 +4,34 @@ class PredictionHead(keras.layers.Layer): """The classification/box predictions head. - Arguments: + + + Args: output_filters: int. Number of convolution filters in the final layer. + The number of output channels determines the prediction type: + - **Classification**: + `output_filters = num_anchors * num_classes` + Predicts class probabilities for each anchor. + - **Bounding Box Regression**: + `output_filters = num_anchors * 4` + Predicts bounding box offsets (x1, y1, x2, y2) for each anchor. num_filters: int. Number of convolution filters used in base layers. Defaults to `256`. num_conv_layers: int. Number of convolution layers before final layer. Defaults to `4`. - kernel_initializer: `str` or `keras.initializers`. - The kernel initializer for the convolution layers. - Defaults to `"random_normal"`. - bias_initializer: `str` or `keras.initializers`. - The bias initializer for the convolution layers. - Defaults to `"zeros"`. - kernel_regularizer: `str` or `keras.regularizers`. - The kernel regularizer for the convolution layers. - Defaults to `None`. - bias_regularizer: `str` or `keras.regularizers`. - The bias regularizer for the convolution layers. - Defaults to `None`. + kernel_initializer: `str` or `keras.initializers`. The kernel + initializer for the convolution layers. Defaults to + `"random_normal"`. + bias_initializer: `str` or `keras.initializers`. The bias initializer + for the convolution layers. Defaults to `"zeros"`. + kernel_regularizer: `str` or `keras.regularizers`. The kernel + regularizer for the convolution layers. Defaults to `None`. + bias_regularizer: `str` or `keras.regularizers`. The bias regularizer + for the convolution layers. Defaults to `None`. Returns: - A function representing either the classification - or the box regression head depending on `output_filters`. + A function representing either the classification + or the box regression head depending on `output_filters`. """ def __init__( @@ -75,7 +81,6 @@ def build(self, input_shape): ) for _ in range(self.num_conv_layers) ] - intermediate_shape = input_shape for conv in self.conv_layers: conv.build(intermediate_shape) @@ -84,7 +89,6 @@ def build(self, input_shape): if self.data_format == "channels_last" else (input_shape[0], self.num_filters) + (input_shape[1:-1]) ) - self.prediction_layer = keras.layers.Conv2D( self.output_filters, kernel_size=3, @@ -96,13 +100,11 @@ def build(self, input_shape): bias_regularizer=self.bias_regularizer, dtype=self.dtype_policy, ) - self.prediction_layer.build( (None, None, None, self.num_filters) if self.data_format == "channels_last" else (None, self.num_filters, None, None) ) - self.built = True def call(self, input): diff --git a/keras_hub/src/models/retinanet/retinanet_backbone.py b/keras_hub/src/models/retinanet/retinanet_backbone.py index 2c79d5ea90..b50e87fc2e 100644 --- a/keras_hub/src/models/retinanet/retinanet_backbone.py +++ b/keras_hub/src/models/retinanet/retinanet_backbone.py @@ -1,15 +1,18 @@ import keras from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid from keras_hub.src.utils.keras_utils import standardize_data_format @keras_hub_export("keras_hub.models.RetinaNetBackbone") -class RetinaNetBackbone(Backbone): +class RetinaNetBackbone(FeaturePyramidBackbone): """RetinaNet Backbone. + Combines a CNN backbone (e.g., ResNet, MobileNet) with a feature pyramid + network (FPN)to extract multi-scale features for object detection. + Args: image_encoder (keras.Model): The backbone model used to extract features from the input image. @@ -66,22 +69,18 @@ def __init__( ) feature_pyramid = FeaturePyramid( - min_level=min_level, max_level=max_level, name="fpn", dtype=dtype + min_level=min_level, + max_level=max_level, + name="fpn", + dtype=dtype, + data_format=data_format, ) # === Functional model === image_input = keras.layers.Input(image_shape, name="inputs") - feature_extractor_outputs = feature_extractor(image_input) feature_pyramid_outputs = feature_pyramid(feature_extractor_outputs) - # === config === - self.min_level = min_level - self.max_level = max_level - self.image_encoder = image_encoder - self.feature_pyramid = feature_pyramid - self.image_shape = image_shape - super().__init__( inputs=image_input, outputs=feature_pyramid_outputs, @@ -89,6 +88,14 @@ def __init__( **kwargs, ) + # === config === + self.min_level = min_level + self.max_level = max_level + self.image_encoder = image_encoder + self.feature_pyramid = feature_pyramid + self.image_shape = image_shape + self.pyramid_outputs = feature_pyramid_outputs + def get_config(self): config = super().get_config() config.update( diff --git a/keras_hub/src/models/retinanet/retinanet_backbone_test.py b/keras_hub/src/models/retinanet/retinanet_backbone_test.py index dedfad398d..176544c236 100644 --- a/keras_hub/src/models/retinanet/retinanet_backbone_test.py +++ b/keras_hub/src/models/retinanet/retinanet_backbone_test.py @@ -11,10 +11,9 @@ def setUp(self): resnet_kwargs = { "input_conv_filters": [64], "input_conv_kernel_sizes": [7], - "stackwise_num_filters": [64, 64, 64], - "stackwise_num_blocks": [2, 2, 2], - "stackwise_num_strides": [1, 2, 2], - "image_shape": (None, None, 3), + "stackwise_num_filters": [64, 128, 256, 512], + "stackwise_num_blocks": [3, 4, 6, 3], + "stackwise_num_strides": [1, 2, 2, 2], "block_type": "bottleneck_block", "use_pre_activation": False, } @@ -23,13 +22,13 @@ def setUp(self): self.init_kwargs = { "image_encoder": image_encoder, "min_level": 3, - "max_level": 4, + "max_level": 7, } self.input_size = 256 self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) - def test_backbone_basics(self): + def test_backbone_basics_channels_first(self): self.run_vision_backbone_test( cls=RetinaNetBackbone, init_kwargs=self.init_kwargs, @@ -37,11 +36,21 @@ def test_backbone_basics(self): expected_output_shape={ "P3": (2, 32, 32, 256), "P4": (2, 16, 16, 256), + "P5": (2, 8, 8, 256), + "P6": (2, 4, 4, 256), + "P7": (2, 2, 2, 256), }, - expected_pyramid_output_keys=False, + expected_pyramid_output_keys=["P3", "P4", "P5", "P6", "P7"], + expected_pyramid_image_sizes=[ + (32, 32), + (16, 16), + (8, 8), + (4, 4), + (2, 2), + ], run_mixed_precision_check=False, - run_data_format_check=False, run_quantization_check=False, + run_data_format_check=False, ) @pytest.mark.large diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder.py b/keras_hub/src/models/retinanet/retinanet_label_encoder.py index 0b58acea01..5e5cd6b23b 100644 --- a/keras_hub/src/models/retinanet/retinanet_label_encoder.py +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder.py @@ -26,8 +26,8 @@ class RetinaNetLabelEncoder(keras.layers.Layer): Args: anchor_generator: A `keras_hub.layers.AnchorGenerator`. - bounding_box_format: str. TODO: - https://github.com/keras-team/keras-hub/issues/1907 + bounding_box_format: str. Ground truth format of bounding boxes. + TODO: https://github.com/keras-team/keras-hub/issues/1907 positive_threshold: float. the threshold to set an anchor to positive match to gt box. Values above it are positive matches. Defaults to `0.5` diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py b/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py index f5097342e7..d05bf5a99a 100644 --- a/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py @@ -10,7 +10,7 @@ class RetinaNetLabelEncoderTest(TestCase): def setUp(self): - self.anchor_generator = AnchorGenerator( + anchor_generator = AnchorGenerator( bounding_box_format="xyxy", min_level=3, max_level=7, @@ -18,6 +18,10 @@ def setUp(self): aspect_ratios=[0.5, 1.0, 2.0], anchor_size=8, ) + self.init_kwargs = { + "anchor_generator": anchor_generator, + "bounding_box_format": "xyxy", + } def test_layer_behaviors(self): images_shape = (8, 128, 128, 3) @@ -25,10 +29,7 @@ def test_layer_behaviors(self): classes_shape = (8, 10) self.run_layer_test( cls=RetinaNetLabelEncoder, - init_kwargs={ - "anchor_generator": self.anchor_generator, - "bounding_box_format": "xyxy", - }, + init_kwargs=self.init_kwargs, input_data={ "images": np.random.uniform(size=images_shape), "gt_boxes": np.random.uniform( @@ -55,8 +56,7 @@ def test_label_encoder_output_shapes(self): classes = np.random.uniform(size=classes_shape, low=0, high=5) encoder = RetinaNetLabelEncoder( - anchor_generator=self.anchor_generator, - bounding_box_format="xyxy", + **self.init_kwargs, ) box_targets, class_targets = encoder(images, boxes, classes) @@ -74,8 +74,7 @@ def test_all_negative_1(self): classes = -np.ones(shape=classes_shape, dtype="float32") encoder = RetinaNetLabelEncoder( - anchor_generator=self.anchor_generator, - bounding_box_format="xyxy", + **self.init_kwargs, ) box_targets, class_targets = encoder(images, boxes, classes) diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index 2eef9e84ec..6886c27a5f 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -1,3 +1,5 @@ +from typing import Union + import keras from keras import ops @@ -38,7 +40,7 @@ class RetinaNetObjectDetector(ImageObjectDetector): `RetinaNetObjectDetectorPreprocessor` class or a custom preprocessor. activation: Optional. The activation function to be used in the classification head. - head_dtype: Optional. The data type for the prediction heads. + dtype: Optional. The data type for the prediction heads. prediction_decoder: Optional. A `keras.layers.Layer` that is responsible for transforming RetinaNet predictions into usable bounding box Tensors. @@ -57,12 +59,13 @@ def __init__( bounding_box_format, preprocessor=None, activation=None, - head_dtype=None, + dtype=None, prediction_decoder=None, **kwargs, ): # === Layers === - head_dtype = head_dtype or backbone.dtype_policy + image_input = keras.layers.Input(backbone.image_shape, name="images") + head_dtype = dtype or backbone.dtype_policy prior_probability = keras.initializers.Constant( -1 * keras.ops.log((1 - 0.01) / 0.01) ) @@ -73,7 +76,6 @@ def __init__( bias_initializer=prior_probability, dtype=head_dtype, ) - classification_head = PredictionHead( output_filters=anchor_generator.anchors_per_location * num_classes, num_conv_layers=4, @@ -82,8 +84,6 @@ def __init__( ) # === Functional Model === - image_input = keras.layers.Input(backbone.image_shape, name="images") - feature_map = backbone(image_input) cls_pred = [] @@ -107,6 +107,12 @@ def __init__( outputs = {"box": box_pred, "classification": cls_pred} + super().__init__( + inputs=image_input, + outputs=outputs, + **kwargs, + ) + # === Config === self.bounding_box_format = bounding_box_format self.num_classes = num_classes @@ -122,12 +128,6 @@ def __init__( bounding_box_format=bounding_box_format, ) - super().__init__( - inputs=image_input, - outputs=outputs, - **kwargs, - ) - def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): y_for_label_encoder = convert_format( y, @@ -199,15 +199,37 @@ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): def predict_step(self, *args): outputs = super().predict_step(*args) - return self.decode_predictions(outputs, args[-1]) + if isinstance(outputs, tuple): + return self.decode_predictions(outputs[0], args[-1]), outputs[1] + return self.decode_predictions(outputs, *args) + + @property + def prediction_decoder(self): + return self._prediction_decoder + + @prediction_decoder.setter + def prediction_decoder(self, prediction_decoder): + if prediction_decoder.bounding_box_format != self.bounding_box_format: + raise ValueError( + "Expected `prediction_decoder` and `RetinaNet` to " + "use the same `bounding_box_format`, but got " + "`prediction_decoder.bounding_box_format=" + f"{prediction_decoder.bounding_box_format}`, and " + "`self.bounding_box_format=" + f"{self.bounding_box_format}`." + ) + self._prediction_decoder = prediction_decoder + self.make_predict_function(force=True) + self.make_train_function(force=True) + self.make_test_function(force=True) def decode_predictions(self, predictions, data): - if isinstance(data, tuple): - images = data[0] - else: - images = data box_pred, cls_pred = predictions["box"], predictions["classification"] # box_pred is on "center_yxhw" format, convert to target format. + if isinstance(data, Union[tuple, list]): + images, _ = data + else: + images = data image_shape = ops.shape(images)[1:] anchor_boxes = self.anchor_generator(images) anchor_boxes = ops.concatenate(list(anchor_boxes.values()), axis=0) @@ -237,26 +259,6 @@ def decode_predictions(self, predictions, data): ) return y_pred - @property - def prediction_decoder(self): - return self._prediction_decoder - - @prediction_decoder.setter - def prediction_decoder(self, prediction_decoder): - if prediction_decoder.bounding_box_format != self.bounding_box_format: - raise ValueError( - "Expected `prediction_decoder` and RetinaNet to " - "use the same `bounding_box_format`, but got " - "`prediction_decoder.bounding_box_format=" - f"{prediction_decoder.bounding_box_format}`, and " - "`self.bounding_box_format=" - f"{self.bounding_box_format}`." - ) - self._prediction_decoder = prediction_decoder - self.make_predict_function(force=True) - self.make_train_function(force=True) - self.make_test_function(force=True) - def get_config(self): config = super().get_config() config.update( diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py index 4f0600761a..f98dc9baba 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py @@ -1,5 +1,6 @@ import pytest from keras import ops +from keras import random from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone @@ -68,12 +69,12 @@ def setUp(self): } self.input_size = 512 - self.images = ops.ones((1, self.input_size, self.input_size, 3)) + self.images = random.uniform((1, self.input_size, self.input_size, 3)) self.labels = { - "boxes": ops.convert_to_tensor( + "boxes": ops.convert_to_numpy( [[[20, 10, 120, 110], [30, 20, 130, 120]]] ), - "classes": ops.convert_to_tensor([[0, 2]]), + "classes": ops.convert_to_numpy([[0, 2]]), } self.train_data = (self.images, self.labels) From 3f915dc65bbd8cb554587bc7cb35b2dfa026ca97 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Tue, 8 Oct 2024 12:57:13 -0700 Subject: [PATCH 10/35] fix decoding error --- keras_hub/src/models/retinanet/retinanet_object_detector.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index 6886c27a5f..3dc09cad4e 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -1,5 +1,3 @@ -from typing import Union - import keras from keras import ops @@ -226,7 +224,7 @@ def prediction_decoder(self, prediction_decoder): def decode_predictions(self, predictions, data): box_pred, cls_pred = predictions["box"], predictions["classification"] # box_pred is on "center_yxhw" format, convert to target format. - if isinstance(data, Union[tuple, list]): + if isinstance(data, list) or isinstance(data, tuple): images, _ = data else: images = data From f0da549811da30f34b6dd13b67eb8c53c59a81a1 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Tue, 8 Oct 2024 14:17:00 -0700 Subject: [PATCH 11/35] - Add ground truth arg for RetinaNet model and remove source and target format from preprocessor --- .../image_object_detector_preprocessor.py | 43 ------------------- .../retinanet/retinanet_object_detector.py | 41 +++++++++++++----- .../retinanet_object_detector_test.py | 18 ++++---- 3 files changed, 40 insertions(+), 62 deletions(-) diff --git a/keras_hub/src/models/image_object_detector_preprocessor.py b/keras_hub/src/models/image_object_detector_preprocessor.py index 72a3c23153..066aea9aae 100644 --- a/keras_hub/src/models/image_object_detector_preprocessor.py +++ b/keras_hub/src/models/image_object_detector_preprocessor.py @@ -1,7 +1,6 @@ import keras from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.bounding_box.converters import convert_format from keras_hub.src.models.preprocessor import Preprocessor from keras_hub.src.utils.tensor_utils import preprocessing_function @@ -35,15 +34,6 @@ class ImageObjectDetectorPreprocessor(Preprocessor): Args: image_converter: Preprocessing pipeline for images. - source_bounding_box_format: str. The format of the source bounding boxes. - supported formats include: - - `"rel_yxyx"` - - `"rel_xyxy"` - - `"rel_xywh" - Defaults to `"rel_yxyx"`. - target_bounding_box_format: str. TODO - https://github.com/keras-team/keras-hub/issues/1907 - Examples. ```python @@ -71,47 +61,14 @@ class ImageObjectDetectorPreprocessor(Preprocessor): def __init__( self, - target_bounding_box_format, - source_bounding_box_format="rel_yxyx", image_converter=None, **kwargs, ): super().__init__(**kwargs) - if "rel" not in source_bounding_box_format: - raise ValueError( - f"Only relative bounding box formats are supported " - f"Received source_bounding_box_format=" - f"`{source_bounding_box_format}`. " - f"Please provide a source bounding box format from one of " - f"the following `rel_xyxy` or `rel_yxyx` or `rel_xywh`. " - f"Ensure that the provided ground truth bounding boxes are " - f"normalized and relative to the image size. " - ) - self.source_bounding_box_format = source_bounding_box_format - self.target_bounding_box_format = target_bounding_box_format - self.image_converter = image_converter @preprocessing_function def call(self, x, y=None, sample_weight=None): if self.image_converter: x = self.image_converter(x) - if y is not None and keras.ops.is_tensor(y): - y = convert_format( - y, - source=self.source_bounding_box_format, - target=self.target_bounding_box_format, - images=x, - ) return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - def get_config(self): - config = super().get_config() - config.update( - { - "source_bounding_box_format": self.source_bounding_box_format, - "target_bounding_box_format": self.target_bounding_box_format, - } - ) - - return config diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index 3dc09cad4e..0001436c87 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -32,8 +32,14 @@ class RetinaNetObjectDetector(ImageObjectDetector): `RetinaNetObjectDetector` training targets. anchor_generator: A `keras_Hub.layers.AnchorGenerator`. num_classes: The number of object classes to be detected. - bounding_box_format: The format of bounding boxes of input dataset. - TODO: https://github.com/keras-team/keras-hub/issues/1907 + ground_truth_bounding_box_format: Ground truth bounding box format. + Refer TODO: https://github.com/keras-team/keras-hub/issues/1907 + Ensure that ground truth boxes follow one of the following formats. + - `rel_xyxy` + - `rel_yxyx` + - `rel_xywh` + target_bounding_box_format: Target bounding box format. + Refer TODO: https://github.com/keras-team/keras-hub/issues/1907 preprocessor: Optional. An instance of the `RetinaNetObjectDetectorPreprocessor` class or a custom preprocessor. activation: Optional. The activation function to be used in the @@ -54,13 +60,24 @@ def __init__( label_encoder, anchor_generator, num_classes, - bounding_box_format, + ground_truth_bounding_box_format, + target_bounding_box_format, preprocessor=None, activation=None, dtype=None, prediction_decoder=None, **kwargs, ): + if "rel" not in ground_truth_bounding_box_format: + raise ValueError( + f"Only relative bounding box formats are supported " + f"Received ground_truth_bounding_box_format=" + f"`{ground_truth_bounding_box_format}`. " + f"Please provide a `ground_truth_bounding_box_format` from one of " + f"the following `rel_xyxy` or `rel_yxyx` or `rel_xywh`. " + f"Ensure that the provided ground truth bounding boxes are " + f"normalized and relative to the image size. " + ) # === Layers === image_input = keras.layers.Input(backbone.image_shape, name="images") head_dtype = dtype or backbone.dtype_policy @@ -112,7 +129,8 @@ def __init__( ) # === Config === - self.bounding_box_format = bounding_box_format + self.ground_truth_bounding_box_format = ground_truth_bounding_box_format + self.target_bounding_box_format = target_bounding_box_format self.num_classes = num_classes self.backbone = backbone self.preprocessor = preprocessor @@ -123,13 +141,13 @@ def __init__( self.classification_head = classification_head self._prediction_decoder = prediction_decoder or NonMaxSuppression( from_logits=(activation != keras.activations.sigmoid), - bounding_box_format=bounding_box_format, + bounding_box_format=self.target_bounding_box_format, ) def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): y_for_label_encoder = convert_format( y, - source=self.bounding_box_format, + source=self.ground_truth_bounding_box_format, target=self.label_encoder.bounding_box_format, images=x, ) @@ -235,14 +253,14 @@ def decode_predictions(self, predictions, data): anchors=anchor_boxes, boxes_delta=box_pred, anchor_format=self.anchor_generator.bounding_box_format, - box_format=self.bounding_box_format, + box_format=self.target_bounding_box_format, variance=BOX_VARIANCE, image_shape=image_shape, ) - # box_pred is now in "self.bounding_box_format" format + # box_pred is now in "self.target_bounding_box_format" format box_pred = convert_format( box_pred, - source=self.bounding_box_format, + source=self.target_bounding_box_format, target=self.prediction_decoder.bounding_box_format, image_shape=image_shape, ) @@ -252,7 +270,7 @@ def decode_predictions(self, predictions, data): y_pred["boxes"] = convert_format( y_pred["boxes"], source=self.prediction_decoder.bounding_box_format, - target=self.bounding_box_format, + target=self.target_bounding_box_format, image_shape=image_shape, ) return y_pred @@ -262,7 +280,8 @@ def get_config(self): config.update( { "num_classes": self.num_classes, - "bounding_box_format": self.bounding_box_format, + "ground_truth_bounding_box_format": self.ground_truth_bounding_box_format, + "target_bounding_box_format": self.target_bounding_box_format, "anchor_generator": keras.layers.serialize( self.anchor_generator ), diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py index f98dc9baba..be96a4b84d 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py @@ -1,6 +1,5 @@ +import numpy as np import pytest -from keras import ops -from keras import random from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone @@ -56,7 +55,7 @@ def setUp(self): ) preprocessor = RetinaNetObjectDetectorPreprocessor( - image_converter=image_converter, target_bounding_box_format="xyxy" + image_converter=image_converter ) self.init_kwargs = { @@ -64,17 +63,20 @@ def setUp(self): "anchor_generator": anchor_generator, "label_encoder": label_encoder, "num_classes": 10, - "bounding_box_format": "yxyx", + "ground_truth_bounding_box_format": "rel_yxyx", + "target_bounding_box_format": "xywh", "preprocessor": preprocessor, } self.input_size = 512 - self.images = random.uniform((1, self.input_size, self.input_size, 3)) + self.images = np.random.uniform( + low=0, high=255, size=(1, self.input_size, self.input_size, 3) + ) self.labels = { - "boxes": ops.convert_to_numpy( - [[[20, 10, 120, 110], [30, 20, 130, 120]]] + "boxes": np.array( + [[[0.2, 0.0, 0.12, 0.11], [0.3, 0.2, 0.4, 0.12]]] ), - "classes": ops.convert_to_numpy([[0, 2]]), + "classes": np.array([[0, 2]]), } self.train_data = (self.images, self.labels) From 05fdefef7a9ac49e59b4a39e38d0836fc6d1bf5c Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Tue, 8 Oct 2024 20:29:23 -0700 Subject: [PATCH 12/35] nit --- keras_hub/src/models/image_object_detector_preprocessor.py | 2 +- keras_hub/src/models/retinanet/retinanet_object_detector.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/image_object_detector_preprocessor.py b/keras_hub/src/models/image_object_detector_preprocessor.py index 066aea9aae..a4eab8389c 100644 --- a/keras_hub/src/models/image_object_detector_preprocessor.py +++ b/keras_hub/src/models/image_object_detector_preprocessor.py @@ -65,10 +65,10 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + self.image_converter = image_converter @preprocessing_function def call(self, x, y=None, sample_weight=None): if self.image_converter: x = self.image_converter(x) - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index 0001436c87..d6a30c670d 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -90,12 +90,14 @@ def __init__( num_filters=256, bias_initializer=prior_probability, dtype=head_dtype, + name="box_head", ) classification_head = PredictionHead( output_filters=anchor_generator.anchors_per_location * num_classes, num_conv_layers=4, num_filters=256, dtype=head_dtype, + name="classification_head", ) # === Functional Model === From 3b26d3ade1fd07d157ef8730caa06f044b291e5f Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Tue, 8 Oct 2024 21:42:31 -0700 Subject: [PATCH 13/35] Subclass Imageconverter and overload call method for object detection method --- .../image_object_detector_preprocessor.py | 2 +- .../retinanet/retinanet_image_converter.py | 55 +++++++++++++++++++ .../retinanet/retinanet_object_detector.py | 39 ++++--------- .../retinanet_object_detector_test.py | 11 ++-- 4 files changed, 73 insertions(+), 34 deletions(-) diff --git a/keras_hub/src/models/image_object_detector_preprocessor.py b/keras_hub/src/models/image_object_detector_preprocessor.py index a4eab8389c..35d74741f9 100644 --- a/keras_hub/src/models/image_object_detector_preprocessor.py +++ b/keras_hub/src/models/image_object_detector_preprocessor.py @@ -70,5 +70,5 @@ def __init__( @preprocessing_function def call(self, x, y=None, sample_weight=None): if self.image_converter: - x = self.image_converter(x) + x, y = self.image_converter(x, y) return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_hub/src/models/retinanet/retinanet_image_converter.py b/keras_hub/src/models/retinanet/retinanet_image_converter.py index b37091fd6f..0da56a5c62 100644 --- a/keras_hub/src/models/retinanet/retinanet_image_converter.py +++ b/keras_hub/src/models/retinanet/retinanet_image_converter.py @@ -1,8 +1,63 @@ +from keras import ops + from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.bounding_box.converters import convert_format from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone +from keras_hub.src.utils.keras_utils import standardize_data_format +from keras_hub.src.utils.tensor_utils import preprocessing_function @keras_hub_export("keras_hub.layers.RetinaNetImageConverter") class RetinaNetImageConverter(ImageConverter): backbone_cls = RetinaNetBackbone + + def __init__( + self, + ground_truth_bounding_box_format, + target_bounding_box_format, + image_size=None, + scale=None, + offset=None, + crop_to_aspect_ratio=True, + interpolation="bilinear", + data_format=None, + **kwargs + ): + super().__init__(**kwargs) + self.ground_truth_bounding_box_format = ground_truth_bounding_box_format + self.target_bounding_box_format = target_bounding_box_format + self.image_size = image_size + self.scale = scale + self.offset = offset + self.crop_to_aspect_ratio = crop_to_aspect_ratio + self.interpolation = interpolation + self.data_format = standardize_data_format(data_format) + + @preprocessing_function + def call(self, x, y=None, sample_weight=None, **kwargs): + if self.image_size is not None: + x = self.resizing(x) + if self.offset is not None: + x -= self._expand_non_channel_dims(self.offset, x) + if self.scale is not None: + x /= self._expand_non_channel_dims(self.scale, x) + if y is not None and ops.is_tensor(y): + y = convert_format( + y, + source=self.ground_truth_bounding_box_format, + target=self.target_bounding_box_format, + images=x, + ) + return x, y + + def get_config(self): + config = super().get_config() + config.update( + { + "ground_truth_bounding_box_format": self.ground_truth_bounding_box_format, + "target_bounding_box_format": self.target_bounding_box_format, + } + ) + + return config diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index d6a30c670d..822bc69123 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -32,13 +32,7 @@ class RetinaNetObjectDetector(ImageObjectDetector): `RetinaNetObjectDetector` training targets. anchor_generator: A `keras_Hub.layers.AnchorGenerator`. num_classes: The number of object classes to be detected. - ground_truth_bounding_box_format: Ground truth bounding box format. - Refer TODO: https://github.com/keras-team/keras-hub/issues/1907 - Ensure that ground truth boxes follow one of the following formats. - - `rel_xyxy` - - `rel_yxyx` - - `rel_xywh` - target_bounding_box_format: Target bounding box format. + bounding_box_format: The format of bounding boxes of input dataset. Refer TODO: https://github.com/keras-team/keras-hub/issues/1907 preprocessor: Optional. An instance of the `RetinaNetObjectDetectorPreprocessor` class or a custom preprocessor. @@ -60,24 +54,13 @@ def __init__( label_encoder, anchor_generator, num_classes, - ground_truth_bounding_box_format, - target_bounding_box_format, + bounding_box_format, preprocessor=None, activation=None, dtype=None, prediction_decoder=None, **kwargs, ): - if "rel" not in ground_truth_bounding_box_format: - raise ValueError( - f"Only relative bounding box formats are supported " - f"Received ground_truth_bounding_box_format=" - f"`{ground_truth_bounding_box_format}`. " - f"Please provide a `ground_truth_bounding_box_format` from one of " - f"the following `rel_xyxy` or `rel_yxyx` or `rel_xywh`. " - f"Ensure that the provided ground truth bounding boxes are " - f"normalized and relative to the image size. " - ) # === Layers === image_input = keras.layers.Input(backbone.image_shape, name="images") head_dtype = dtype or backbone.dtype_policy @@ -131,8 +114,7 @@ def __init__( ) # === Config === - self.ground_truth_bounding_box_format = ground_truth_bounding_box_format - self.target_bounding_box_format = target_bounding_box_format + self.bounding_box_format = bounding_box_format self.num_classes = num_classes self.backbone = backbone self.preprocessor = preprocessor @@ -143,13 +125,13 @@ def __init__( self.classification_head = classification_head self._prediction_decoder = prediction_decoder or NonMaxSuppression( from_logits=(activation != keras.activations.sigmoid), - bounding_box_format=self.target_bounding_box_format, + bounding_box_format=self.bounding_box_format, ) def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): y_for_label_encoder = convert_format( y, - source=self.ground_truth_bounding_box_format, + source=self.bounding_box_format, target=self.label_encoder.bounding_box_format, images=x, ) @@ -255,14 +237,14 @@ def decode_predictions(self, predictions, data): anchors=anchor_boxes, boxes_delta=box_pred, anchor_format=self.anchor_generator.bounding_box_format, - box_format=self.target_bounding_box_format, + box_format=self.bounding_box_format, variance=BOX_VARIANCE, image_shape=image_shape, ) - # box_pred is now in "self.target_bounding_box_format" format + # box_pred is now in "self.bounding_box_format" format box_pred = convert_format( box_pred, - source=self.target_bounding_box_format, + source=self.bounding_box_format, target=self.prediction_decoder.bounding_box_format, image_shape=image_shape, ) @@ -272,7 +254,7 @@ def decode_predictions(self, predictions, data): y_pred["boxes"] = convert_format( y_pred["boxes"], source=self.prediction_decoder.bounding_box_format, - target=self.target_bounding_box_format, + target=self.bounding_box_format, image_shape=image_shape, ) return y_pred @@ -282,8 +264,7 @@ def get_config(self): config.update( { "num_classes": self.num_classes, - "ground_truth_bounding_box_format": self.ground_truth_bounding_box_format, - "target_bounding_box_format": self.target_bounding_box_format, + "bounding_box_format": self.bounding_box_format, "anchor_generator": keras.layers.serialize( self.anchor_generator ), diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py index be96a4b84d..8ad1fe3a27 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py @@ -1,10 +1,12 @@ import numpy as np import pytest -from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone +from keras_hub.src.models.retinanet.retinanet_image_converter import ( + RetinaNetImageConverter, +) from keras_hub.src.models.retinanet.retinanet_label_encoder import ( RetinaNetLabelEncoder, ) @@ -50,8 +52,10 @@ def setUp(self): bounding_box_format="yxyx", anchor_generator=anchor_generator ) - image_converter = ImageConverter( + image_converter = RetinaNetImageConverter( image_size=(256, 256), + ground_truth_bounding_box_format="rel_yxyx", + target_bounding_box_format="yxyx", ) preprocessor = RetinaNetObjectDetectorPreprocessor( @@ -63,8 +67,7 @@ def setUp(self): "anchor_generator": anchor_generator, "label_encoder": label_encoder, "num_classes": 10, - "ground_truth_bounding_box_format": "rel_yxyx", - "target_bounding_box_format": "xywh", + "bounding_box_format": "yxyx", "preprocessor": preprocessor, } From 0df121a1be428c65facf9af3ef30d3a750e110a2 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Tue, 8 Oct 2024 22:50:13 -0700 Subject: [PATCH 14/35] Revert "Subclass Imageconverter and overload call method for object detection method" This reverts commit 3b26d3ade1fd07d157ef8730caa06f044b291e5f. --- .../image_object_detector_preprocessor.py | 2 +- .../retinanet/retinanet_image_converter.py | 55 ------------------- .../retinanet/retinanet_object_detector.py | 39 +++++++++---- .../retinanet_object_detector_test.py | 11 ++-- 4 files changed, 34 insertions(+), 73 deletions(-) diff --git a/keras_hub/src/models/image_object_detector_preprocessor.py b/keras_hub/src/models/image_object_detector_preprocessor.py index 35d74741f9..a4eab8389c 100644 --- a/keras_hub/src/models/image_object_detector_preprocessor.py +++ b/keras_hub/src/models/image_object_detector_preprocessor.py @@ -70,5 +70,5 @@ def __init__( @preprocessing_function def call(self, x, y=None, sample_weight=None): if self.image_converter: - x, y = self.image_converter(x, y) + x = self.image_converter(x) return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_hub/src/models/retinanet/retinanet_image_converter.py b/keras_hub/src/models/retinanet/retinanet_image_converter.py index 0da56a5c62..b37091fd6f 100644 --- a/keras_hub/src/models/retinanet/retinanet_image_converter.py +++ b/keras_hub/src/models/retinanet/retinanet_image_converter.py @@ -1,63 +1,8 @@ -from keras import ops - from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.bounding_box.converters import convert_format from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone -from keras_hub.src.utils.keras_utils import standardize_data_format -from keras_hub.src.utils.tensor_utils import preprocessing_function @keras_hub_export("keras_hub.layers.RetinaNetImageConverter") class RetinaNetImageConverter(ImageConverter): backbone_cls = RetinaNetBackbone - - def __init__( - self, - ground_truth_bounding_box_format, - target_bounding_box_format, - image_size=None, - scale=None, - offset=None, - crop_to_aspect_ratio=True, - interpolation="bilinear", - data_format=None, - **kwargs - ): - super().__init__(**kwargs) - self.ground_truth_bounding_box_format = ground_truth_bounding_box_format - self.target_bounding_box_format = target_bounding_box_format - self.image_size = image_size - self.scale = scale - self.offset = offset - self.crop_to_aspect_ratio = crop_to_aspect_ratio - self.interpolation = interpolation - self.data_format = standardize_data_format(data_format) - - @preprocessing_function - def call(self, x, y=None, sample_weight=None, **kwargs): - if self.image_size is not None: - x = self.resizing(x) - if self.offset is not None: - x -= self._expand_non_channel_dims(self.offset, x) - if self.scale is not None: - x /= self._expand_non_channel_dims(self.scale, x) - if y is not None and ops.is_tensor(y): - y = convert_format( - y, - source=self.ground_truth_bounding_box_format, - target=self.target_bounding_box_format, - images=x, - ) - return x, y - - def get_config(self): - config = super().get_config() - config.update( - { - "ground_truth_bounding_box_format": self.ground_truth_bounding_box_format, - "target_bounding_box_format": self.target_bounding_box_format, - } - ) - - return config diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index 822bc69123..d6a30c670d 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -32,7 +32,13 @@ class RetinaNetObjectDetector(ImageObjectDetector): `RetinaNetObjectDetector` training targets. anchor_generator: A `keras_Hub.layers.AnchorGenerator`. num_classes: The number of object classes to be detected. - bounding_box_format: The format of bounding boxes of input dataset. + ground_truth_bounding_box_format: Ground truth bounding box format. + Refer TODO: https://github.com/keras-team/keras-hub/issues/1907 + Ensure that ground truth boxes follow one of the following formats. + - `rel_xyxy` + - `rel_yxyx` + - `rel_xywh` + target_bounding_box_format: Target bounding box format. Refer TODO: https://github.com/keras-team/keras-hub/issues/1907 preprocessor: Optional. An instance of the `RetinaNetObjectDetectorPreprocessor` class or a custom preprocessor. @@ -54,13 +60,24 @@ def __init__( label_encoder, anchor_generator, num_classes, - bounding_box_format, + ground_truth_bounding_box_format, + target_bounding_box_format, preprocessor=None, activation=None, dtype=None, prediction_decoder=None, **kwargs, ): + if "rel" not in ground_truth_bounding_box_format: + raise ValueError( + f"Only relative bounding box formats are supported " + f"Received ground_truth_bounding_box_format=" + f"`{ground_truth_bounding_box_format}`. " + f"Please provide a `ground_truth_bounding_box_format` from one of " + f"the following `rel_xyxy` or `rel_yxyx` or `rel_xywh`. " + f"Ensure that the provided ground truth bounding boxes are " + f"normalized and relative to the image size. " + ) # === Layers === image_input = keras.layers.Input(backbone.image_shape, name="images") head_dtype = dtype or backbone.dtype_policy @@ -114,7 +131,8 @@ def __init__( ) # === Config === - self.bounding_box_format = bounding_box_format + self.ground_truth_bounding_box_format = ground_truth_bounding_box_format + self.target_bounding_box_format = target_bounding_box_format self.num_classes = num_classes self.backbone = backbone self.preprocessor = preprocessor @@ -125,13 +143,13 @@ def __init__( self.classification_head = classification_head self._prediction_decoder = prediction_decoder or NonMaxSuppression( from_logits=(activation != keras.activations.sigmoid), - bounding_box_format=self.bounding_box_format, + bounding_box_format=self.target_bounding_box_format, ) def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): y_for_label_encoder = convert_format( y, - source=self.bounding_box_format, + source=self.ground_truth_bounding_box_format, target=self.label_encoder.bounding_box_format, images=x, ) @@ -237,14 +255,14 @@ def decode_predictions(self, predictions, data): anchors=anchor_boxes, boxes_delta=box_pred, anchor_format=self.anchor_generator.bounding_box_format, - box_format=self.bounding_box_format, + box_format=self.target_bounding_box_format, variance=BOX_VARIANCE, image_shape=image_shape, ) - # box_pred is now in "self.bounding_box_format" format + # box_pred is now in "self.target_bounding_box_format" format box_pred = convert_format( box_pred, - source=self.bounding_box_format, + source=self.target_bounding_box_format, target=self.prediction_decoder.bounding_box_format, image_shape=image_shape, ) @@ -254,7 +272,7 @@ def decode_predictions(self, predictions, data): y_pred["boxes"] = convert_format( y_pred["boxes"], source=self.prediction_decoder.bounding_box_format, - target=self.bounding_box_format, + target=self.target_bounding_box_format, image_shape=image_shape, ) return y_pred @@ -264,7 +282,8 @@ def get_config(self): config.update( { "num_classes": self.num_classes, - "bounding_box_format": self.bounding_box_format, + "ground_truth_bounding_box_format": self.ground_truth_bounding_box_format, + "target_bounding_box_format": self.target_bounding_box_format, "anchor_generator": keras.layers.serialize( self.anchor_generator ), diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py index 8ad1fe3a27..be96a4b84d 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py @@ -1,12 +1,10 @@ import numpy as np import pytest +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone -from keras_hub.src.models.retinanet.retinanet_image_converter import ( - RetinaNetImageConverter, -) from keras_hub.src.models.retinanet.retinanet_label_encoder import ( RetinaNetLabelEncoder, ) @@ -52,10 +50,8 @@ def setUp(self): bounding_box_format="yxyx", anchor_generator=anchor_generator ) - image_converter = RetinaNetImageConverter( + image_converter = ImageConverter( image_size=(256, 256), - ground_truth_bounding_box_format="rel_yxyx", - target_bounding_box_format="yxyx", ) preprocessor = RetinaNetObjectDetectorPreprocessor( @@ -67,7 +63,8 @@ def setUp(self): "anchor_generator": anchor_generator, "label_encoder": label_encoder, "num_classes": 10, - "bounding_box_format": "yxyx", + "ground_truth_bounding_box_format": "rel_yxyx", + "target_bounding_box_format": "xywh", "preprocessor": preprocessor, } From 8697240628afabf826c233180f7024761d1a3594 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Tue, 8 Oct 2024 23:13:25 -0700 Subject: [PATCH 15/35] add names to layers --- .../src/models/retinanet/retinanet_object_detector.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index d6a30c670d..34fe13f7d8 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -107,12 +107,14 @@ def __init__( box_pred = [] for level in feature_map: box_pred.append( - keras.layers.Reshape((-1, 4))(box_head(feature_map[level])) + keras.layers.Reshape((-1, 4), name=f"box_pred_{level}")( + box_head(feature_map[level]) + ) ) cls_pred.append( - keras.layers.Reshape((-1, num_classes))( - classification_head(feature_map[level]) - ) + keras.layers.Reshape( + (-1, num_classes), name=f"cls_pred_{level}" + )(classification_head(feature_map[level])) ) cls_pred = keras.layers.Concatenate(axis=1, name="classification")( From 394faf0c26ee42ecd4bd1c356c79bd03099ef439 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 9 Oct 2024 00:27:30 -0700 Subject: [PATCH 16/35] correct fpn coarser level as per torch retinanet model --- .../src/models/retinanet/feature_pyramid.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/keras_hub/src/models/retinanet/feature_pyramid.py b/keras_hub/src/models/retinanet/feature_pyramid.py index 2322063eea..efc18895e7 100644 --- a/keras_hub/src/models/retinanet/feature_pyramid.py +++ b/keras_hub/src/models/retinanet/feature_pyramid.py @@ -198,11 +198,18 @@ def build(self, input_shapes): dtype=self.dtype_policy, name=f"coarser_{level}", ) - self.output_conv_layers[level].build( - (None, None, None, self.num_filters) - if self.data_format == "channels_last" - else (None, self.num_filters, None, None) - ) + if i == backbone_max_level + 1: + self.output_conv_layers[level].build( + (None, None, None, input_shapes[f"P{i-1}"][-1]) + if self.data_format == "channels_last" + else (None, input_shapes[f"P{i-1}"][-1], None, None) + ) + else: + self.output_conv_layers[level].build( + (None, None, None, self.num_filters) + if self.data_format == "channels_last" + else (None, self.num_filters, None, None) + ) # Build batch norm layers self.output_batch_norms = {} @@ -279,7 +286,12 @@ def call(self, inputs): for i in range(backbone_max_level + 1, self.max_level + 1): level = f"P{i}" - feats_in = output_features[f"P{i-1}"] + feats_in = ( + inputs[f"P{i-1}"] + if i == backbone_max_level + 1 + else output_features[f"P{i-1}"] + ) + print(feats_in.shape) if i > backbone_max_level + 1: feats_in = self.activation(feats_in) output_features[level] = ( From 33d81e9baaeba3ad5965077f796fdc53d6dd0ae7 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 9 Oct 2024 00:32:48 -0700 Subject: [PATCH 17/35] nit --- keras_hub/src/models/retinanet/feature_pyramid.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras_hub/src/models/retinanet/feature_pyramid.py b/keras_hub/src/models/retinanet/feature_pyramid.py index efc18895e7..4b66f70a40 100644 --- a/keras_hub/src/models/retinanet/feature_pyramid.py +++ b/keras_hub/src/models/retinanet/feature_pyramid.py @@ -291,7 +291,6 @@ def call(self, inputs): if i == backbone_max_level + 1 else output_features[f"P{i-1}"] ) - print(feats_in.shape) if i > backbone_max_level + 1: feats_in = self.activation(feats_in) output_features[level] = ( From 79502d996441de7a367a688bef6a3e920372b40b Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 9 Oct 2024 12:48:59 -0700 Subject: [PATCH 18/35] Polish Prediction head and fpn layers to include flags and norm layers --- .../src/models/retinanet/feature_pyramid.py | 11 ++++- .../models/retinanet/feature_pyramid_test.py | 13 +++++- .../src/models/retinanet/prediction_head.py | 44 +++++++++++++------ .../models/retinanet/prediction_head_test.py | 12 ++++- .../models/retinanet/retinanet_backbone.py | 24 +++++++--- .../retinanet/retinanet_backbone_test.py | 1 + .../retinanet/retinanet_object_detector.py | 7 +++ .../retinanet_object_detector_test.py | 1 + 8 files changed, 88 insertions(+), 25 deletions(-) diff --git a/keras_hub/src/models/retinanet/feature_pyramid.py b/keras_hub/src/models/retinanet/feature_pyramid.py index 4b66f70a40..3c4aaa760c 100644 --- a/keras_hub/src/models/retinanet/feature_pyramid.py +++ b/keras_hub/src/models/retinanet/feature_pyramid.py @@ -39,6 +39,10 @@ class FeaturePyramid(keras.layers.Layer): Args: min_level: int. The minimum level of the feature pyramid. max_level: int. The maximum level of the feature pyramid. + use_p5: bool. If True, uses the output of the last layer (`P5` from + Feature Pyramid Network) as input for creating coarser convolution + layers (`P6`, `P7`). If False, uses the direct input `P5` + for creating coarser convolution layers. num_filters: int. The number of filters in each feature map. activation: string or `keras.activations`. The activation function to be used in network. @@ -71,6 +75,7 @@ def __init__( self, min_level, max_level, + use_p5, num_filters=256, activation="relu", kernel_initializer="VarianceScaling", @@ -92,6 +97,7 @@ def __init__( self.min_level = min_level self.max_level = max_level self.num_filters = num_filters + self.use_p5 = use_p5 self.activation = keras.activations.get(activation) self.kernel_initializer = keras.initializers.get(kernel_initializer) self.bias_initializer = keras.initializers.get(bias_initializer) @@ -198,7 +204,7 @@ def build(self, input_shapes): dtype=self.dtype_policy, name=f"coarser_{level}", ) - if i == backbone_max_level + 1: + if i == backbone_max_level + 1 and self.use_p5: self.output_conv_layers[level].build( (None, None, None, input_shapes[f"P{i-1}"][-1]) if self.data_format == "channels_last" @@ -288,7 +294,7 @@ def call(self, inputs): level = f"P{i}" feats_in = ( inputs[f"P{i-1}"] - if i == backbone_max_level + 1 + if i == backbone_max_level + 1 and self.use_p5 else output_features[f"P{i-1}"] ) if i > backbone_max_level + 1: @@ -313,6 +319,7 @@ def get_config(self): "min_level": self.min_level, "max_level": self.max_level, "num_filters": self.num_filters, + "use_p5": self.use_p5, "use_batch_norm": self.use_batch_norm, "data_format": self.data_format, "activation": keras.activations.serialize(self.activation), diff --git a/keras_hub/src/models/retinanet/feature_pyramid_test.py b/keras_hub/src/models/retinanet/feature_pyramid_test.py index 728233c6ae..b9b62e62ac 100644 --- a/keras_hub/src/models/retinanet/feature_pyramid_test.py +++ b/keras_hub/src/models/retinanet/feature_pyramid_test.py @@ -18,6 +18,7 @@ def test_layer_behaviors(self): "batch_norm_epsilon": 0.0001, "kernel_initializer": "HeNormal", "bias_initializer": "Zeros", + "use_p5": False, }, input_data={ "P3": random.uniform(shape=(2, 64, 64, 4)), @@ -40,12 +41,14 @@ def test_layer_behaviors(self): "equal_resolutions", 3, 7, + False, {"P3": (2, 16, 16, 3), "P4": (2, 8, 8, 3), "P5": (2, 4, 4, 3)}, ), ( "different_resolutions", 2, 6, + True, { "P2": (2, 64, 128, 4), "P3": (2, 32, 64, 8), @@ -54,8 +57,14 @@ def test_layer_behaviors(self): }, ), ) - def test_layer_output_shapes(self, min_level, max_level, input_shapes): - layer = FeaturePyramid(min_level=min_level, max_level=max_level) + def test_layer_output_shapes( + self, min_level, max_level, use_p5, input_shapes + ): + layer = FeaturePyramid( + min_level=min_level, + max_level=max_level, + use_p5=use_p5, + ) inputs = { level: ops.ones(input_shapes[level]) for level in input_shapes diff --git a/keras_hub/src/models/retinanet/prediction_head.py b/keras_hub/src/models/retinanet/prediction_head.py index 95395e864b..62446bf7ba 100644 --- a/keras_hub/src/models/retinanet/prediction_head.py +++ b/keras_hub/src/models/retinanet/prediction_head.py @@ -1,11 +1,11 @@ import keras +from keras_hub.src.utils.keras_utils import standardize_data_format + class PredictionHead(keras.layers.Layer): """The classification/box predictions head. - - Args: output_filters: int. Number of convolution filters in the final layer. The number of output channels determines the prediction type: @@ -28,6 +28,8 @@ class PredictionHead(keras.layers.Layer): regularizer for the convolution layers. Defaults to `None`. bias_regularizer: `str` or `keras.regularizers`. The bias regularizer for the convolution layers. Defaults to `None`. + use_group_norm: bool. Whether to use Group Normalization after + the convolution layers. Defaults to `False`. Returns: A function representing either the classification @@ -44,6 +46,8 @@ def __init__( bias_initializer="zeros", kernel_regularizer=None, bias_regularizer=None, + use_group_norm=False, + data_format=None, **kwargs, ): super().__init__(**kwargs) @@ -62,33 +66,42 @@ def __init__( self.bias_regularizer = keras.regularizers.get(bias_regularizer) else: self.bias_regularizer = None - - self.data_format = keras.backend.image_data_format() + self.use_group_norm = use_group_norm + self.data_format = standardize_data_format(data_format) def build(self, input_shape): - self.conv_layers = [ - keras.layers.Conv2D( + intermediate_shape = input_shape + self.conv_layers = [] + self.group_norm_layers = [] + for _ in range(self.num_conv_layers): + conv = keras.layers.Conv2D( self.num_filters, kernel_size=3, padding="same", kernel_initializer=self.kernel_initializer, bias_initializer=self.bias_initializer, + use_bias=not self.use_group_norm, kernel_regularizer=self.kernel_regularizer, bias_regularizer=self.bias_regularizer, - activation=self.activation, data_format=self.data_format, dtype=self.dtype_policy, ) - for _ in range(self.num_conv_layers) - ] - intermediate_shape = input_shape - for conv in self.conv_layers: conv.build(intermediate_shape) + self.conv_layers.append(conv) intermediate_shape = ( input_shape[:-1] + (self.num_filters,) if self.data_format == "channels_last" else (input_shape[0], self.num_filters) + (input_shape[1:-1]) ) + if self.use_group_norm: + group_norm = keras.layers.GroupNormalization( + groups=32, + axis=-1 if self.data_format == "channels_last" else 1, + dtype=self.dtype_policy, + ) + group_norm.build(intermediate_shape) + self.group_norm_layers.append(group_norm) + self.prediction_layer = keras.layers.Conv2D( self.output_filters, kernel_size=3, @@ -109,8 +122,12 @@ def build(self, input_shape): def call(self, input): x = input - for conv in self.conv_layers: - x = conv(x) + for idx in range(self.num_conv_layers): + x = self.conv_layers[idx](x) + if self.use_group_norm: + x = self.group_norm_layers[idx](x) + x = self.activation(x) + output = self.prediction_layer(x) return output @@ -121,6 +138,7 @@ def get_config(self): "output_filters": self.output_filters, "num_filters": self.num_filters, "num_conv_layers": self.num_conv_layers, + "use_group_norm": self.use_group_norm, "activation": keras.activations.serialize(self.activation), "kernel_initializer": keras.initializers.serialize( self.kernel_initializer diff --git a/keras_hub/src/models/retinanet/prediction_head_test.py b/keras_hub/src/models/retinanet/prediction_head_test.py index 7416565d12..004cea0ec2 100644 --- a/keras_hub/src/models/retinanet/prediction_head_test.py +++ b/keras_hub/src/models/retinanet/prediction_head_test.py @@ -1,3 +1,4 @@ +from absl.testing import parameterized from keras import random from keras_hub.src.models.retinanet.prediction_head import PredictionHead @@ -5,15 +6,22 @@ class FeaturePyramidTest(TestCase): - def test_layer_behaviors(self): + @parameterized.named_parameters( + ("without_group_normalization", False, 10), + ("with_group_normalization", True, 14), + ) + def test_layer_behaviors( + self, use_group_norm, expected_num_trainable_weights + ): self.run_layer_test( cls=PredictionHead, init_kwargs={ "output_filters": 9 * 4, # anchors_per_location * box length(4) "num_filters": 256, "num_conv_layers": 4, + "use_group_norm": use_group_norm, }, input_data=random.uniform(shape=(2, 64, 64, 256)), expected_output_shape=(2, 64, 64, 36), - expected_num_trainable_weights=10, + expected_num_trainable_weights=expected_num_trainable_weights, ) diff --git a/keras_hub/src/models/retinanet/retinanet_backbone.py b/keras_hub/src/models/retinanet/retinanet_backbone.py index b50e87fc2e..72888c5099 100644 --- a/keras_hub/src/models/retinanet/retinanet_backbone.py +++ b/keras_hub/src/models/retinanet/retinanet_backbone.py @@ -14,14 +14,18 @@ class RetinaNetBackbone(FeaturePyramidBackbone): network (FPN)to extract multi-scale features for object detection. Args: - image_encoder (keras.Model): The backbone model used to extract features + image_encoder: `keras.Model`. The backbone model used to extract features from the input image. It should have pyramid outputs. - min_level (int): The minimum feature pyramid level. - max_level (int): The maximum feature pyramid level. - image_shape (tuple): The shape of the input image. - data_format (str): The data format of the input image (channels_first or channels_last). - dtype (str): The data type of the input image. + min_level: int. The minimum feature pyramid level. + max_level: int. The maximum feature pyramid level. + use_p5: bool. If True, uses the output of the last layer (`P5` from + Feature Pyramid Network) as input for creating coarser convolution + layers (`P6`, `P7`). If False, uses the direct input `P5` + for creating coarser convolution layers. + image_shape: tuple. The shape of the input image. + data_format: str. The data format of the input image (channels_first or channels_last). + dtype: str. The data type of the input image. **kwargs: Additional arguments passed to the base class. Raises: @@ -34,6 +38,8 @@ def __init__( image_encoder, min_level, max_level, + use_p5, + use_fpn_batch_norm=False, image_shape=(None, None, 3), data_format=None, dtype=None, @@ -71,9 +77,11 @@ def __init__( feature_pyramid = FeaturePyramid( min_level=min_level, max_level=max_level, + use_p5=use_p5, name="fpn", dtype=dtype, data_format=data_format, + use_batch_norm=use_fpn_batch_norm, ) # === Functional model === @@ -91,6 +99,8 @@ def __init__( # === config === self.min_level = min_level self.max_level = max_level + self.use_p5 = use_p5 + self.use_fpn_batch_norm = use_fpn_batch_norm self.image_encoder = image_encoder self.feature_pyramid = feature_pyramid self.image_shape = image_shape @@ -103,6 +113,8 @@ def get_config(self): "image_encoder": keras.layers.serialize(self.image_encoder), "min_level": self.min_level, "max_level": self.max_level, + "use_p5": self.use_p5, + "use_fpn_batch_norm": self.use_fpn_batch_norm, "image_shape": self.image_shape, } ) diff --git a/keras_hub/src/models/retinanet/retinanet_backbone_test.py b/keras_hub/src/models/retinanet/retinanet_backbone_test.py index 176544c236..524374447b 100644 --- a/keras_hub/src/models/retinanet/retinanet_backbone_test.py +++ b/keras_hub/src/models/retinanet/retinanet_backbone_test.py @@ -23,6 +23,7 @@ def setUp(self): "image_encoder": image_encoder, "min_level": 3, "max_level": 7, + "use_p5": True, } self.input_size = 256 diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index 34fe13f7d8..f667e9e484 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -40,6 +40,8 @@ class RetinaNetObjectDetector(ImageObjectDetector): - `rel_xywh` target_bounding_box_format: Target bounding box format. Refer TODO: https://github.com/keras-team/keras-hub/issues/1907 + use_prediction_head_norm: bool. Whether to use Group Normalization after + the convolution layers in prediction head. Defaults to `False`. preprocessor: Optional. An instance of the `RetinaNetObjectDetectorPreprocessor` class or a custom preprocessor. activation: Optional. The activation function to be used in the @@ -62,6 +64,7 @@ def __init__( num_classes, ground_truth_bounding_box_format, target_bounding_box_format, + use_prediction_head_norm=False, preprocessor=None, activation=None, dtype=None, @@ -89,6 +92,7 @@ def __init__( num_conv_layers=4, num_filters=256, bias_initializer=prior_probability, + use_group_norm=use_prediction_head_norm, dtype=head_dtype, name="box_head", ) @@ -96,6 +100,7 @@ def __init__( output_filters=anchor_generator.anchors_per_location * num_classes, num_conv_layers=4, num_filters=256, + use_group_norm=use_prediction_head_norm, dtype=head_dtype, name="classification_head", ) @@ -135,6 +140,7 @@ def __init__( # === Config === self.ground_truth_bounding_box_format = ground_truth_bounding_box_format self.target_bounding_box_format = target_bounding_box_format + self.use_prediction_head_norm = use_prediction_head_norm self.num_classes = num_classes self.backbone = backbone self.preprocessor = preprocessor @@ -285,6 +291,7 @@ def get_config(self): { "num_classes": self.num_classes, "ground_truth_bounding_box_format": self.ground_truth_bounding_box_format, + "use_prediction_head_norm": self.use_prediction_head_norm, "target_bounding_box_format": self.target_bounding_box_format, "anchor_generator": keras.layers.serialize( self.anchor_generator diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py index be96a4b84d..b321f6eef7 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py @@ -35,6 +35,7 @@ def setUp(self): "image_encoder": image_encoder, "min_level": 3, "max_level": 4, + "use_p5": False, } feature_extractor = RetinaNetBackbone(**retinanet_backbone_kwargs) From 72a02c42087882d2e9b3a7b5c047af2950172c01 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 9 Oct 2024 12:57:58 -0700 Subject: [PATCH 19/35] nit --- keras_hub/src/models/retinanet/prediction_head.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/retinanet/prediction_head.py b/keras_hub/src/models/retinanet/prediction_head.py index 62446bf7ba..20c092c24e 100644 --- a/keras_hub/src/models/retinanet/prediction_head.py +++ b/keras_hub/src/models/retinanet/prediction_head.py @@ -73,7 +73,7 @@ def build(self, input_shape): intermediate_shape = input_shape self.conv_layers = [] self.group_norm_layers = [] - for _ in range(self.num_conv_layers): + for idx in range(self.num_conv_layers): conv = keras.layers.Conv2D( self.num_filters, kernel_size=3, @@ -85,6 +85,7 @@ def build(self, input_shape): bias_regularizer=self.bias_regularizer, data_format=self.data_format, dtype=self.dtype_policy, + name=f"conv2d_{idx}", ) conv.build(intermediate_shape) self.conv_layers.append(conv) @@ -98,6 +99,7 @@ def build(self, input_shape): groups=32, axis=-1 if self.data_format == "channels_last" else 1, dtype=self.dtype_policy, + name=f"group_norm_{idx}", ) group_norm.build(intermediate_shape) self.group_norm_layers.append(group_norm) @@ -112,6 +114,7 @@ def build(self, input_shape): kernel_regularizer=self.kernel_regularizer, bias_regularizer=self.bias_regularizer, dtype=self.dtype_policy, + name="logits_layer", ) self.prediction_layer.build( (None, None, None, self.num_filters) From a28a0332cbc1a9a4d0da324caa5d89767271c31c Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 9 Oct 2024 13:52:20 -0700 Subject: [PATCH 20/35] nit --- keras_hub/src/models/retinanet/prediction_head_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/retinanet/prediction_head_test.py b/keras_hub/src/models/retinanet/prediction_head_test.py index 004cea0ec2..111c92ee7a 100644 --- a/keras_hub/src/models/retinanet/prediction_head_test.py +++ b/keras_hub/src/models/retinanet/prediction_head_test.py @@ -5,7 +5,7 @@ from keras_hub.src.tests.test_case import TestCase -class FeaturePyramidTest(TestCase): +class PredictionHeadTest(TestCase): @parameterized.named_parameters( ("without_group_normalization", False, 10), ("with_group_normalization", True, 14), From 50686e032319be582a5683a0f2431a1f46d65c53 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 9 Oct 2024 14:55:38 -0700 Subject: [PATCH 21/35] add prior probability flag for prediction head to use it for classification head and user friendly --- .../src/models/retinanet/prediction_head.py | 25 +++++++++++++++++-- .../retinanet/retinanet_object_detector.py | 8 +++--- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/keras_hub/src/models/retinanet/prediction_head.py b/keras_hub/src/models/retinanet/prediction_head.py index 20c092c24e..1581d901cb 100644 --- a/keras_hub/src/models/retinanet/prediction_head.py +++ b/keras_hub/src/models/retinanet/prediction_head.py @@ -19,6 +19,12 @@ class PredictionHead(keras.layers.Layer): Defaults to `256`. num_conv_layers: int. Number of convolution layers before final layer. Defaults to `4`. + use_prior_probability: bool. Whether to use prior probability in the + bias initializer for the final convolution layer. Defaults to + `False`. + prior_probability: float. The prior probability value to use for + initializing the bias. Only used if `use_prior_probability` is + `True`. Defaults to `0.01`. kernel_initializer: `str` or `keras.initializers`. The kernel initializer for the convolution layers. Defaults to `"random_normal"`. @@ -41,6 +47,8 @@ def __init__( output_filters, num_filters, num_conv_layers, + use_prior_probability=False, + prior_probability=0.01, activation="relu", kernel_initializer="random_normal", bias_initializer="zeros", @@ -55,6 +63,8 @@ def __init__( self.output_filters = output_filters self.num_filters = num_filters self.num_conv_layers = num_conv_layers + self.use_prior_probability = use_prior_probability + self.prior_probability = prior_probability self.activation = keras.activations.get(activation) self.kernel_initializer = keras.initializers.get(kernel_initializer) self.bias_initializer = keras.initializers.get(bias_initializer) @@ -103,14 +113,23 @@ def build(self, input_shape): ) group_norm.build(intermediate_shape) self.group_norm_layers.append(group_norm) - + prior_probability = keras.initializers.Constant( + -1 + * keras.ops.log( + (1 - self.prior_probability) / self.prior_probability + ) + ) self.prediction_layer = keras.layers.Conv2D( self.output_filters, kernel_size=3, strides=1, padding="same", kernel_initializer=self.kernel_initializer, - bias_initializer=self.bias_initializer, + bias_initializer=( + prior_probability + if self.use_prior_probability + else self.bias_initializer + ), kernel_regularizer=self.kernel_regularizer, bias_regularizer=self.bias_regularizer, dtype=self.dtype_policy, @@ -142,6 +161,8 @@ def get_config(self): "num_filters": self.num_filters, "num_conv_layers": self.num_conv_layers, "use_group_norm": self.use_group_norm, + "use_prior_probability": self.use_prior_probability, + "prior_probability": self.prior_probability, "activation": keras.activations.serialize(self.activation), "kernel_initializer": keras.initializers.serialize( self.kernel_initializer diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index f667e9e484..5bf0690487 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -65,6 +65,7 @@ def __init__( ground_truth_bounding_box_format, target_bounding_box_format, use_prediction_head_norm=False, + classification_head_prior_probability=0.01, preprocessor=None, activation=None, dtype=None, @@ -84,15 +85,14 @@ def __init__( # === Layers === image_input = keras.layers.Input(backbone.image_shape, name="images") head_dtype = dtype or backbone.dtype_policy - prior_probability = keras.initializers.Constant( - -1 * keras.ops.log((1 - 0.01) / 0.01) - ) + box_head = PredictionHead( output_filters=anchor_generator.anchors_per_location * 4, num_conv_layers=4, num_filters=256, - bias_initializer=prior_probability, use_group_norm=use_prediction_head_norm, + use_prior_probability=True, + prior_probability=classification_head_prior_probability, dtype=head_dtype, name="box_head", ) From 8dc54838027415c4c653a58b9fdc71dc5ceb4600 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 9 Oct 2024 15:47:43 -0700 Subject: [PATCH 22/35] compute_shape seems redudant here and correct layers for channels_first --- .../src/models/retinanet/feature_pyramid.py | 91 +------------------ 1 file changed, 2 insertions(+), 89 deletions(-) diff --git a/keras_hub/src/models/retinanet/feature_pyramid.py b/keras_hub/src/models/retinanet/feature_pyramid.py index 3c4aaa760c..2aefee916a 100644 --- a/keras_hub/src/models/retinanet/feature_pyramid.py +++ b/keras_hub/src/models/retinanet/feature_pyramid.py @@ -145,7 +145,7 @@ def build(self, input_shapes): self.lateral_conv_layers[level].build( (None, None, None, input_shapes[level][-1]) if self.data_format == "channels_last" - else (None, input_shapes[level][-1], None, None) + else (None, input_shapes[level][1], None, None) ) self.lateral_batch_norm_layers = {} @@ -208,7 +208,7 @@ def build(self, input_shapes): self.output_conv_layers[level].build( (None, None, None, input_shapes[f"P{i-1}"][-1]) if self.data_format == "channels_last" - else (None, input_shapes[f"P{i-1}"][-1], None, None) + else (None, input_shapes[f"P{i-1}"][1], None, None) ) else: self.output_conv_layers[level].build( @@ -345,90 +345,3 @@ def get_config(self): ) return config - - def compute_output_shape(self, input_shapes): - output_shape = {} - input_levels = [int(level[1]) for level in input_shapes] - backbone_max_level = min(max(input_levels), self.max_level) - - for i in range(self.min_level, backbone_max_level + 1): - level = f"P{i}" - if self.data_format == "channels_last": - output_shape[level] = input_shapes[level][:-1] + ( - self.num_filters, - ) - else: - output_shape[level] = ( - input_shapes[level][0], - self.num_filters, - ) + input_shapes[level][1:3] - - intermediate_shape = input_shapes[f"P{backbone_max_level}"] - intermediate_shape = ( - ( - intermediate_shape[0], - ( - intermediate_shape[1] // 2 - if intermediate_shape[1] is not None - else None - ), - ( - intermediate_shape[2] // 2 - if intermediate_shape[1] is not None - else None - ), - self.num_filters, - ) - if self.data_format == "channels_last" - else ( - intermediate_shape[0], - self.num_filters, - ( - intermediate_shape[1] // 2 - if intermediate_shape[1] is not None - else None - ), - ( - intermediate_shape[2] // 2 - if intermediate_shape[1] is not None - else None - ), - ) - ) - - for i in range(backbone_max_level + 1, self.max_level + 1): - level = f"P{i}" - output_shape[level] = intermediate_shape - intermediate_shape = ( - ( - intermediate_shape[0], - ( - intermediate_shape[1] // 2 - if intermediate_shape[1] is not None - else None - ), - ( - intermediate_shape[2] // 2 - if intermediate_shape[1] is not None - else None - ), - self.num_filters, - ) - if self.data_format == "channels_last" - else ( - intermediate_shape[0], - self.num_filters, - ( - intermediate_shape[1] // 2 - if intermediate_shape[1] is not None - else None - ), - ( - intermediate_shape[2] // 2 - if intermediate_shape[1] is not None - else None - ), - ) - ) - - return output_shape From 9f7d8ef34ab859687f0f220368fedaee64a68ae9 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 9 Oct 2024 16:01:40 -0700 Subject: [PATCH 23/35] keep compute_output_shape for fpn --- .../src/models/retinanet/feature_pyramid.py | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/keras_hub/src/models/retinanet/feature_pyramid.py b/keras_hub/src/models/retinanet/feature_pyramid.py index 2aefee916a..e6bb18bde7 100644 --- a/keras_hub/src/models/retinanet/feature_pyramid.py +++ b/keras_hub/src/models/retinanet/feature_pyramid.py @@ -1,3 +1,4 @@ +import math import keras from keras_hub.src.utils.keras_utils import standardize_data_format @@ -345,3 +346,90 @@ def get_config(self): ) return config + + def compute_output_shape(self, input_shapes): + output_shape = {} + input_levels = [int(level[1]) for level in input_shapes] + backbone_max_level = min(max(input_levels), self.max_level) + + for i in range(self.min_level, backbone_max_level + 1): + level = f"P{i}" + if self.data_format == "channels_last": + output_shape[level] = input_shapes[level][:-1] + ( + self.num_filters, + ) + else: + output_shape[level] = ( + input_shapes[level][0], + self.num_filters, + ) + input_shapes[level][1:3] + + intermediate_shape = input_shapes[f"P{backbone_max_level}"] + intermediate_shape = ( + ( + intermediate_shape[0], + ( + int(math.ceil(intermediate_shape[1] / 2)) + if intermediate_shape[1] is not None + else None + ), + ( + int(math.ceil(intermediate_shape[1] / 2)) + if intermediate_shape[1] is not None + else None + ), + self.num_filters, + ) + if self.data_format == "channels_last" + else ( + intermediate_shape[0], + self.num_filters, + ( + int(math.ceil(intermediate_shape[1] / 2)) + if intermediate_shape[1] is not None + else None + ), + ( + int(math.ceil(intermediate_shape[1] / 2)) + if intermediate_shape[1] is not None + else None + ), + ) + ) + + for i in range(backbone_max_level + 1, self.max_level + 1): + level = f"P{i}" + output_shape[level] = intermediate_shape + intermediate_shape = ( + ( + intermediate_shape[0], + ( + int(math.ceil(intermediate_shape[1] / 2)) + if intermediate_shape[1] is not None + else None + ), + ( + int(math.ceil(intermediate_shape[1] / 2)) + if intermediate_shape[1] is not None + else None + ), + self.num_filters, + ) + if self.data_format == "channels_last" + else ( + intermediate_shape[0], + self.num_filters, + ( + int(math.ceil(intermediate_shape[1] / 2)) + if intermediate_shape[1] is not None + else None + ), + ( + int(math.ceil(intermediate_shape[1] / 2)) + if intermediate_shape[1] is not None + else None + ), + ) + ) + + return output_shape From 68017891dcf1479a6a56ef4161cc53aac29d275f Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 9 Oct 2024 17:42:32 -0700 Subject: [PATCH 24/35] nit --- keras_hub/src/models/retinanet/feature_pyramid.py | 1 + keras_hub/src/models/retinanet/retinanet_object_detector.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/retinanet/feature_pyramid.py b/keras_hub/src/models/retinanet/feature_pyramid.py index e6bb18bde7..ea8b13af75 100644 --- a/keras_hub/src/models/retinanet/feature_pyramid.py +++ b/keras_hub/src/models/retinanet/feature_pyramid.py @@ -1,4 +1,5 @@ import math + import keras from keras_hub.src.utils.keras_utils import standardize_data_format diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index 5bf0690487..3839d791a1 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -12,7 +12,7 @@ RetinaNetObjectDetectorPreprocessor, ) -BOX_VARIANCE = [0.1, 0.1, 0.2, 0.2] +BOX_VARIANCE = [1.0, 1.0, 1.0, 1.0] @keras_hub_export("keras_hub.models.RetinaNetObjectDetector") From 7e57cf18bce746415217b485873237fb2849e622 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 10 Oct 2024 13:52:28 -0700 Subject: [PATCH 25/35] Change AnchorGen Implementation as per torch --- .../src/models/retinanet/anchor_generator.py | 99 +++++++++---------- .../models/retinanet/anchor_generator_test.py | 61 +----------- .../retinanet/retinanet_label_encoder.py | 15 +-- .../retinanet/retinanet_object_detector.py | 4 +- .../retinanet_object_detector_test.py | 2 +- 5 files changed, 63 insertions(+), 118 deletions(-) diff --git a/keras_hub/src/models/retinanet/anchor_generator.py b/keras_hub/src/models/retinanet/anchor_generator.py index e04779abb6..38fc44826e 100644 --- a/keras_hub/src/models/retinanet/anchor_generator.py +++ b/keras_hub/src/models/retinanet/anchor_generator.py @@ -83,6 +83,7 @@ def __init__( self.num_scales = num_scales self.aspect_ratios = aspect_ratios self.anchor_size = anchor_size + self.num_base_anchors = num_scales * len(aspect_ratios) self.built = True def call(self, inputs): @@ -94,60 +95,61 @@ def call(self, inputs): image_shape = tuple(image_shape) - multilevel_boxes = {} + multilevel_anchors = {} for level in range(self.min_level, self.max_level + 1): - boxes_l = [] # Calculate the feature map size for this level feat_size_y = math.ceil(image_shape[0] / 2**level) feat_size_x = math.ceil(image_shape[1] / 2**level) # Calculate the stride (step size) for this level - stride_y = image_shape[0] / feat_size_y - stride_x = image_shape[1] / feat_size_x + stride_y = image_shape[0] // feat_size_y + stride_x = image_shape[1] // feat_size_x # Generate anchor center points # Start from stride/2 to center anchors on pixels - cx = ops.arange(stride_x / 2, image_shape[1], stride_x) - cy = ops.arange(stride_y / 2, image_shape[0], stride_y) + cx = ops.arange(0, feat_size_x, dtype="float32") * stride_x + cy = ops.arange(0, feat_size_y, dtype="float32") * stride_y # Create a grid of anchor centers - cx_grid, cy_grid = ops.meshgrid(cx, cy) - - for scale in range(self.num_scales): - for aspect_ratio in self.aspect_ratios: - # Calculate the intermediate scale factor - intermidate_scale = 2 ** (scale / self.num_scales) - # Calculate the base anchor size for this level and scale - base_anchor_size = ( - self.anchor_size * 2**level * intermidate_scale - ) - # Adjust anchor dimensions based on aspect ratio - aspect_x = aspect_ratio**0.5 - aspect_y = aspect_ratio**-0.5 - half_anchor_size_x = base_anchor_size * aspect_x / 2.0 - half_anchor_size_y = base_anchor_size * aspect_y / 2.0 - - # Generate anchor boxes (y1, x1, y2, x2 format) - boxes = ops.stack( - [ - cy_grid - half_anchor_size_y, - cx_grid - half_anchor_size_x, - cy_grid + half_anchor_size_y, - cx_grid + half_anchor_size_x, - ], - axis=-1, - ) - boxes_l.append(boxes) - # Concat anchors on the same level to tensor shape HxWx(Ax4) - boxes_l = ops.concatenate(boxes_l, axis=-1) - boxes_l = ops.reshape(boxes_l, (-1, 4)) - # Convert to user defined - multilevel_boxes[f"P{level}"] = convert_format( - boxes_l, + cy_grid, cx_grid = ops.meshgrid(cy, cx, indexing="ij") + cy_grid = ops.reshape(cy_grid, (-1,)) + cx_grid = ops.reshape(cx_grid, (-1,)) + + shifts = ops.stack((cx_grid, cy_grid, cx_grid, cy_grid), axis=1) + sizes = [ + int( + 2**level * self.anchor_size * 2 ** (scale / self.num_scales) + ) + for scale in range(self.num_scales) + ] + + base_anchors = self.generate_base_anchors( + sizes=sizes, aspect_ratios=self.aspect_ratios + ) + shifts = ops.reshape(shifts, (-1, 1, 4)) + base_anchors = ops.reshape(base_anchors, (1, -1, 4)) + + anchors = shifts + base_anchors + anchors = ops.reshape(anchors, (-1, 4)) + multilevel_anchors[f"P{level}"] = convert_format( + anchors, source="yxyx", target=self.bounding_box_format, ) - return multilevel_boxes + return multilevel_anchors + + def generate_base_anchors(self, sizes, aspect_ratios): + sizes = ops.convert_to_tensor(sizes, dtype="float32") + aspect_ratios = ops.convert_to_tensor(aspect_ratios) + h_ratios = ops.sqrt(aspect_ratios) + w_ratios = 1 / h_ratios + + ws = ops.reshape(w_ratios[:, None] * sizes[None, :], (-1,)) + hs = ops.reshape(h_ratios[:, None] * sizes[None, :], (-1,)) + + base_anchors = ops.stack([-1 * ws, -1 * hs, ws, hs], axis=1) / 2 + base_anchors = ops.round(base_anchors) + return base_anchors def compute_output_shape(self, input_shape): multilevel_boxes_shape = {} @@ -158,18 +160,11 @@ def compute_output_shape(self, input_shape): for i in range(self.min_level, self.max_level + 1): multilevel_boxes_shape[f"P{i}"] = ( - (image_height // 2 ** (i)) - * (image_width // 2 ** (i)) - * self.anchors_per_location, + int( + math.ceil(image_height / 2 ** (i)) + * math.ceil(image_width // 2 ** (i)) + * self.num_base_anchors + ), 4, ) return multilevel_boxes_shape - - @property - def anchors_per_location(self): - """ - The `anchors_per_location` property returns the number of anchors - generated per pixel location, which is equal to - `num_scales * len(aspect_ratios)`. - """ - return self.num_scales * len(self.aspect_ratios) diff --git a/keras_hub/src/models/retinanet/anchor_generator_test.py b/keras_hub/src/models/retinanet/anchor_generator_test.py index c843c32f27..dd76b7f2f4 100644 --- a/keras_hub/src/models/retinanet/anchor_generator_test.py +++ b/keras_hub/src/models/retinanet/anchor_generator_test.py @@ -2,7 +2,6 @@ from absl.testing import parameterized from keras import ops -from keras_hub.src.bounding_box.converters import convert_format from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator from keras_hub.src.tests.test_case import TestCase @@ -18,7 +17,7 @@ def test_layer_behaviors(self): "max_level": 7, "num_scales": 3, "aspect_ratios": [0.5, 1.0, 2.0], - "anchor_size": 8, + "anchor_size": 4, }, input_data=np.random.uniform(size=images_shape), expected_output_shape={ @@ -40,58 +39,13 @@ def test_layer_behaviors(self): + ( { "P5": [ - [-16.0, -16.0, 48.0, 48.0], - [-16.0, 16.0, 48.0, 80.0], - [16.0, -16.0, 80.0, 48.0], - [16.0, 16.0, 80.0, 80.0], + [-32.0, -32.0, 32.0, 32.0], + [0.0, -32.0, 64.0, 32.0], + [-32.0, 0.0, 32.0, 64.0], + [0.0, 0.0, 64.0, 64.0], ] }, ), - # Multi scale anchor - ("xywh", 5, 6, 1, [1.0], 2.0, [64, 64]) - + ( - { - "P5": [ - [-16.0, -16.0, 48.0, 48.0], - [-16.0, 16.0, 48.0, 80.0], - [16.0, -16.0, 80.0, 48.0], - [16.0, 16.0, 80.0, 80.0], - ], - "P6": [[-32, -32, 96, 96]], - }, - ), - # Multi aspect ratio anchor - ("xyxy", 6, 6, 1, [1.0, 4.0, 0.25], 2.0, [64, 64]) - + ( - { - "P6": [ - [-32.0, -32.0, 96.0, 96.0], - [0.0, -96.0, 64.0, 160.0], - [-96.0, 0.0, 160.0, 64.0], - ] - }, - ), - # Intermidate scales - ("yxyx", 5, 5, 2, [1.0], 1.0, [32, 32]) - + ( - { - "P5": [ - [0.0, 0.0, 32.0, 32.0], - [ - 16 - 16 * 2**0.5, - 16 - 16 * 2**0.5, - 16 + 16 * 2**0.5, - 16 + 16 * 2**0.5, - ], - ] - }, - ), - # Non-square - ("xywh", 5, 5, 1, [1.0], 1.0, [64, 32]) - + ({"P5": [[0, 0, 32, 32], [32, 0, 64, 32]]},), - # Indivisible by 2^level - ("xyxy", 5, 5, 1, [1.0], 1.0, [40, 32]) - + ({"P5": [[-6, 0, 26, 32], [14, 0, 46, 32]]},), ) def test_anchor_generator( self, @@ -116,9 +70,4 @@ def test_anchor_generator( multilevel_boxes = anchor_generator(images) for key in expected_boxes: expected_boxes[key] = ops.convert_to_tensor(expected_boxes[key]) - expected_boxes[key] = convert_format( - expected_boxes[key], - source="yxyx", - target=bounding_box_format, - ) self.assertAllClose(expected_boxes, multilevel_boxes) diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder.py b/keras_hub/src/models/retinanet/retinanet_label_encoder.py index 5e5cd6b23b..06efca3b61 100644 --- a/keras_hub/src/models/retinanet/retinanet_label_encoder.py +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder.py @@ -1,5 +1,6 @@ +import math + import keras -import keras.src from keras import ops from keras_hub.src.bounding_box.converters import _encode_box_to_deltas @@ -36,7 +37,7 @@ class RetinaNetLabelEncoder(keras.layers.Layer): Defaults to `0.4` box_variance: List[float]. The scaling factors used to scale the bounding box targets. - Defaults to `[0.1, 0.1, 0.2, 0.2]`. + Defaults to `[1.0, 1.0, 1.0, 1.0]`. background_class: int. The class ID used for the background class, Defaults to `-1`. ignore_class: int. The class ID used for the ignore class, @@ -60,7 +61,7 @@ def __init__( bounding_box_format, positive_threshold=0.5, negative_threshold=0.4, - box_variance=[0.1, 0.1, 0.2, 0.2], + box_variance=[1.0, 1.0, 1.0, 1.0], background_class=-1.0, ignore_class=-2.0, box_matcher_match_values=[-1, -2, 1], @@ -249,10 +250,10 @@ def compute_output_shape( total_num_anchors = 0 for i in range(min_level, max_level + 1): - total_num_anchors += ( - (image_H // 2 ** (i)) - * (image_W // 2 ** (i)) - * self.anchor_generator.anchors_per_location + total_num_anchors += int( + math.ceil(image_H / 2 ** (i)) + * math.ceil(image_W / 2 ** (i)) + * self.anchor_generator.num_base_anchors ) return (batch_size, total_num_anchors, 4), ( diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index 3839d791a1..21083d3381 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -87,7 +87,7 @@ def __init__( head_dtype = dtype or backbone.dtype_policy box_head = PredictionHead( - output_filters=anchor_generator.anchors_per_location * 4, + output_filters=anchor_generator.num_base_anchors * 4, num_conv_layers=4, num_filters=256, use_group_norm=use_prediction_head_norm, @@ -97,7 +97,7 @@ def __init__( name="box_head", ) classification_head = PredictionHead( - output_filters=anchor_generator.anchors_per_location * num_classes, + output_filters=anchor_generator.num_base_anchors * num_classes, num_conv_layers=4, num_filters=256, use_group_norm=use_prediction_head_norm, diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py index b321f6eef7..a75f7b69c5 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py @@ -45,7 +45,7 @@ def setUp(self): max_level=4, num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], - anchor_size=8, + anchor_size=4, ) label_encoder = RetinaNetLabelEncoder( bounding_box_format="yxyx", anchor_generator=anchor_generator From 8ac617c9d7ba1995d19af009aaa40c5d217eaa16 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 10 Oct 2024 14:29:43 -0700 Subject: [PATCH 26/35] correct the source format of anchors format --- keras_hub/src/models/retinanet/anchor_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/retinanet/anchor_generator.py b/keras_hub/src/models/retinanet/anchor_generator.py index 38fc44826e..dd8046c949 100644 --- a/keras_hub/src/models/retinanet/anchor_generator.py +++ b/keras_hub/src/models/retinanet/anchor_generator.py @@ -133,7 +133,7 @@ def call(self, inputs): anchors = ops.reshape(anchors, (-1, 4)) multilevel_anchors[f"P{level}"] = convert_format( anchors, - source="yxyx", + source="xyxy", target=self.bounding_box_format, ) return multilevel_anchors From 03efed50f60c87658a6a185e3f7ac9be46fd9223 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 11 Oct 2024 00:04:50 -0700 Subject: [PATCH 27/35] use plain rescaling and normalization no resizing for od models as it can effect the bounding boxes and the ops i backend framework dependent --- .../models/retinanet/anchor_generator_test.py | 2 +- .../retinanet/retinanet_image_converter.py | 44 +++++++++++++++++++ .../retinanet_object_detector_test.py | 8 ++-- 3 files changed, 49 insertions(+), 5 deletions(-) diff --git a/keras_hub/src/models/retinanet/anchor_generator_test.py b/keras_hub/src/models/retinanet/anchor_generator_test.py index dd76b7f2f4..0b71630843 100644 --- a/keras_hub/src/models/retinanet/anchor_generator_test.py +++ b/keras_hub/src/models/retinanet/anchor_generator_test.py @@ -40,8 +40,8 @@ def test_layer_behaviors(self): { "P5": [ [-32.0, -32.0, 32.0, 32.0], + [-32.0, 0, 32.0, 64.0], [0.0, -32.0, 64.0, 32.0], - [-32.0, 0.0, 32.0, 64.0], [0.0, 0.0, 64.0, 64.0], ] }, diff --git a/keras_hub/src/models/retinanet/retinanet_image_converter.py b/keras_hub/src/models/retinanet/retinanet_image_converter.py index b37091fd6f..b067419922 100644 --- a/keras_hub/src/models/retinanet/retinanet_image_converter.py +++ b/keras_hub/src/models/retinanet/retinanet_image_converter.py @@ -1,8 +1,52 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone +from keras_hub.src.utils.tensor_utils import preprocessing_function @keras_hub_export("keras_hub.layers.RetinaNetImageConverter") class RetinaNetImageConverter(ImageConverter): backbone_cls = RetinaNetBackbone + + def __init__( + self, + scale=None, + offset=None, + norm_mean=[0.485, 0.456, 0.406], + norm_std=[0.229, 0.224, 0.225], + **kwargs + ): + super().__init__(**kwargs) + self.scale = scale + self.offset = offset + self.norm_mean = norm_mean + self.norm_std = norm_std + self.built = True + + @preprocessing_function + def call(self, inputs): + x = inputs + # Rescaling Image + if self.scale is not None: + x = x * self._expand_non_channel_dims(self.scale, x) + if self.offset is not None: + x = x + self._expand_non_channel_dims(self.offset, x) + + # By default normalize using imagenet mean and std + if self.norm_mean: + x = x - self._expand_non_channel_dims(self.norm_mean, x) + + if self.norm_std: + x = x / self._expand_non_channel_dims(self.norm_std, x) + + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "norm_mean": self.norm_mean, + "norm_std": self.norm_std, + } + ) + return config diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py index a75f7b69c5..20f256e79e 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py @@ -1,10 +1,12 @@ import numpy as np import pytest -from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone +from keras_hub.src.models.retinanet.retinanet_image_converter import ( + RetinaNetImageConverter, +) from keras_hub.src.models.retinanet.retinanet_label_encoder import ( RetinaNetLabelEncoder, ) @@ -51,9 +53,7 @@ def setUp(self): bounding_box_format="yxyx", anchor_generator=anchor_generator ) - image_converter = ImageConverter( - image_size=(256, 256), - ) + image_converter = RetinaNetImageConverter(scale=1 / 255.0) preprocessor = RetinaNetObjectDetectorPreprocessor( image_converter=image_converter From 5704950e79fa484abb4f724c47ec79181837072a Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 11 Oct 2024 00:55:00 -0700 Subject: [PATCH 28/35] use single bbox format for model --- .../retinanet/retinanet_object_detector.py | 39 +++++-------------- .../retinanet_object_detector_test.py | 7 +--- 2 files changed, 12 insertions(+), 34 deletions(-) diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index 21083d3381..35f05c9df7 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -32,13 +32,7 @@ class RetinaNetObjectDetector(ImageObjectDetector): `RetinaNetObjectDetector` training targets. anchor_generator: A `keras_Hub.layers.AnchorGenerator`. num_classes: The number of object classes to be detected. - ground_truth_bounding_box_format: Ground truth bounding box format. - Refer TODO: https://github.com/keras-team/keras-hub/issues/1907 - Ensure that ground truth boxes follow one of the following formats. - - `rel_xyxy` - - `rel_yxyx` - - `rel_xywh` - target_bounding_box_format: Target bounding box format. + bounding_box_format: Dataset bounding box format. Refer TODO: https://github.com/keras-team/keras-hub/issues/1907 use_prediction_head_norm: bool. Whether to use Group Normalization after the convolution layers in prediction head. Defaults to `False`. @@ -62,8 +56,7 @@ def __init__( label_encoder, anchor_generator, num_classes, - ground_truth_bounding_box_format, - target_bounding_box_format, + bounding_box_format, use_prediction_head_norm=False, classification_head_prior_probability=0.01, preprocessor=None, @@ -72,16 +65,6 @@ def __init__( prediction_decoder=None, **kwargs, ): - if "rel" not in ground_truth_bounding_box_format: - raise ValueError( - f"Only relative bounding box formats are supported " - f"Received ground_truth_bounding_box_format=" - f"`{ground_truth_bounding_box_format}`. " - f"Please provide a `ground_truth_bounding_box_format` from one of " - f"the following `rel_xyxy` or `rel_yxyx` or `rel_xywh`. " - f"Ensure that the provided ground truth bounding boxes are " - f"normalized and relative to the image size. " - ) # === Layers === image_input = keras.layers.Input(backbone.image_shape, name="images") head_dtype = dtype or backbone.dtype_policy @@ -138,8 +121,7 @@ def __init__( ) # === Config === - self.ground_truth_bounding_box_format = ground_truth_bounding_box_format - self.target_bounding_box_format = target_bounding_box_format + self.bounding_box_format = bounding_box_format self.use_prediction_head_norm = use_prediction_head_norm self.num_classes = num_classes self.backbone = backbone @@ -151,13 +133,13 @@ def __init__( self.classification_head = classification_head self._prediction_decoder = prediction_decoder or NonMaxSuppression( from_logits=(activation != keras.activations.sigmoid), - bounding_box_format=self.target_bounding_box_format, + bounding_box_format=bounding_box_format, ) def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): y_for_label_encoder = convert_format( y, - source=self.ground_truth_bounding_box_format, + source=self.bounding_box_format, target=self.label_encoder.bounding_box_format, images=x, ) @@ -263,14 +245,14 @@ def decode_predictions(self, predictions, data): anchors=anchor_boxes, boxes_delta=box_pred, anchor_format=self.anchor_generator.bounding_box_format, - box_format=self.target_bounding_box_format, + box_format=self.bounding_box_format, variance=BOX_VARIANCE, image_shape=image_shape, ) - # box_pred is now in "self.target_bounding_box_format" format + # box_pred is now in "self.bounding_box_format" format box_pred = convert_format( box_pred, - source=self.target_bounding_box_format, + source=self.bounding_box_format, target=self.prediction_decoder.bounding_box_format, image_shape=image_shape, ) @@ -280,7 +262,7 @@ def decode_predictions(self, predictions, data): y_pred["boxes"] = convert_format( y_pred["boxes"], source=self.prediction_decoder.bounding_box_format, - target=self.target_bounding_box_format, + target=self.bounding_box_format, image_shape=image_shape, ) return y_pred @@ -290,9 +272,8 @@ def get_config(self): config.update( { "num_classes": self.num_classes, - "ground_truth_bounding_box_format": self.ground_truth_bounding_box_format, "use_prediction_head_norm": self.use_prediction_head_norm, - "target_bounding_box_format": self.target_bounding_box_format, + "bounding_box_format": self.bounding_box_format, "anchor_generator": keras.layers.serialize( self.anchor_generator ), diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py index 20f256e79e..079e027331 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py @@ -64,8 +64,7 @@ def setUp(self): "anchor_generator": anchor_generator, "label_encoder": label_encoder, "num_classes": 10, - "ground_truth_bounding_box_format": "rel_yxyx", - "target_bounding_box_format": "xywh", + "bounding_box_format": "yxyx", "preprocessor": preprocessor, } @@ -74,9 +73,7 @@ def setUp(self): low=0, high=255, size=(1, self.input_size, self.input_size, 3) ) self.labels = { - "boxes": np.array( - [[[0.2, 0.0, 0.12, 0.11], [0.3, 0.2, 0.4, 0.12]]] - ), + "boxes": np.array([[[20, 10, 12, 11], [30, 20, 40, 12]]]), "classes": np.array([[0, 2]]), } From 7c1d1de98dba5d6ae4265a4299b1712ae65016e7 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 11 Oct 2024 14:45:57 -0700 Subject: [PATCH 29/35] - Add arg for encoding format - Add required docstrings - Use `center_xywh` encoding for retinanet as per torch weights --- keras_hub/src/bounding_box/converters.py | 114 ++++++++++++++++-- keras_hub/src/models/image_object_detector.py | 5 +- .../image_object_detector_preprocessor.py | 17 --- .../retinanet/retinanet_label_encoder.py | 9 +- .../retinanet/retinanet_object_detector.py | 101 +++++++++------- .../retinanet_object_detector_test.py | 4 +- 6 files changed, 175 insertions(+), 75 deletions(-) diff --git a/keras_hub/src/bounding_box/converters.py b/keras_hub/src/bounding_box/converters.py index 263cd6df33..6371673520 100644 --- a/keras_hub/src/bounding_box/converters.py +++ b/keras_hub/src/bounding_box/converters.py @@ -20,29 +20,74 @@ class RequiresImagesException(Exception): ALL_AXES = 4 -def _encode_box_to_deltas( +def encode_box_to_deltas( anchors, boxes, - anchor_format: str, - box_format: str, + anchor_format, + box_format, + encoding_format="center_yxhw", variance=None, image_shape=None, ): - """Converts bounding_boxes from `center_yxhw` to delta format.""" + """Encodes bounding boxes to deltas relative to anchors. + + This function converts bounding boxes to delta format, representing the + difference between the boxes and the provided anchors. The boxes and + anchors are first converted to the specified `encoding_format` + (`center_yxhw` by default) before the delta calculation. This allows for + consistent delta representation regardless of the original box format. + + Args: + anchors: `Tensors`. Anchor boxes with shape of `(N, 4)` where N is the + number of anchors. + boxes: `Tensors` Bounding boxes to encode. Boxes can be of be shape + `(B, N, 4)` or `(N, 4)`. + anchor_format: str. The format of the input `anchors` + (e.g., "xyxy", "xywh", etc.). + box_format: str. The format of the input `boxes` + (e.g., "xyxy", "xywh", etc.). + encoding_format: str. The intermediate format to which boxes and anchors + are converted before delta calculation. Defaults to "center_yxhw". + variance: `List[float]`. A 4-element array/tensor representing variance + factors to scale the box deltas. If provided, the calculated deltas + are divided by the variance. Defaults to None. + image_shape: `Tuple[int]`. The shape of the image (height, width, 3). + This might be needed for normalization during format conversion. + Defaults to None. + + Returns: + Encoded box deltas. The return type matches the `encode_format`. + + Raises: + ValueError: If `variance` is not None and its length is not 4. + ValueError: If `encoding_format` is not `"center_xywh"` or + `"center_yxhw"`. + + """ if variance is not None: variance = ops.convert_to_tensor(variance, "float32") var_len = variance.shape[-1] if var_len != 4: raise ValueError(f"`variance` must be length 4, got {variance}") + + if encoding_format not in ["center_xywh", "center_yxhw"]: + raise ValueError( + "`encoding_format` should be one of 'center_xywh' or 'center_yxhw', " + f"got {encoding_format}" + ) + encoded_anchors = convert_format( anchors, source=anchor_format, - target="center_yxhw", + target=encoding_format, image_shape=image_shape, ) boxes = convert_format( - boxes, source=box_format, target="center_yxhw", image_shape=image_shape + boxes, + source=box_format, + target=encoding_format, + image_shape=image_shape, ) anchor_dimensions = ops.maximum( encoded_anchors[..., 2:], keras.backend.epsilon() @@ -61,15 +106,54 @@ def _encode_box_to_deltas( return boxes_delta -def _decode_deltas_to_boxes( +def decode_deltas_to_boxes( anchors, boxes_delta, - anchor_format: str, - box_format: str, + anchor_format, + box_format, + encoded_format="center_yxhw", variance=None, image_shape=None, ): - """Converts bounding_boxes from delta format to `center_yxhw`.""" + """Converts bounding boxes from delta format to the specified `box_format`. + + This function decodes bounding box deltas relative to anchors to obtain the + final bounding box coordinates. The boxes are encoded in a specific + `encoded_format` (center_yxhw by default) during the decoding process. + This allows flexibility in how the deltas are applied to the anchors. + + Args: + anchors: Can be `Tensors` or `Dict[Tensors]` where keys are level + indicies and values are corresponding anchor boxes. + The shape of the array/tensor should be `(N, 4)` where N is the + number of anchors. + boxes_delta Can be `Tensors` or `Dict[Tensors]` Bounding box deltas + must have the same type and structure as `anchors`. The + shape of the array/tensor can be `(N, 4)` or `(B, N, 4)` where N is + the number of boxes. + anchor_format: str. The format of the input `anchors`. + (e.g., `"xyxy"`, `"xywh"`, etc.) + box_format: str. The desired format for the output boxes. + (e.g., `"xyxy"`, `"xywh"`, etc.) + encoded_format: str. Raw output format from regression head. Defaults + to `"center_yxhw"`. + variance: `List[floats]`. A 4-element array/tensor representing + variance factors to scale the box deltas. If provided, the deltas + are multiplied by the variance before being applied to the anchors. + Defaults to None. + image_shape: The shape of the image (height, width). This is needed + if normalization to image size is required when converting between + formats. Defaults to None. + + Returns: + Decoded box coordinates. The return type matches the `box_format`. + + Raises: + ValueError: If `variance` is not None and its length is not 4. + ValueError: If `encoded_format` is not `"center_xywh"` or + `"center_yxhw"`. + + """ if variance is not None: variance = ops.convert_to_tensor(variance, "float32") var_len = variance.shape[-1] @@ -77,11 +161,17 @@ def _decode_deltas_to_boxes( if var_len != 4: raise ValueError(f"`variance` must be length 4, got {variance}") + if encoded_format not in ["center_xywh", "center_yxhw"]: + raise ValueError( + f"`encoded_format` should be 'center_xywh' or 'center_yxhw', " + f"but got '{encoded_format}'." + ) + def decode_single_level(anchor, box_delta): encoded_anchor = convert_format( anchor, source=anchor_format, - target="center_yxhw", + target=encoded_format, image_shape=image_shape, ) if variance is not None: @@ -97,7 +187,7 @@ def decode_single_level(anchor, box_delta): ) box = convert_format( box, - source="center_yxhw", + source=encoded_format, target=box_format, image_shape=image_shape, ) diff --git a/keras_hub/src/models/image_object_detector.py b/keras_hub/src/models/image_object_detector.py index 4016d7dff2..aa8a54dc3e 100644 --- a/keras_hub/src/models/image_object_detector.py +++ b/keras_hub/src/models/image_object_detector.py @@ -74,7 +74,10 @@ def compile( if metrics is not None: raise ValueError("User metrics not yet supported") - losses = {"box": box_loss, "classification": classification_loss} + losses = { + "bbox_regression": box_loss, + "cls_logits": classification_loss, + } super().compile( optimizer=optimizer, diff --git a/keras_hub/src/models/image_object_detector_preprocessor.py b/keras_hub/src/models/image_object_detector_preprocessor.py index a4eab8389c..581a10d6d9 100644 --- a/keras_hub/src/models/image_object_detector_preprocessor.py +++ b/keras_hub/src/models/image_object_detector_preprocessor.py @@ -40,23 +40,6 @@ class ImageObjectDetectorPreprocessor(Preprocessor): preprocessor = keras_hub.models.ImageObjectDetectorPreprocessor.from_preset( "retinanet_resnet50", ) - - # Resize a single image for resnet 50. - x = np.ones((512, 512, 3)) - x = preprocessor(x) - - # Resize a labeled image. - x, y = np.ones((512, 512, 3)), 1 - x, y = preprocessor(x, y) - - # Resize a batch of labeled images. - x, y = [np.ones((512, 512, 3)), np.zeros((512, 512, 3))], [1, 0] - x, y = preprocessor(x, y) - - # Use a `tf.data.Dataset`. - ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(2) - ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) - ``` """ def __init__( diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder.py b/keras_hub/src/models/retinanet/retinanet_label_encoder.py index 06efca3b61..ad32c45efc 100644 --- a/keras_hub/src/models/retinanet/retinanet_label_encoder.py +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder.py @@ -3,8 +3,8 @@ import keras from keras import ops -from keras_hub.src.bounding_box.converters import _encode_box_to_deltas from keras_hub.src.bounding_box.converters import convert_format +from keras_hub.src.bounding_box.converters import encode_box_to_deltas from keras_hub.src.bounding_box.iou import compute_iou from keras_hub.src.models.retinanet.box_matcher import BoxMatcher from keras_hub.src.utils import tensor_utils @@ -28,6 +28,7 @@ class RetinaNetLabelEncoder(keras.layers.Layer): Args: anchor_generator: A `keras_hub.layers.AnchorGenerator`. bounding_box_format: str. Ground truth format of bounding boxes. + encoding_format: str. The desired target encoding format for the boxes. TODO: https://github.com/keras-team/keras-hub/issues/1907 positive_threshold: float. the threshold to set an anchor to positive match to gt box. Values above it are positive matches. @@ -59,6 +60,7 @@ def __init__( self, anchor_generator, bounding_box_format, + encoding_format="center_yxhw", positive_threshold=0.5, negative_threshold=0.4, box_variance=[1.0, 1.0, 1.0, 1.0], @@ -71,6 +73,7 @@ def __init__( super().__init__(**kwargs) self.anchor_generator = anchor_generator self.bounding_box_format = bounding_box_format + self.encoding_format = encoding_format self.positive_threshold = positive_threshold self.box_variance = box_variance self.negative_threshold = negative_threshold @@ -175,11 +178,12 @@ def _encode_sample(self, gt_boxes, gt_classes, anchor_boxes, image_shape): matched_gt_boxes, (-1, ops.shape(matched_gt_boxes)[1], 4) ) - box_target = _encode_box_to_deltas( + box_target = encode_box_to_deltas( anchors=anchor_boxes, boxes=matched_gt_boxes, anchor_format=self.bounding_box_format, box_format=self.bounding_box_format, + encoding_format=self.encoding_format, variance=self.box_variance, image_shape=image_shape, ) @@ -220,6 +224,7 @@ def get_config(self): self.anchor_generator ), "bounding_box_format": self.bounding_box_format, + "encoding_format": self.encoding_format, "positive_threshold": self.positive_threshold, "box_variance": self.box_variance, "negative_threshold": self.negative_threshold, diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index 35f05c9df7..2473b00215 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -2,18 +2,19 @@ from keras import ops from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.bounding_box.converters import _decode_deltas_to_boxes from keras_hub.src.bounding_box.converters import convert_format +from keras_hub.src.bounding_box.converters import decode_deltas_to_boxes from keras_hub.src.models.image_object_detector import ImageObjectDetector from keras_hub.src.models.retinanet.non_max_supression import NonMaxSuppression from keras_hub.src.models.retinanet.prediction_head import PredictionHead from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone +from keras_hub.src.models.retinanet.retinanet_label_encoder import ( + RetinaNetLabelEncoder, +) from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( RetinaNetObjectDetectorPreprocessor, ) -BOX_VARIANCE = [1.0, 1.0, 1.0, 1.0] - @keras_hub_export("keras_hub.models.RetinaNetObjectDetector") class RetinaNetObjectDetector(ImageObjectDetector): @@ -21,30 +22,38 @@ class RetinaNetObjectDetector(ImageObjectDetector): This class implements the RetinaNet object detection architecture. It consists of a feature extractor backbone, a feature pyramid network(FPN), - and two prediction heads for classification and regression. + and two prediction heads (for classification and bounding box regression). Args: - backbone: `keras.Model`. A `keras.models.RetinaNetBackbone` class, defining the - backbone network architecture. - label_encoder: `keras.layers.Layer`. A `RetinaNetLabelEncoder` class - that accepts an image Tensor, a bounding box Tensor and a bounding - box class Tensor to its `call()` method, and returns - `RetinaNetObjectDetector` training targets. - anchor_generator: A `keras_Hub.layers.AnchorGenerator`. - num_classes: The number of object classes to be detected. - bounding_box_format: Dataset bounding box format. + backbone: `keras.Model`. A `keras.models.RetinaNetBackbone` class, + defining the backbone network architecture. Provides feature maps + for detection. + anchor_generator: A `keras_hub.layers.AnchorGenerator` instance. + Generates anchor boxes at different scales and aspect ratios + across the image. + num_classes: int. The number of object classes to be detected. + bounding_box_format: str. Dataset bounding box format (e.g., "xyxy", + "yxyx"). Refer TODO: https://github.com/keras-team/keras-hub/issues/1907 + label_encoder: Optional. A `RetinaNetLabelEncoder` instance. Encodes + ground truth boxes and classes into training targets for RetinaNet. + If None,a default encoder is created. use_prediction_head_norm: bool. Whether to use Group Normalization after - the convolution layers in prediction head. Defaults to `False`. - preprocessor: Optional. An instance of the - `RetinaNetObjectDetectorPreprocessor` class or a custom preprocessor. + the convolution layers in the prediction heads. Defaults to `False`. + classification_head_prior_probability: float. Prior probability for the + classification head (used for focal loss). Defaults to 0.01. + preprocessor: Optional. An instance of + `RetinaNetObjectDetectorPreprocessor`or a custom preprocessor. + Handles image preprocessing before feeding into the backbone. activation: Optional. The activation function to be used in the - classification head. - dtype: Optional. The data type for the prediction heads. - prediction_decoder: Optional. A `keras.layers.Layer` that is - responsible for transforming RetinaNet predictions into usable - bounding box Tensors. - Defaults to `NonMaxSuppression` class instance. + classification head. If None, sigmoid is used. + dtype: Optional. The data type for the prediction heads. Defaults to the + backbone's dtype policy. + prediction_decoder: Optional. A `keras.layers.Layer` instance + responsible for transforming RetinaNet predictions + (box regressions and classifications) into final bounding boxes and + classes with confidence scores. Defaults to a `NonMaxSuppression` + instance. """ backbone_cls = RetinaNetBackbone @@ -53,10 +62,10 @@ class RetinaNetObjectDetector(ImageObjectDetector): def __init__( self, backbone, - label_encoder, anchor_generator, num_classes, bounding_box_format, + label_encoder=None, use_prediction_head_norm=False, classification_head_prior_probability=0.01, preprocessor=None, @@ -93,6 +102,8 @@ def __init__( cls_pred = [] box_pred = [] + + # Iterate through the feature pyramid levels (e.g., P3, P4, P5, P6, P7). for level in feature_map: box_pred.append( keras.layers.Reshape((-1, 4), name=f"box_pred_{level}")( @@ -105,14 +116,15 @@ def __init__( )(classification_head(feature_map[level])) ) - cls_pred = keras.layers.Concatenate(axis=1, name="classification")( - cls_pred - ) - # box_pred is always in "center_yxhw" delta-encoded no matter what + # Concatenate predictions from all FPN levels. + cls_pred = keras.layers.Concatenate(axis=1, name="cls_logits")(cls_pred) + # box_pred is always in "center_xywh" delta-encoded no matter what # format you pass in. - box_pred = keras.layers.Concatenate(axis=1, name="box")(box_pred) + box_pred = keras.layers.Concatenate(axis=1, name="bbox_regression")( + box_pred + ) - outputs = {"box": box_pred, "classification": cls_pred} + outputs = {"bbox_regression": box_pred, "cls_logits": cls_pred} super().__init__( inputs=image_input, @@ -126,11 +138,17 @@ def __init__( self.num_classes = num_classes self.backbone = backbone self.preprocessor = preprocessor - self.label_encoder = label_encoder self.anchor_generator = anchor_generator self.activation = activation self.box_head = box_head self.classification_head = classification_head + # As weights are ported from torch they use encoded format + # as "center_xywh" + self.label_encoder = label_encoder or RetinaNetLabelEncoder( + anchor_generator, + bounding_box_format=bounding_box_format, + encoding_format="center_xywh", + ) self._prediction_decoder = prediction_decoder or NonMaxSuppression( from_logits=(activation != keras.activations.sigmoid), bounding_box_format=bounding_box_format, @@ -150,8 +168,8 @@ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): gt_classes=y_for_label_encoder["classes"], ) - box_pred = y_pred["box"] - cls_pred = y_pred["classification"] + box_pred = y_pred["bbox_regression"] + cls_pred = y_pred["cls_logits"] if boxes.shape[-1] != 4: raise ValueError( @@ -184,16 +202,16 @@ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): box_weights = positive_mask / normalizer y_true = { - "box": boxes, - "classification": cls_labels, + "bbox_regression": boxes, + "cls_logits": cls_labels, } sample_weights = { - "box": box_weights, - "classification": cls_weights, + "bbox_regression": box_weights, + "cls_logits": cls_weights, } zero_weight = { - "box": ops.zeros_like(box_weights), - "classification": ops.zeros_like(cls_weights), + "bbox_regression": ops.zeros_like(box_weights), + "cls_logits": ops.zeros_like(cls_weights), } sample_weight = ops.cond( @@ -232,7 +250,8 @@ def prediction_decoder(self, prediction_decoder): self.make_test_function(force=True) def decode_predictions(self, predictions, data): - box_pred, cls_pred = predictions["box"], predictions["classification"] + box_pred = predictions["bbox_regression"] + cls_pred = predictions["cls_logits"] # box_pred is on "center_yxhw" format, convert to target format. if isinstance(data, list) or isinstance(data, tuple): images, _ = data @@ -241,12 +260,12 @@ def decode_predictions(self, predictions, data): image_shape = ops.shape(images)[1:] anchor_boxes = self.anchor_generator(images) anchor_boxes = ops.concatenate(list(anchor_boxes.values()), axis=0) - box_pred = _decode_deltas_to_boxes( + box_pred = decode_deltas_to_boxes( anchors=anchor_boxes, boxes_delta=box_pred, + encoded_format="center_xywh", anchor_format=self.anchor_generator.bounding_box_format, box_format=self.bounding_box_format, - variance=BOX_VARIANCE, image_shape=image_shape, ) # box_pred is now in "self.bounding_box_format" format diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py index 079e027331..1592595bd3 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py @@ -71,9 +71,9 @@ def setUp(self): self.input_size = 512 self.images = np.random.uniform( low=0, high=255, size=(1, self.input_size, self.input_size, 3) - ) + ).astype("float32") self.labels = { - "boxes": np.array([[[20, 10, 12, 11], [30, 20, 40, 12]]]), + "boxes": np.array([[[20.0, 10.0, 12, 11], [30, 20, 40, 12]]]), "classes": np.array([[0, 2]]), } From 2414f0083f1f7d8c0664d870f454af94695aa74d Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 11 Oct 2024 15:27:59 -0700 Subject: [PATCH 30/35] make anchor generator optional --- .../retinanet/retinanet_object_detector.py | 43 ++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index 2473b00215..83a6788587 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -5,6 +5,7 @@ from keras_hub.src.bounding_box.converters import convert_format from keras_hub.src.bounding_box.converters import decode_deltas_to_boxes from keras_hub.src.models.image_object_detector import ImageObjectDetector +from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator from keras_hub.src.models.retinanet.non_max_supression import NonMaxSuppression from keras_hub.src.models.retinanet.prediction_head import PredictionHead from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone @@ -30,14 +31,39 @@ class RetinaNetObjectDetector(ImageObjectDetector): for detection. anchor_generator: A `keras_hub.layers.AnchorGenerator` instance. Generates anchor boxes at different scales and aspect ratios - across the image. + across the image. If None, a default `AnchorGenerator` is + created with the following parameters: + - `bounding_box_format`: Same as the model's + `bounding_box_format`. + - `min_level`: The backbone's `min_level`. + - `max_level`: The backbone's `max_level`. + - `num_scales`: 3. + - `aspect_ratios`: [0.5, 1.0, 2.0]. + - `anchor_size`: 4.0. + You can create a custom `AnchorGenerator` by instantiating the + `keras_hub.layers.AnchorGenerator` class and passing the desired + arguments. num_classes: int. The number of object classes to be detected. bounding_box_format: str. Dataset bounding box format (e.g., "xyxy", "yxyx"). Refer TODO: https://github.com/keras-team/keras-hub/issues/1907 label_encoder: Optional. A `RetinaNetLabelEncoder` instance. Encodes - ground truth boxes and classes into training targets for RetinaNet. - If None,a default encoder is created. + ground truth boxes and classes into training targets. It matches + ground truth boxes to anchors based on IoU and encodes box + coordinates as offsets. If `None`, a default encoder is created. + See the + `keras_hub.src.models.retinanet.retinanet_label_encoder.RetinaNetLabelEncoder` + class for details. If None, a default encoder is created with + standard parameters. + - `anchor_generator`: Same as the model's. + - `bounding_box_format`: Same as the model's + `bounding_box_format`. + - `positive_threshold`: 0.5 + - `negative_threshold`: 0.4 + - `encoding_format`: "center_xywh" + - `box_variance`: [1.0, 1.0, 1.0, 1.0] + - `background_class`: -1 + - `ignore_class`: -2 use_prediction_head_norm: bool. Whether to use Group Normalization after the convolution layers in the prediction heads. Defaults to `False`. classification_head_prior_probability: float. Prior probability for the @@ -62,9 +88,9 @@ class RetinaNetObjectDetector(ImageObjectDetector): def __init__( self, backbone, - anchor_generator, num_classes, bounding_box_format, + anchor_generator=None, label_encoder=None, use_prediction_head_norm=False, classification_head_prior_probability=0.01, @@ -138,10 +164,17 @@ def __init__( self.num_classes = num_classes self.backbone = backbone self.preprocessor = preprocessor - self.anchor_generator = anchor_generator self.activation = activation self.box_head = box_head self.classification_head = classification_head + self.anchor_generator = anchor_generator or AnchorGenerator( + self.bounding_box_format, + min_level=backbone.min_level, + max_level=backbone.max_level, + num_scales=3, + aspect_ratios=[0.5, 1.0, 2.0], + anchor_size=4, + ) # As weights are ported from torch they use encoded format # as "center_xywh" self.label_encoder = label_encoder or RetinaNetLabelEncoder( From 064c971f984fa0f76ffa076034f8669f41a3297b Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 11 Oct 2024 15:53:11 -0700 Subject: [PATCH 31/35] init as layers for anchor generator and label encoder and as one more arg for prediction head configuration --- .../retinanet/retinanet_object_detector.py | 44 ++++++++++++------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index 83a6788587..cd80405213 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -68,6 +68,10 @@ class for details. If None, a default encoder is created with the convolution layers in the prediction heads. Defaults to `False`. classification_head_prior_probability: float. Prior probability for the classification head (used for focal loss). Defaults to 0.01. + pre_logits_num_conv_layers: int. Number of convolutional layers in the + head before the logits layer. These convolutional layers are applied + before the final linear layer (logits) that produces the output + predictions (bounding box regressions, classification scores). preprocessor: Optional. An instance of `RetinaNetObjectDetectorPreprocessor`or a custom preprocessor. Handles image preprocessing before feeding into the backbone. @@ -94,6 +98,7 @@ def __init__( label_encoder=None, use_prediction_head_norm=False, classification_head_prior_probability=0.01, + pre_logits_num_conv_layers=4, preprocessor=None, activation=None, dtype=None, @@ -104,9 +109,25 @@ def __init__( image_input = keras.layers.Input(backbone.image_shape, name="images") head_dtype = dtype or backbone.dtype_policy + anchor_generator = anchor_generator or AnchorGenerator( + bounding_box_format, + min_level=backbone.min_level, + max_level=backbone.max_level, + num_scales=3, + aspect_ratios=[0.5, 1.0, 2.0], + anchor_size=4, + ) + # As weights are ported from torch they use encoded format + # as "center_xywh" + label_encoder = label_encoder or RetinaNetLabelEncoder( + anchor_generator, + bounding_box_format=bounding_box_format, + encoding_format="center_xywh", + ) + box_head = PredictionHead( output_filters=anchor_generator.num_base_anchors * 4, - num_conv_layers=4, + num_conv_layers=pre_logits_num_conv_layers, num_filters=256, use_group_norm=use_prediction_head_norm, use_prior_probability=True, @@ -116,7 +137,7 @@ def __init__( ) classification_head = PredictionHead( output_filters=anchor_generator.num_base_anchors * num_classes, - num_conv_layers=4, + num_conv_layers=pre_logits_num_conv_layers, num_filters=256, use_group_norm=use_prediction_head_norm, dtype=head_dtype, @@ -165,23 +186,11 @@ def __init__( self.backbone = backbone self.preprocessor = preprocessor self.activation = activation + self.pre_logits_num_conv_layers = pre_logits_num_conv_layers self.box_head = box_head self.classification_head = classification_head - self.anchor_generator = anchor_generator or AnchorGenerator( - self.bounding_box_format, - min_level=backbone.min_level, - max_level=backbone.max_level, - num_scales=3, - aspect_ratios=[0.5, 1.0, 2.0], - anchor_size=4, - ) - # As weights are ported from torch they use encoded format - # as "center_xywh" - self.label_encoder = label_encoder or RetinaNetLabelEncoder( - anchor_generator, - bounding_box_format=bounding_box_format, - encoding_format="center_xywh", - ) + self.anchor_generator = anchor_generator + self.label_encoder = label_encoder self._prediction_decoder = prediction_decoder or NonMaxSuppression( from_logits=(activation != keras.activations.sigmoid), bounding_box_format=bounding_box_format, @@ -325,6 +334,7 @@ def get_config(self): { "num_classes": self.num_classes, "use_prediction_head_norm": self.use_prediction_head_norm, + "pre_logits_num_conv_layers": self.pre_logits_num_conv_layers, "bounding_box_format": self.bounding_box_format, "anchor_generator": keras.layers.serialize( self.anchor_generator From 4ff8f13b176f3c464a9b7e34b266385d4d9f22a6 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 11 Oct 2024 16:27:42 -0700 Subject: [PATCH 32/35] nit --- .../models/retinanet/retinanet_backbone.py | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/keras_hub/src/models/retinanet/retinanet_backbone.py b/keras_hub/src/models/retinanet/retinanet_backbone.py index 72888c5099..6be9b8ffff 100644 --- a/keras_hub/src/models/retinanet/retinanet_backbone.py +++ b/keras_hub/src/models/retinanet/retinanet_backbone.py @@ -14,19 +14,31 @@ class RetinaNetBackbone(FeaturePyramidBackbone): network (FPN)to extract multi-scale features for object detection. Args: - image_encoder: `keras.Model`. The backbone model used to extract features - from the input image. - It should have pyramid outputs. - min_level: int. The minimum feature pyramid level. - max_level: int. The maximum feature pyramid level. - use_p5: bool. If True, uses the output of the last layer (`P5` from - Feature Pyramid Network) as input for creating coarser convolution - layers (`P6`, `P7`). If False, uses the direct input `P5` - for creating coarser convolution layers. - image_shape: tuple. The shape of the input image. - data_format: str. The data format of the input image (channels_first or channels_last). + image_encoder: `keras.Model`. The backbone model (e.g., ResNet50, + MobileNetV2) used to extract features from the input image. + It should have pyramid outputs (i.e., a dictionary mapping level + names like `"P2"`, `"P3"`, etc. to their corresponding feature + tensors). + min_level: int. The minimum level of the feature pyramid (e.g., 3). + This determines the coarsest level of features used. + max_level: int. The maximum level of the feature pyramid (e.g., 7). + This determines the finest level of features used. + use_p5: bool. Determines the input source for creating coarser + feature pyramid levels. If `True`, the output of the last backbone + layer (typically `'P5'` in an FPN) is used as input to create + higher-level feature maps (e.g., `'P6'`, `'P7'`) through + additional convolutional layers. If `False`, the original `'P5'` + feature map from the backbone is directly used as input for + creating the coarser levels, bypassing any further processing of + `'P5'` within the feature pyramid. Defaults to `False`. + use_fpn_batch_norm: bool. Whether to use batch normalization in the + feature pyramid network. Defaults to `False`. + image_shape: tuple. tuple. The shape of the input image (H, W, C). + The height and width can be `None` if they are variable. + data_format: str. The data format of the input image + (channels_first or channels_last). dtype: str. The data type of the input image. - **kwargs: Additional arguments passed to the base class. + **kwargs: Additional keword arguments passed to the base class. Raises: ValueError: If `min_level` is greater than `max_level`. From c4f752dd0e0f996fdaa6f02b152bc31e6baffcf8 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Sat, 12 Oct 2024 13:02:30 -0700 Subject: [PATCH 33/35] - only consider levels from min level to backbone maxlevel fro feature extraction from image encoder --- keras_hub/src/models/retinanet/retinanet_backbone.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/retinanet/retinanet_backbone.py b/keras_hub/src/models/retinanet/retinanet_backbone.py index 6be9b8ffff..b20672b7f7 100644 --- a/keras_hub/src/models/retinanet/retinanet_backbone.py +++ b/keras_hub/src/models/retinanet/retinanet_backbone.py @@ -81,7 +81,7 @@ def __init__( inputs=image_encoder.inputs, outputs={ f"P{level}": image_encoder.pyramid_outputs[f"P{level}"] - for level in input_levels + for level in range(min_level, backbone_max_level + 1) }, name="backbone", ) @@ -105,6 +105,7 @@ def __init__( inputs=image_input, outputs=feature_pyramid_outputs, dtype=dtype, + name="retinanet_backbone", **kwargs, ) From bde84b941d3a0fce5d01cffa74ef38698bfe992c Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Sat, 12 Oct 2024 13:47:02 -0700 Subject: [PATCH 34/35] nit --- keras_hub/src/models/retinanet/retinanet_backbone.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras_hub/src/models/retinanet/retinanet_backbone.py b/keras_hub/src/models/retinanet/retinanet_backbone.py index b20672b7f7..b43f30264d 100644 --- a/keras_hub/src/models/retinanet/retinanet_backbone.py +++ b/keras_hub/src/models/retinanet/retinanet_backbone.py @@ -105,7 +105,6 @@ def __init__( inputs=image_input, outputs=feature_pyramid_outputs, dtype=dtype, - name="retinanet_backbone", **kwargs, ) From caacc9921771135269aa29467dc696a82f2fa3b6 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Tue, 15 Oct 2024 13:21:07 -0700 Subject: [PATCH 35/35] nit --- keras_hub/src/bounding_box/converters.py | 18 +++++++-------- .../src/models/retinanet/prediction_head.py | 22 +++++++++---------- .../models/retinanet/retinanet_backbone.py | 1 + .../retinanet/retinanet_label_encoder.py | 10 ++++----- .../retinanet/retinanet_object_detector.py | 12 +++++----- .../retinanet_object_detector_test.py | 1 - 6 files changed, 32 insertions(+), 32 deletions(-) diff --git a/keras_hub/src/bounding_box/converters.py b/keras_hub/src/bounding_box/converters.py index 6371673520..92ef27c15d 100644 --- a/keras_hub/src/bounding_box/converters.py +++ b/keras_hub/src/bounding_box/converters.py @@ -29,13 +29,14 @@ def encode_box_to_deltas( variance=None, image_shape=None, ): - """Encodes bounding boxes to deltas relative to anchors. + """Encodes bounding boxes relative to anchors as deltas. - This function converts bounding boxes to delta format, representing the - difference between the boxes and the provided anchors. The boxes and - anchors are first converted to the specified `encoding_format` - (`center_yxhw` by default) before the delta calculation. This allows for - consistent delta representation regardless of the original box format. + This function calculates the deltas that represent the difference between + bounding boxes and provided anchors. Deltas encode the offsets and scaling + factors to apply to anchors to obtain the target boxes. + + Boxes and anchors are first converted to the specified `encoding_format` + (defaulting to `center_yxhw`) for consistent delta representation. Args: anchors: `Tensors`. Anchor boxes with shape of `(N, 4)` where N is the @@ -52,9 +53,8 @@ def encode_box_to_deltas( factors to scale the box deltas. If provided, the calculated deltas are divided by the variance. Defaults to None. image_shape: `Tuple[int]`. The shape of the image (height, width, 3). - This might be needed for normalization during format conversion. - Defaults to None. - + When using relative bounding box format for `box_format` the + `image_shape` is used for normalization. Returns: Encoded box deltas. The return type matches the `encode_format`. diff --git a/keras_hub/src/models/retinanet/prediction_head.py b/keras_hub/src/models/retinanet/prediction_head.py index 1581d901cb..007d4f32bd 100644 --- a/keras_hub/src/models/retinanet/prediction_head.py +++ b/keras_hub/src/models/retinanet/prediction_head.py @@ -4,24 +4,24 @@ class PredictionHead(keras.layers.Layer): - """The classification/box predictions head. + """A head for classification or bounding box regression predictions. Args: - output_filters: int. Number of convolution filters in the final layer. + output_filters: int. The umber of convolution filters in the final layer. The number of output channels determines the prediction type: - **Classification**: `output_filters = num_anchors * num_classes` Predicts class probabilities for each anchor. - **Bounding Box Regression**: - `output_filters = num_anchors * 4` - Predicts bounding box offsets (x1, y1, x2, y2) for each anchor. - num_filters: int. Number of convolution filters used in base layers. - Defaults to `256`. - num_conv_layers: int. Number of convolution layers before final layer. - Defaults to `4`. - use_prior_probability: bool. Whether to use prior probability in the - bias initializer for the final convolution layer. Defaults to - `False`. + `output_filters = num_anchors * 4` Predicts bounding box + offsets (x1, y1, x2, y2) for each anchor. + num_filters: int. The number of convolution filters to use in the base + layer. + num_conv_layers: int. The number of convolution layers before the final + layer. + use_prior_probability: bool. Set to True to use prior probability in the + bias initializer for the final convolution layer. + Defaults to `False`. prior_probability: float. The prior probability value to use for initializing the bias. Only used if `use_prior_probability` is `True`. Defaults to `0.01`. diff --git a/keras_hub/src/models/retinanet/retinanet_backbone.py b/keras_hub/src/models/retinanet/retinanet_backbone.py index b43f30264d..c6ebff9ef2 100644 --- a/keras_hub/src/models/retinanet/retinanet_backbone.py +++ b/keras_hub/src/models/retinanet/retinanet_backbone.py @@ -58,6 +58,7 @@ def __init__( **kwargs, ): + # === Layers === if min_level > max_level: raise ValueError( f"Minimum level ({min_level}) must be less than or equal to " diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder.py b/keras_hub/src/models/retinanet/retinanet_label_encoder.py index ad32c45efc..0c6a972d17 100644 --- a/keras_hub/src/models/retinanet/retinanet_label_encoder.py +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder.py @@ -178,7 +178,7 @@ def _encode_sample(self, gt_boxes, gt_classes, anchor_boxes, image_shape): matched_gt_boxes, (-1, ops.shape(matched_gt_boxes)[1], 4) ) - box_target = encode_box_to_deltas( + box_targets = encode_box_to_deltas( anchors=anchor_boxes, boxes=matched_gt_boxes, anchor_format=self.bounding_box_format, @@ -191,16 +191,16 @@ def _encode_sample(self, gt_boxes, gt_classes, anchor_boxes, image_shape): matched_gt_cls_ids = tensor_utils.target_gather( gt_classes, matched_gt_idx ) - cls_target = ops.where( + classs_targets = ops.where( ops.not_equal(positive_mask, 1.0), self.background_class, matched_gt_cls_ids, ) - cls_target = ops.where( - ops.equal(ignore_mask, 1.0), self.ignore_class, cls_target + classs_targets = ops.where( + ops.equal(ignore_mask, 1.0), self.ignore_class, classs_targets ) label = ops.concatenate( - [box_target, ops.cast(cls_target, box_target.dtype)], axis=-1 + [box_targets, ops.cast(classs_targets, box_targets.dtype)], axis=-1 ) # In the case that a box in the corner of an image matches with an all diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index cd80405213..b55c0f69de 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -45,8 +45,8 @@ class RetinaNetObjectDetector(ImageObjectDetector): arguments. num_classes: int. The number of object classes to be detected. bounding_box_format: str. Dataset bounding box format (e.g., "xyxy", - "yxyx"). - Refer TODO: https://github.com/keras-team/keras-hub/issues/1907 + "yxyx"). The supported formats are + refer TODO: https://github.com/keras-team/keras-hub/issues/1907 label_encoder: Optional. A `RetinaNetLabelEncoder` instance. Encodes ground truth boxes and classes into training targets. It matches ground truth boxes to anchors based on IoU and encodes box @@ -68,10 +68,10 @@ class for details. If None, a default encoder is created with the convolution layers in the prediction heads. Defaults to `False`. classification_head_prior_probability: float. Prior probability for the classification head (used for focal loss). Defaults to 0.01. - pre_logits_num_conv_layers: int. Number of convolutional layers in the - head before the logits layer. These convolutional layers are applied - before the final linear layer (logits) that produces the output - predictions (bounding box regressions, classification scores). + pre_logits_num_conv_layers: int. The number of convolutional layers in + the head before the logits layer. These convolutional layers are + applied before the final linear layer (logits) that produces the + output predictions (bounding box regressions, classification scores). preprocessor: Optional. An instance of `RetinaNetObjectDetectorPreprocessor`or a custom preprocessor. Handles image preprocessing before feeding into the backbone. diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py index 1592595bd3..4441d2765b 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py @@ -79,7 +79,6 @@ def setUp(self): self.train_data = (self.images, self.labels) - @pytest.mark.large def test_detection_basics(self): self.run_task_test( cls=RetinaNetObjectDetector,