Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add video support to segment-anything-2 pipeline #181

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 75 additions & 8 deletions runner/app/pipelines/segment_anything_2.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import logging
import os
import shutil
import tempfile
from typing import List, Optional, Tuple

import PIL
from PIL import Image
from fastapi import UploadFile
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_torch_device, get_model_dir
from app.routes.util import InferenceError
from PIL import ImageFile
from sam2.sam2_image_predictor import SAM2ImagePredictor

ImageFile.LOAD_TRUNCATED_IMAGES = True
from sam2.sam2_video_predictor import SAM2VideoPredictor
import subprocess

logger = logging.getLogger(__name__)

Expand All @@ -19,21 +26,81 @@ def __init__(self, model_id: str):
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()
if torch_device.type == "cuda":
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
elif torch_device.type == "mps":
print(
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
"give numerically different outputs and sometimes degraded performance on MPS. "
"See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
)

self.tm = SAM2ImagePredictor.from_pretrained(
model_id=model_id,
device=torch_device,
**kwargs,
)

self.tm_vid = SAM2VideoPredictor.from_pretrained(
model_id=model_id,
device=torch_device,
)

def __call__(
self, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
try:
self.tm.set_image(image)
prediction = self.tm.predict(**kwargs)
except Exception as e:
raise InferenceError(original_exception=e)
self, media_file: UploadFile, media_type: str, **kwargs
) -> Tuple[List[UploadFile], str, List[Optional[bool]]]:
if media_type == "image":
try:
image = Image.open(media_file.file).convert("RGB")
self.tm.set_image(image)
prediction = self.tm.predict(**kwargs)
except Exception as e:
raise InferenceError(original_exception=e)
elif media_type == "video":
try:
temp_dir = tempfile.mkdtemp()
# TODO: Fix the file type dependency, try passing to ffmpeg without saving to file
video_path = f"{temp_dir}/input.mp4"
with open(video_path, "wb") as video_file:
video_file.write(media_file.file.read())

# Run ffmpeg command to extract frames from video
frame_dir = tempfile.mkdtemp()
output_pattern = f"{frame_dir}/%05d.jpg"
ffmpeg_command = f"ffmpeg -i {video_path} -q:v 2 -start_number 0 {output_pattern}"
subprocess.run(ffmpeg_command, shell=True, check=True)
shutil.rmtree(temp_dir)

# Limit to the first 500 frames to avoid running out of memory
frame_files = sorted(
[f for f in os.listdir(frame_dir) if f.endswith('.jpg')]
)
for frame_file in frame_files[:-500]:
os.remove(os.path.join(frame_dir, frame_file))

inference_state = self.tm_vid.init_state(video_path=frame_dir)
shutil.rmtree(frame_dir)

_, out_obj_ids, out_mask_logits = self.tm_vid.add_new_points_or_box(
inference_state,
frame_idx=kwargs.get('frame_idx', None),
obj_id=1,
points=kwargs.get('points', None),
labels=kwargs.get('labels', None),
)

for out_frame_idx, out_obj_ids, out_mask_logits in self.tm_vid.propagate_in_video(inference_state):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a feeling we should return the full triple instead of creating a video segment, leaving post-processing to the consumer of the API ( though I recognize this is good quick way to validate the sanity of the mask outputs )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the guidance on this, I see how we can just return the results of self.tm_vid.propagate_in_video(inference_state):

Copy link
Collaborator Author

@eliteprox eliteprox Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a feeling we should return the full triple instead of creating a video segment, leaving post-processing to the consumer of the API ( though I recognize this is good quick way to validate the sanity of the mask outputs )

I had some issues trying to return the correct values. I added frame index as an input parameter, normally propagate_in_video will loop returning results for each frame starting at the frame index until the end of the video, now it should only return a single frame. But the data doesn't look correct, can you take a look? @pschroedl

return {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}

except Exception as e:
raise InferenceError(original_exception=e)

return prediction

Expand Down
75 changes: 56 additions & 19 deletions runner/app/routes/segment_anything_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
include_in_schema=False,
)
async def segment_anything_2(
image: Annotated[
media_file: Annotated[
UploadFile, File(description="Image to segment.", media_type="image/*")
],
model_id: Annotated[
Expand Down Expand Up @@ -112,6 +112,10 @@ async def segment_anything_2(
)
),
] = True,
frame_idx : Annotated[
int,
Form(description="Frame index reference for (required video file input only)")
] = -1,
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
Expand Down Expand Up @@ -144,18 +148,51 @@ async def segment_anything_2(
content=http_error(str(e)),
)

supported_video_types = ["video/mp4"]
supported_image_types = ["image/jpeg", "image/png", "image/jpg"]

try:
image = Image.open(image.file).convert("RGB")
masks, scores, low_res_mask_logits = pipeline(
image,
point_coords=point_coords,
point_labels=point_labels,
box=box,
mask_input=mask_input,
multimask_output=multimask_output,
return_logits=return_logits,
normalize_coords=normalize_coords,
)
if media_file.content_type in supported_image_types:
masks, scores, low_res_mask_logits = pipeline(
media_file,
media_type="image",
point_coords=point_coords,
point_labels=point_labels,
box=box,
mask_input=mask_input,
multimask_output=multimask_output,
return_logits=return_logits,
normalize_coords=normalize_coords,
)

# Return masks sorted by descending score as string.
sorted_ind = np.argsort(scores)[::-1]
return {
"masks": str(masks[sorted_ind].tolist()),
"scores": str(scores[sorted_ind].tolist()),
"logits": str(low_res_mask_logits[sorted_ind].tolist()),
}

elif media_file.content_type in supported_video_types:
low_res_mask_logits = pipeline(
media_file,
media_type="video",
frame_idx=frame_idx,
points=point_coords,
labels=point_labels,
)

sadf = low_res_mask_logits


return {
"masks": str(""),
"logits": str(np.array(low_res_mask_logits)),
"scores": str(""),
}
else:
raise InferenceError(f"Unsupported media type: {media_file.content_type}")

except Exception as e:
logger.error(f"Segment Anything 2 error: {e}")
logger.exception(e)
Expand All @@ -170,10 +207,10 @@ async def segment_anything_2(
content=http_error("Segment Anything 2 error"),
)

# Return masks sorted by descending score as string.
sorted_ind = np.argsort(scores)[::-1]
return {
"masks": str(masks[sorted_ind].tolist()),
"scores": str(scores[sorted_ind].tolist()),
"logits": str(low_res_mask_logits[sorted_ind].tolist()),
}
# # Return masks sorted by descending score as string.
# sorted_ind = np.argsort(scores)[::-1]
# return {
# "masks": str(masks[sorted_ind].tolist()),
# "scores": str(scores[sorted_ind].tolist()),
# "logits": str(low_res_mask_logits[sorted_ind].tolist()),
# }
11 changes: 8 additions & 3 deletions runner/gateway.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -438,10 +438,10 @@ components:
title: Body_image_to_video_image_to_video_post
Body_segment_anything_2_segment_anything_2_post:
properties:
image:
media_file:
type: string
format: binary
title: Image
title: Media File
description: Image to segment.
model_id:
type: string
Expand Down Expand Up @@ -486,9 +486,14 @@ components:
description: If true, the point coordinates will be normalized to the range
[0,1], with point_coords expected to be with respect to image dimensions.
default: true
frame_idx:
type: integer
title: Frame Idx
description: Frame index reference for (required video file input only)
default: -1
type: object
required:
- image
- media_file
- model_id
title: Body_segment_anything_2_segment_anything_2_post
Body_upscale_upscale_post:
Expand Down
11 changes: 8 additions & 3 deletions runner/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -446,10 +446,10 @@ components:
title: Body_image_to_video_image_to_video_post
Body_segment_anything_2_segment_anything_2_post:
properties:
image:
media_file:
type: string
format: binary
title: Image
title: Media File
description: Image to segment.
model_id:
type: string
Expand Down Expand Up @@ -494,9 +494,14 @@ components:
description: If true, the point coordinates will be normalized to the range
[0,1], with point_coords expected to be with respect to image dimensions.
default: true
frame_idx:
type: integer
title: Frame Idx
description: Frame index reference for (required video file input only)
default: -1
type: object
required:
- image
- media_file
title: Body_segment_anything_2_segment_anything_2_post
Body_upscale_upscale_post:
properties:
Expand Down
110 changes: 57 additions & 53 deletions worker/runner.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.