Skip to content

Commit

Permalink
Updated return type logic for OD detections (#147)
Browse files Browse the repository at this point in the history
* Added logic to return list

* lint fixes
  • Loading branch information
Advitya17 authored Aug 21, 2023
1 parent b6ce3d6 commit 93bdbf8
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions python/ml_wrappers/model/image_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def _get_device(device: str) -> str:
:rtype: str
"""
if (device in [member.value for member in Device]
or type(device) == int
or type(device) is int
or device.isdigit()
or device is None):
if device == Device.AUTO.value:
Expand Down Expand Up @@ -525,7 +525,7 @@ def predict(self, x, iou_threshold: float = 0.5, score_threshold: float = 0.5):
"""
detections = []
for image in x:
if type(image) == Tensor:
if type(image) is Tensor:
raw_detections = self._model(
image.to(self._device).unsqueeze(0))
else:
Expand All @@ -543,7 +543,10 @@ def predict(self, x, iou_threshold: float = 0.5, score_threshold: float = 0.5):

detections.append(image_predictions.detach().cpu().numpy()
.tolist())
return np.array(detections)
try:
return np.array(detections)
except ValueError:
return detections

def predict_proba(self, dataset, iou_threshold=0.1):
"""Predict the output probability using the wrapped model.
Expand Down Expand Up @@ -657,7 +660,10 @@ def predict(self, dataset: pd.DataFrame, iou_threshold: float = 0.5,
detections.append(
image_predictions.detach().cpu().numpy().tolist())

return np.array(detections)
try:
return np.array(detections)
except ValueError:
return detections

def predict_proba(self, dataset: pd.DataFrame,
iou_threshold=0.1) -> np.ndarray:
Expand Down

0 comments on commit 93bdbf8

Please sign in to comment.