From 20ab713d2a54bd669727fafddbe9ff6d1ca74329 Mon Sep 17 00:00:00 2001 From: Takumi Ohyama Date: Mon, 23 Sep 2024 13:47:31 +0000 Subject: [PATCH] Fix missing image_shape in YOLOv8 prediction_decoder --- .../src/models/object_detection/yolo_v8/yolo_v8_detector.py | 5 ++++- .../yolo_v8_segmentation/yolo_v8_segmentation.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/keras_cv/src/models/object_detection/yolo_v8/yolo_v8_detector.py b/keras_cv/src/models/object_detection/yolo_v8/yolo_v8_detector.py index 8347e0fc7e..814de64709 100644 --- a/keras_cv/src/models/object_detection/yolo_v8/yolo_v8_detector.py +++ b/keras_cv/src/models/object_detection/yolo_v8/yolo_v8_detector.py @@ -593,6 +593,7 @@ def decode_predictions( boxes = pred["boxes"] scores = pred["classes"] + image_shape = tuple(images[0].shape) boxes = decode_regression_to_boxes(boxes) anchor_points, stride_tensor = get_anchors(image_shape=images.shape[1:]) @@ -606,7 +607,9 @@ def decode_predictions( images=images, ) - return self.prediction_decoder(box_preds, scores) + return self.prediction_decoder( + box_preds, scores, image_shape=image_shape + ) def predict_step(self, *args): outputs = super().predict_step(*args) diff --git a/keras_cv/src/models/segmentation/yolo_v8_segmentation/yolo_v8_segmentation.py b/keras_cv/src/models/segmentation/yolo_v8_segmentation/yolo_v8_segmentation.py index dcf625df0a..ef4aef12df 100644 --- a/keras_cv/src/models/segmentation/yolo_v8_segmentation/yolo_v8_segmentation.py +++ b/keras_cv/src/models/segmentation/yolo_v8_segmentation/yolo_v8_segmentation.py @@ -729,6 +729,8 @@ def compute_loss(self, x, y, y_pred, sample_weight=None, **kwargs): def decode_predictions(self, pred, images): boxes = pred["boxes"] scores = pred["classes"] + + image_shape = tuple(images[0].shape) boxes = decode_regression_to_boxes(boxes) anchor_points, stride_tensor = get_anchors(image_shape=images.shape[1:]) @@ -742,7 +744,9 @@ def decode_predictions(self, pred, images): images=images, ) - return self.prediction_decoder(box_preds, scores) + return self.prediction_decoder( + box_preds, scores, image_shape=image_shape + ) def predict_step(self, *args): outputs = super().predict_step(*args)