Skip to content

Commit

Permalink
init as layers for anchor generator and label encoder and as one more…
Browse files Browse the repository at this point in the history
… arg for prediction head configuration
  • Loading branch information
sineeli committed Oct 11, 2024
1 parent 2414f00 commit 064c971
Showing 1 changed file with 27 additions and 17 deletions.
44 changes: 27 additions & 17 deletions keras_hub/src/models/retinanet/retinanet_object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class for details. If None, a default encoder is created with
the convolution layers in the prediction heads. Defaults to `False`.
classification_head_prior_probability: float. Prior probability for the
classification head (used for focal loss). Defaults to 0.01.
pre_logits_num_conv_layers: int. Number of convolutional layers in the
head before the logits layer. These convolutional layers are applied
before the final linear layer (logits) that produces the output
predictions (bounding box regressions, classification scores).
preprocessor: Optional. An instance of
`RetinaNetObjectDetectorPreprocessor`or a custom preprocessor.
Handles image preprocessing before feeding into the backbone.
Expand All @@ -94,6 +98,7 @@ def __init__(
label_encoder=None,
use_prediction_head_norm=False,
classification_head_prior_probability=0.01,
pre_logits_num_conv_layers=4,
preprocessor=None,
activation=None,
dtype=None,
Expand All @@ -104,9 +109,25 @@ def __init__(
image_input = keras.layers.Input(backbone.image_shape, name="images")
head_dtype = dtype or backbone.dtype_policy

anchor_generator = anchor_generator or AnchorGenerator(
bounding_box_format,
min_level=backbone.min_level,
max_level=backbone.max_level,
num_scales=3,
aspect_ratios=[0.5, 1.0, 2.0],
anchor_size=4,
)
# As weights are ported from torch they use encoded format
# as "center_xywh"
label_encoder = label_encoder or RetinaNetLabelEncoder(
anchor_generator,
bounding_box_format=bounding_box_format,
encoding_format="center_xywh",
)

box_head = PredictionHead(
output_filters=anchor_generator.num_base_anchors * 4,
num_conv_layers=4,
num_conv_layers=pre_logits_num_conv_layers,
num_filters=256,
use_group_norm=use_prediction_head_norm,
use_prior_probability=True,
Expand All @@ -116,7 +137,7 @@ def __init__(
)
classification_head = PredictionHead(
output_filters=anchor_generator.num_base_anchors * num_classes,
num_conv_layers=4,
num_conv_layers=pre_logits_num_conv_layers,
num_filters=256,
use_group_norm=use_prediction_head_norm,
dtype=head_dtype,
Expand Down Expand Up @@ -165,23 +186,11 @@ def __init__(
self.backbone = backbone
self.preprocessor = preprocessor
self.activation = activation
self.pre_logits_num_conv_layers = pre_logits_num_conv_layers
self.box_head = box_head
self.classification_head = classification_head
self.anchor_generator = anchor_generator or AnchorGenerator(
self.bounding_box_format,
min_level=backbone.min_level,
max_level=backbone.max_level,
num_scales=3,
aspect_ratios=[0.5, 1.0, 2.0],
anchor_size=4,
)
# As weights are ported from torch they use encoded format
# as "center_xywh"
self.label_encoder = label_encoder or RetinaNetLabelEncoder(
anchor_generator,
bounding_box_format=bounding_box_format,
encoding_format="center_xywh",
)
self.anchor_generator = anchor_generator
self.label_encoder = label_encoder
self._prediction_decoder = prediction_decoder or NonMaxSuppression(
from_logits=(activation != keras.activations.sigmoid),
bounding_box_format=bounding_box_format,
Expand Down Expand Up @@ -325,6 +334,7 @@ def get_config(self):
{
"num_classes": self.num_classes,
"use_prediction_head_norm": self.use_prediction_head_norm,
"pre_logits_num_conv_layers": self.pre_logits_num_conv_layers,
"bounding_box_format": self.bounding_box_format,
"anchor_generator": keras.layers.serialize(
self.anchor_generator
Expand Down

0 comments on commit 064c971

Please sign in to comment.