diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py index 83a6788587..cd80405213 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -68,6 +68,10 @@ class for details. If None, a default encoder is created with the convolution layers in the prediction heads. Defaults to `False`. classification_head_prior_probability: float. Prior probability for the classification head (used for focal loss). Defaults to 0.01. + pre_logits_num_conv_layers: int. Number of convolutional layers in the + head before the logits layer. These convolutional layers are applied + before the final linear layer (logits) that produces the output + predictions (bounding box regressions, classification scores). preprocessor: Optional. An instance of `RetinaNetObjectDetectorPreprocessor`or a custom preprocessor. Handles image preprocessing before feeding into the backbone. @@ -94,6 +98,7 @@ def __init__( label_encoder=None, use_prediction_head_norm=False, classification_head_prior_probability=0.01, + pre_logits_num_conv_layers=4, preprocessor=None, activation=None, dtype=None, @@ -104,9 +109,25 @@ def __init__( image_input = keras.layers.Input(backbone.image_shape, name="images") head_dtype = dtype or backbone.dtype_policy + anchor_generator = anchor_generator or AnchorGenerator( + bounding_box_format, + min_level=backbone.min_level, + max_level=backbone.max_level, + num_scales=3, + aspect_ratios=[0.5, 1.0, 2.0], + anchor_size=4, + ) + # As weights are ported from torch they use encoded format + # as "center_xywh" + label_encoder = label_encoder or RetinaNetLabelEncoder( + anchor_generator, + bounding_box_format=bounding_box_format, + encoding_format="center_xywh", + ) + box_head = PredictionHead( output_filters=anchor_generator.num_base_anchors * 4, - num_conv_layers=4, + num_conv_layers=pre_logits_num_conv_layers, num_filters=256, use_group_norm=use_prediction_head_norm, use_prior_probability=True, @@ -116,7 +137,7 @@ def __init__( ) classification_head = PredictionHead( output_filters=anchor_generator.num_base_anchors * num_classes, - num_conv_layers=4, + num_conv_layers=pre_logits_num_conv_layers, num_filters=256, use_group_norm=use_prediction_head_norm, dtype=head_dtype, @@ -165,23 +186,11 @@ def __init__( self.backbone = backbone self.preprocessor = preprocessor self.activation = activation + self.pre_logits_num_conv_layers = pre_logits_num_conv_layers self.box_head = box_head self.classification_head = classification_head - self.anchor_generator = anchor_generator or AnchorGenerator( - self.bounding_box_format, - min_level=backbone.min_level, - max_level=backbone.max_level, - num_scales=3, - aspect_ratios=[0.5, 1.0, 2.0], - anchor_size=4, - ) - # As weights are ported from torch they use encoded format - # as "center_xywh" - self.label_encoder = label_encoder or RetinaNetLabelEncoder( - anchor_generator, - bounding_box_format=bounding_box_format, - encoding_format="center_xywh", - ) + self.anchor_generator = anchor_generator + self.label_encoder = label_encoder self._prediction_decoder = prediction_decoder or NonMaxSuppression( from_logits=(activation != keras.activations.sigmoid), bounding_box_format=bounding_box_format, @@ -325,6 +334,7 @@ def get_config(self): { "num_classes": self.num_classes, "use_prediction_head_norm": self.use_prediction_head_norm, + "pre_logits_num_conv_layers": self.pre_logits_num_conv_layers, "bounding_box_format": self.bounding_box_format, "anchor_generator": keras.layers.serialize( self.anchor_generator