Skip to content

Commit

Permalink
add video support to sam-2 pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
eliteprox committed Sep 10, 2024
1 parent 1455989 commit 3040cbd
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 86 deletions.
83 changes: 75 additions & 8 deletions runner/app/pipelines/segment_anything_2.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import logging
import os
import shutil
import tempfile
from typing import List, Optional, Tuple

import PIL
from PIL import Image
from fastapi import UploadFile
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_torch_device, get_model_dir
from app.routes.util import InferenceError
from PIL import ImageFile
from sam2.sam2_image_predictor import SAM2ImagePredictor

ImageFile.LOAD_TRUNCATED_IMAGES = True
from sam2.sam2_video_predictor import SAM2VideoPredictor
import subprocess

logger = logging.getLogger(__name__)

Expand All @@ -19,21 +26,81 @@ def __init__(self, model_id: str):
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()
if torch_device.type == "cuda":
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
elif torch_device.type == "mps":
print(
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
"give numerically different outputs and sometimes degraded performance on MPS. "
"See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
)

self.tm = SAM2ImagePredictor.from_pretrained(
model_id=model_id,
device=torch_device,
**kwargs,
)

self.tm_vid = SAM2VideoPredictor.from_pretrained(
model_id=model_id,
device=torch_device,
)

def __call__(
self, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
try:
self.tm.set_image(image)
prediction = self.tm.predict(**kwargs)
except Exception as e:
raise InferenceError(original_exception=e)
self, media_file: UploadFile, media_type: str, **kwargs
) -> Tuple[List[UploadFile], str, List[Optional[bool]]]:
if media_type == "image":
try:
image = Image.open(media_file.file).convert("RGB")
self.tm.set_image(image)
prediction = self.tm.predict(**kwargs)
except Exception as e:
raise InferenceError(original_exception=e)
elif media_type == "video":
try:
temp_dir = tempfile.mkdtemp()
# TODO: Fix the file type dependency, try passing to ffmpeg without saving to file
video_path = f"{temp_dir}/input.mp4"
with open(video_path, "wb") as video_file:
video_file.write(media_file.file.read())

# Run ffmpeg command to extract frames from video
frame_dir = tempfile.mkdtemp()
output_pattern = f"{frame_dir}/%05d.jpg"
ffmpeg_command = f"ffmpeg -i {video_path} -q:v 2 -start_number 0 {output_pattern}"
subprocess.run(ffmpeg_command, shell=True, check=True)
shutil.rmtree(temp_dir)

# Limit to the first 500 frames to avoid running out of memory
frame_files = sorted(
[f for f in os.listdir(frame_dir) if f.endswith('.jpg')]
)
for frame_file in frame_files[:-500]:
os.remove(os.path.join(frame_dir, frame_file))

inference_state = self.tm_vid.init_state(video_path=frame_dir)
shutil.rmtree(frame_dir)

_, out_obj_ids, out_mask_logits = self.tm_vid.add_new_points_or_box(
inference_state,
frame_idx=kwargs.get('frame_idx', None),
obj_id=1,
points=kwargs.get('points', None),
labels=kwargs.get('labels', None),
)

for out_frame_idx, out_obj_ids, out_mask_logits in self.tm_vid.propagate_in_video(inference_state):
return {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}

except Exception as e:
raise InferenceError(original_exception=e)

return prediction

Expand Down
75 changes: 56 additions & 19 deletions runner/app/routes/segment_anything_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
include_in_schema=False,
)
async def segment_anything_2(
image: Annotated[
media_file: Annotated[
UploadFile, File(description="Image to segment.", media_type="image/*")
],
model_id: Annotated[
Expand Down Expand Up @@ -112,6 +112,10 @@ async def segment_anything_2(
)
),
] = True,
frame_idx : Annotated[
int,
Form(description="Frame index reference for (required video file input only)")
] = -1,
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
Expand Down Expand Up @@ -144,18 +148,51 @@ async def segment_anything_2(
content=http_error(str(e)),
)

supported_video_types = ["video/mp4"]
supported_image_types = ["image/jpeg", "image/png", "image/jpg"]

try:
image = Image.open(image.file).convert("RGB")
masks, scores, low_res_mask_logits = pipeline(
image,
point_coords=point_coords,
point_labels=point_labels,
box=box,
mask_input=mask_input,
multimask_output=multimask_output,
return_logits=return_logits,
normalize_coords=normalize_coords,
)
if media_file.content_type in supported_image_types:
masks, scores, low_res_mask_logits = pipeline(
media_file,
media_type="image",
point_coords=point_coords,
point_labels=point_labels,
box=box,
mask_input=mask_input,
multimask_output=multimask_output,
return_logits=return_logits,
normalize_coords=normalize_coords,
)

# Return masks sorted by descending score as string.
sorted_ind = np.argsort(scores)[::-1]
return {
"masks": str(masks[sorted_ind].tolist()),
"scores": str(scores[sorted_ind].tolist()),
"logits": str(low_res_mask_logits[sorted_ind].tolist()),
}

elif media_file.content_type in supported_video_types:
low_res_mask_logits = pipeline(
media_file,
media_type="video",
frame_idx=frame_idx,
points=point_coords,
labels=point_labels,
)

sadf = low_res_mask_logits


return {
"masks": str(""),
"logits": str(np.array(low_res_mask_logits)),
"scores": str(""),
}
else:
raise InferenceError(f"Unsupported media type: {media_file.content_type}")

except Exception as e:
logger.error(f"Segment Anything 2 error: {e}")
logger.exception(e)
Expand All @@ -170,10 +207,10 @@ async def segment_anything_2(
content=http_error("Segment Anything 2 error"),
)

# Return masks sorted by descending score as string.
sorted_ind = np.argsort(scores)[::-1]
return {
"masks": str(masks[sorted_ind].tolist()),
"scores": str(scores[sorted_ind].tolist()),
"logits": str(low_res_mask_logits[sorted_ind].tolist()),
}
# # Return masks sorted by descending score as string.
# sorted_ind = np.argsort(scores)[::-1]
# return {
# "masks": str(masks[sorted_ind].tolist()),
# "scores": str(scores[sorted_ind].tolist()),
# "logits": str(low_res_mask_logits[sorted_ind].tolist()),
# }
11 changes: 8 additions & 3 deletions runner/gateway.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -438,10 +438,10 @@ components:
title: Body_image_to_video_image_to_video_post
Body_segment_anything_2_segment_anything_2_post:
properties:
image:
media_file:
type: string
format: binary
title: Image
title: Media File
description: Image to segment.
model_id:
type: string
Expand Down Expand Up @@ -486,9 +486,14 @@ components:
description: If true, the point coordinates will be normalized to the range
[0,1], with point_coords expected to be with respect to image dimensions.
default: true
frame_idx:
type: integer
title: Frame Idx
description: Frame index reference for (required video file input only)
default: -1
type: object
required:
- image
- media_file
- model_id
title: Body_segment_anything_2_segment_anything_2_post
Body_upscale_upscale_post:
Expand Down
11 changes: 8 additions & 3 deletions runner/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -446,10 +446,10 @@ components:
title: Body_image_to_video_image_to_video_post
Body_segment_anything_2_segment_anything_2_post:
properties:
image:
media_file:
type: string
format: binary
title: Image
title: Media File
description: Image to segment.
model_id:
type: string
Expand Down Expand Up @@ -494,9 +494,14 @@ components:
description: If true, the point coordinates will be normalized to the range
[0,1], with point_coords expected to be with respect to image dimensions.
default: true
frame_idx:
type: integer
title: Frame Idx
description: Frame index reference for (required video file input only)
default: -1
type: object
required:
- image
- media_file
title: Body_segment_anything_2_segment_anything_2_post
Body_upscale_upscale_post:
properties:
Expand Down
110 changes: 57 additions & 53 deletions worker/runner.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 3040cbd

Please sign in to comment.