From 4b60e69c2c868da0096e163a5d752c2fef6c777a Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Wed, 23 Oct 2024 03:23:59 +0000 Subject: [PATCH 1/8] add fast image processor rtdetr --- src/transformers/__init__.py | 4 +- .../models/auto/image_processing_auto.py | 2 +- src/transformers/models/rt_detr/__init__.py | 2 + .../rt_detr/image_processing_rt_detr_fast.py | 856 ++++++++++++++++++ .../utils/dummy_vision_objects.py | 7 + .../rt_detr/test_image_processing_rt_detr.py | 374 ++++---- tests/test_image_processing_common.py | 2 + 7 files changed, 1061 insertions(+), 186 deletions(-) create mode 100644 src/transformers/models/rt_detr/image_processing_rt_detr_fast.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index cc8b07395024a8..e6789c77fb825a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1228,7 +1228,7 @@ _import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"]) _import_structure["models.pvt"].extend(["PvtImageProcessor"]) _import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"]) - _import_structure["models.rt_detr"].extend(["RTDetrImageProcessor"]) + _import_structure["models.rt_detr"].extend(["RTDetrImageProcessor", "RTDetrImageProcessorFast"]) _import_structure["models.sam"].extend(["SamImageProcessor"]) _import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"]) _import_structure["models.seggpt"].extend(["SegGptImageProcessor"]) @@ -6152,7 +6152,7 @@ ) from .models.pvt import PvtImageProcessor from .models.qwen2_vl import Qwen2VLImageProcessor - from .models.rt_detr import RTDetrImageProcessor + from .models.rt_detr import RTDetrImageProcessor, RTDetrImageProcessorFast from .models.sam import SamImageProcessor from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor from .models.seggpt import SegGptImageProcessor diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index d181afeb2d4d0d..5698abe15c8029 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -123,7 +123,7 @@ ("qwen2_vl", ("Qwen2VLImageProcessor",)), ("regnet", ("ConvNextImageProcessor",)), ("resnet", ("ConvNextImageProcessor",)), - ("rt_detr", "RTDetrImageProcessor"), + ("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")), ("sam", ("SamImageProcessor",)), ("segformer", ("SegformerImageProcessor",)), ("seggpt", ("SegGptImageProcessor",)), diff --git a/src/transformers/models/rt_detr/__init__.py b/src/transformers/models/rt_detr/__init__.py index 94a428c66685a6..52453f38b2c4f4 100644 --- a/src/transformers/models/rt_detr/__init__.py +++ b/src/transformers/models/rt_detr/__init__.py @@ -26,6 +26,7 @@ pass else: _import_structure["image_processing_rt_detr"] = ["RTDetrImageProcessor"] + _import_structure["image_processing_rt_detr_fast"] = ["RTDetrImageProcessorFast"] try: if not is_torch_available(): @@ -55,6 +56,7 @@ pass else: from .image_processing_rt_detr import RTDetrImageProcessor + from .image_processing_rt_detr_fast import RTDetrImageProcessorFast try: if not is_torch_available(): diff --git a/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py b/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py new file mode 100644 index 00000000000000..a15ba17cab71de --- /dev/null +++ b/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py @@ -0,0 +1,856 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for RT-DETR.""" + +import functools +import pathlib +from typing import Any, Dict, List, Optional, Tuple, Union + +from ...image_processing_utils import BatchFeature, get_size_dict +from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict +from ...image_transforms import ( + center_to_corners_format, + corners_to_center_format, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + AnnotationFormat, + AnnotationType, + ChannelDimension, + ImageInput, + ImageType, + PILImageResampling, + get_image_size, + get_image_type, + infer_channel_dimension_format, + make_list_of_images, + pil_torch_interpolation_mapping, + validate_annotations, +) +from ...utils import ( + TensorType, + filter_out_non_signature_kwargs, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + is_vision_available, + logging, + requires_backends, +) +from .image_processing_rt_detr import ( + get_size_with_aspect_ratio, + max_across_indices, +) + + +if is_torch_available(): + import torch + +if is_vision_available(): + pass + + +if is_torchvision_available(): + from ...image_utils import pil_torch_interpolation_mapping + + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +logger = logging.get_logger(__name__) + +SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION,) + + +# Copied from transformers.models.detr.image_processing_detr_fast.get_image_size_for_max_height_width +def get_image_size_for_max_height_width( + image_size: Tuple[int, int], + max_height: int, + max_width: int, +) -> Tuple[int, int]: + """ + Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio. + Important, even if image_height < max_height and image_width < max_width, the image will be resized + to at least one of the edges be equal to max_height or max_width. + + For example: + - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50) + - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400) + + Args: + image_size (`Tuple[int, int]`): + The image to resize. + max_height (`int`): + The maximum allowed height. + max_width (`int`): + The maximum allowed width. + """ + height, width = image_size + height_scale = max_height / height + width_scale = max_width / width + min_scale = min(height_scale, width_scale) + new_height = int(height * min_scale) + new_width = int(width * min_scale) + return new_height, new_width + + +# Copied from transformers.models.detr.image_processing_detr_fast.safe_squeeze +def safe_squeeze(tensor: torch.Tensor, axis: Optional[int] = None) -> torch.Tensor: + """ + Squeezes a tensor, but only if the axis specified has dim 1. + """ + if axis is None: + return tensor.squeeze() + + try: + return tensor.squeeze(axis=axis) + except ValueError: + return tensor + + +# Copied from transformers.models.detr.image_processing_detr_fast.get_max_height_width +def get_max_height_width(images: List[torch.Tensor]) -> Tuple[int]: + """ + Get the maximum height and width across all images in a batch. + """ + + _, max_height, max_width = max_across_indices([img.shape for img in images]) + + return (max_height, max_width) + + +def prepare_coco_detection_annotation( + image, + target, + return_segmentation_masks: bool = False, + input_data_format: Optional[Union[ChannelDimension, str]] = None, +): + """ + Convert the target in COCO format into the format expected by RTDETR. + """ + image_height, image_width = image.size()[-2:] + + image_id = target["image_id"] + image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device) + + # Get all COCO annotations for the given image. + annotations = target["annotations"] + annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0] + + classes = [obj["category_id"] for obj in annotations] + classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device) + + # for conversion to coco api + area = torch.as_tensor([obj["area"] for obj in annotations], dtype=torch.float32, device=image.device) + iscrowd = torch.as_tensor( + [obj["iscrowd"] if "iscrowd" in obj else 0 for obj in annotations], dtype=torch.int64, device=image.device + ) + + boxes = [obj["bbox"] for obj in annotations] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width) + boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + + new_target = {} + new_target["image_id"] = image_id + new_target["class_labels"] = classes[keep] + new_target["boxes"] = boxes[keep] + new_target["area"] = area[keep] + new_target["iscrowd"] = iscrowd[keep] + new_target["orig_size"] = torch.as_tensor( + [int(image_height), int(image_width)], dtype=torch.int64, device=image.device + ) + + if annotations and "keypoints" in annotations[0]: + keypoints = [obj["keypoints"] for obj in annotations] + # Converting the filtered keypoints list to a numpy array + keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device) + # Apply the keep mask here to filter the relevant annotations + keypoints = keypoints[keep] + num_keypoints = keypoints.shape[0] + keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints + new_target["keypoints"] = keypoints + + return new_target + + +class RTDetrImageProcessorFast(BaseImageProcessorFast): + r""" + Constructs a fast RT-DETR DETR image processor. + + Args: + format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): + Data format of the annotations. One of "coco_detection" or "coco_panoptic". + do_resize (`bool`, *optional*, defaults to `True`): + Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be + overridden by the `do_resize` parameter in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 800, "longest_edge": 1333}`): + Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter + in the `preprocess` method. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use if resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the + `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `False`): + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`): + Mean values to use when normalizing the image. Can be a single value or a list of values, one for each + channel. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`): + Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one + for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_annotations (`bool`, *optional*, defaults to `True`): + Controls whether to convert the annotations to the format expected by the DETR model. Converts the + bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. + Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `False`): + Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` + method. If `True`, padding will be applied to the bottom and right of the image with zeros. + If `pad_size` is provided, the image will be padded to the specified dimensions. + Otherwise, the image will be padded to the maximum height and width of the batch. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + + model_input_names = ["pixel_values", "pixel_mask"] + + def __init__( + self, + format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: Union[PILImageResampling, F.InterpolationMode] = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = False, + image_mean: Union[float, List[float]] = None, + image_std: Union[float, List[float]] = None, + do_convert_annotations: bool = True, + do_pad: bool = False, + pad_size: Optional[Dict[str, int]] = None, + **kwargs, + ) -> None: + size = size if size is not None else {"height": 640, "width": 640} + size = get_size_dict(size, default_to_square=False) + + if do_convert_annotations is None: + do_convert_annotations = do_normalize + + super().__init__(**kwargs) + self.format = format + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.do_convert_annotations = do_convert_annotations + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.do_pad = do_pad + self.pad_size = pad_size + + def prepare_annotation( + self, + image: torch.Tensor, + target: Dict, + format: Optional[AnnotationFormat] = None, + return_segmentation_masks: bool = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> Dict: + """ + Prepare an annotation for feeding into RTDETR model. + """ + format = format if format is not None else self.format + + if format == AnnotationFormat.COCO_DETECTION: + return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks + target = prepare_coco_detection_annotation( + image, target, return_segmentation_masks, input_data_format=input_data_format + ) + else: + raise ValueError(f"Format {format} is not supported.") + return target + + # Copied from transformers.models.detr.image_processing_detr_fast.DetrImageProcessorFast.resize + def resize( + self, + image: torch.Tensor, + size: SizeDict, + interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR, + **kwargs, + ) -> torch.Tensor: + """ + Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an + int, smaller edge of the image will be matched to this number. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + Resampling filter to use if resizing the image. + """ + if size.shortest_edge and size.longest_edge: + # Resize the image so that the shortest edge or the longest edge is of the given size + # while maintaining the aspect ratio of the original image. + new_size = get_size_with_aspect_ratio( + image.size()[-2:], + size["shortest_edge"], + size["longest_edge"], + ) + elif size.max_height and size.max_width: + new_size = get_image_size_for_max_height_width(image.size()[-2:], size["max_height"], size["max_width"]) + elif size.height and size.width: + new_size = (size["height"], size["width"]) + else: + raise ValueError( + "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got" + f" {size.keys()}." + ) + + image = F.resize( + image, + size=new_size, + interpolation=interpolation, + **kwargs, + ) + return image + + # Copied from transformers.models.detr.image_processing_detr_fast.DetrImageProcessorFast.resize_annotation + def resize_annotation( + self, + annotation: Dict[str, Any], + orig_size: Tuple[int, int], + target_size: Tuple[int, int], + threshold: float = 0.5, + interpolation: F.InterpolationMode = F.InterpolationMode.NEAREST, + ): + """ + Resizes an annotation to a target size. + + Args: + annotation (`Dict[str, Any]`): + The annotation dictionary. + orig_size (`Tuple[int, int]`): + The original size of the input image. + target_size (`Tuple[int, int]`): + The target size of the image, as returned by the preprocessing `resize` step. + threshold (`float`, *optional*, defaults to 0.5): + The threshold used to binarize the segmentation masks. + resample (`InterpolationMode`, defaults to `InterpolationMode.NEAREST`): + The resampling filter to use when resizing the masks. + """ + ratio_height, ratio_width = [target / orig for target, orig in zip(target_size, orig_size)] + + new_annotation = {} + new_annotation["size"] = target_size + + for key, value in annotation.items(): + if key == "boxes": + boxes = value + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32, device=boxes.device + ) + new_annotation["boxes"] = scaled_boxes + elif key == "area": + area = value + scaled_area = area * (ratio_width * ratio_height) + new_annotation["area"] = scaled_area + elif key == "masks": + masks = value[:, None] + masks = [F.resize(mask, target_size, interpolation=interpolation) for mask in masks] + masks = torch.stack(masks).to(torch.float32) + masks = masks[:, 0] > threshold + new_annotation["masks"] = masks + elif key == "size": + new_annotation["size"] = target_size + else: + new_annotation[key] = value + + return new_annotation + + # Copied from transformers.models.detr.image_processing_detr_fast.DetrImageProcessorFast.normalize_annotation + def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict: + image_height, image_width = image_size + norm_annotation = {} + for key, value in annotation.items(): + if key == "boxes": + boxes = value + boxes = corners_to_center_format(boxes) + boxes /= torch.as_tensor( + [image_width, image_height, image_width, image_height], dtype=torch.float32, device=boxes.device + ) + norm_annotation[key] = boxes + else: + norm_annotation[key] = value + return norm_annotation + + # Copied from transformers.models.detr.image_processing_detr_fast.DetrImageProcessorFast._update_annotation_for_padded_image + def _update_annotation_for_padded_image( + self, + annotation: Dict, + input_image_size: Tuple[int, int], + output_image_size: Tuple[int, int], + padding, + update_bboxes, + ) -> Dict: + """ + Update the annotation for a padded image. + """ + new_annotation = {} + new_annotation["size"] = output_image_size + ratio_height, ratio_width = (input / output for output, input in zip(output_image_size, input_image_size)) + + for key, value in annotation.items(): + if key == "masks": + masks = value + masks = F.pad( + masks, + padding, + fill=0, + ) + masks = safe_squeeze(masks, 1) + new_annotation["masks"] = masks + elif key == "boxes" and update_bboxes: + boxes = value + boxes *= torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height], device=boxes.device) + new_annotation["boxes"] = boxes + elif key == "size": + new_annotation["size"] = output_image_size + else: + new_annotation[key] = value + return new_annotation + + # Copied from transformers.models.detr.image_processing_detr_fast.DetrImageProcessorFast.pad + def pad( + self, + image: torch.Tensor, + padded_size: Tuple[int, int], + annotation: Optional[Dict[str, Any]] = None, + update_bboxes: bool = True, + fill: int = 0, + ): + original_size = image.size()[-2:] + padding_bottom = padded_size[0] - original_size[0] + padding_right = padded_size[1] - original_size[1] + if padding_bottom < 0 or padding_right < 0: + raise ValueError( + f"Padding dimensions are negative. Please make sure that the padded size is larger than the " + f"original size. Got padded size: {padded_size}, original size: {original_size}." + ) + if original_size != padded_size: + padding = [0, 0, padding_right, padding_bottom] + image = F.pad(image, padding, fill=fill) + if annotation is not None: + annotation = self._update_annotation_for_padded_image( + annotation, original_size, padded_size, padding, update_bboxes + ) + + # Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. + pixel_mask = torch.zeros(padded_size, dtype=torch.int64, device=image.device) + pixel_mask[: original_size[0], : original_size[1]] = 1 + + return image, pixel_mask, annotation + + @functools.lru_cache(maxsize=1) + # Copied from transformers.models.detr.image_processing_detr_fast.DetrImageProcessorFast._validate_input_arguments + def _validate_input_arguments( + self, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Union[float, List[float]], + image_std: Union[float, List[float]], + do_resize: bool, + size: Dict[str, int], + resample: "PILImageResampling", + data_format: Union[str, ChannelDimension], + return_tensors: Union[TensorType, str], + ): + if return_tensors != "pt": + raise ValueError("Only returning PyTorch tensors is currently supported.") + + if data_format != ChannelDimension.FIRST: + raise ValueError("Only channel first data format is currently supported.") + + if do_resize and None in (size, resample): + raise ValueError("Size and resample must be specified if do_resize is True.") + + if do_rescale and rescale_factor is None: + raise ValueError("Rescale factor must be specified if do_rescale is True.") + + if do_normalize and None in (image_mean, image_std): + raise ValueError("Image mean and standard deviation must be specified if do_normalize is True.") + + @filter_out_non_signature_kwargs(extra=["device"]) + def preprocess( + self, + images: ImageInput, + annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None, + return_segmentation_masks: bool = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: Optional[Union[PILImageResampling, F.InterpolationMode]] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[Union[int, float]] = None, + do_normalize: Optional[bool] = None, + do_convert_annotations: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + format: Optional[Union[str, AnnotationFormat]] = None, + return_tensors: Optional[Union[TensorType, str]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + pad_size: Optional[Dict[str, int]] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or a batch of images so that it can be used by the model. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging + from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. + annotations (`AnnotationType` or `List[AnnotationType]`, *optional*): + List of annotations associated with the image or batch of images. If annotation is for object + detection, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a + dictionary. An image can have no annotations, in which case the list should be empty. + If annotation is for segmentation, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary. + An image can have no segments, in which case the list should be empty. + - "file_name" (`str`): The file name of the image. + return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks): + Whether to return segmentation masks. + masks_path (`str` or `pathlib.Path`, *optional*): + Path to the directory containing the segmentation masks. + do_resize (`bool`, *optional*, defaults to self.do_resize): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to self.size): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to self.resample): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to self.do_rescale): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to self.rescale_factor): + Rescale factor to use when rescaling the image. + do_normalize (`bool`, *optional*, defaults to self.do_normalize): + Whether to normalize the image. + do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations): + Whether to convert the annotations to the format expected by the model. Converts the bounding + boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)` + and in relative coordinates. + image_mean (`float` or `List[float]`, *optional*, defaults to self.image_mean): + Mean to use when normalizing the image. + image_std (`float` or `List[float]`, *optional*, defaults to self.image_std): + Standard deviation to use when normalizing the image. + do_pad (`bool`, *optional*, defaults to self.do_pad): + Whether to pad the image. If `True`, padding will be applied to the bottom and right of + the image with zeros. If `pad_size` is provided, the image will be padded to the specified + dimensions. Otherwise, the image will be padded to the maximum height and width of the batch. + format (`str` or `AnnotationFormat`, *optional*, defaults to self.format): + Format of the annotations. + return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors): + Type of tensors to return. If `None`, will return the list of images. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + do_resize = self.do_resize if do_resize is None else do_resize + size = self.size if size is None else size + size = get_size_dict(size=size, default_to_square=True) + resample = self.resample if resample is None else resample + do_rescale = self.do_rescale if do_rescale is None else do_rescale + rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor + do_normalize = self.do_normalize if do_normalize is None else do_normalize + image_mean = self.image_mean if image_mean is None else image_mean + image_std = self.image_std if image_std is None else image_std + do_convert_annotations = ( + self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations + ) + do_pad = self.do_pad if do_pad is None else do_pad + pad_size = self.pad_size if pad_size is None else pad_size + format = self.format if format is None else format + return_tensors = "pt" if return_tensors is None else return_tensors + device = kwargs.pop("device", None) + + # Make hashable for cache + size = SizeDict(**size) + image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean + image_std = tuple(image_std) if isinstance(image_std, list) else image_std + + images = make_list_of_images(images) + image_type = get_image_type(images[0]) + + if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]: + raise ValueError(f"Unsupported input image type {image_type}") + + self._validate_input_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + return_tensors=return_tensors, + data_format=data_format, + ) + + if annotations is not None and isinstance(annotations, dict): + annotations = [annotations] + + if annotations is not None and len(images) != len(annotations): + raise ValueError( + f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." + ) + + format = AnnotationFormat(format) + if annotations is not None: + validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations) + + data = {} + if image_type == ImageType.PIL: + images = [F.pil_to_tensor(image) for image in images] + elif image_type == ImageType.NUMPY: + # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays + images = [torch.from_numpy(image).contiguous() for image in images] + + if device is not None: + images = [image.to(device) for image in images] + + # We assume that all images have the same channel dimension format. + if input_data_format is None: + input_data_format = infer_channel_dimension_format(images[0]) + if input_data_format == ChannelDimension.LAST: + images = [image.permute(2, 0, 1).contiguous() for image in images] + + # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) + if annotations is not None: + prepared_images = [] + prepared_annotations = [] + for image, target in zip(images, annotations): + target = self.prepare_annotation( + image, + target, + format, + return_segmentation_masks=return_segmentation_masks, + masks_path=masks_path, + input_data_format=input_data_format, + ) + prepared_images.append(image) + prepared_annotations.append(target) + images = prepared_images + annotations = prepared_annotations + del prepared_images, prepared_annotations + + if do_resize: + if isinstance(resample, (PILImageResampling, int)): + interpolation = pil_torch_interpolation_mapping[resample] + else: + interpolation = resample + resized_images = [self.resize(image, size=size, interpolation=interpolation) for image in images] + if annotations is not None: + for i, (image, target) in enumerate(zip(resized_images, annotations)): + annotations[i] = self.resize_annotation( + target, + orig_size=images[i].size()[-2:], + target_size=image.size()[-2:], + ) + images = resized_images + del resized_images + + if do_rescale and do_normalize: + # fused rescale and normalize + new_mean = torch.tensor(image_mean, device=images[0].device) * (1.0 / rescale_factor) + new_std = torch.tensor(image_std, device=images[0].device) * (1.0 / rescale_factor) + images = [F.normalize(image.to(dtype=torch.float32), new_mean, new_std) for image in images] + elif do_rescale: + images = [image * rescale_factor for image in images] + elif do_normalize: + images = [F.normalize(image, image_mean, image_std) for image in images] + + if do_convert_annotations and annotations is not None: + annotations = [ + self.normalize_annotation(annotation, get_image_size(image, input_data_format)) + for annotation, image in zip(annotations, images) + ] + + if do_pad: + # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} + if pad_size is not None: + padded_size = (pad_size["height"], pad_size["width"]) + else: + padded_size = get_max_height_width(images) + + annotation_list = annotations if annotations is not None else [None] * len(images) + padded_images = [] + pixel_masks = [] + padded_annotations = [] + for image, annotation in zip(images, annotation_list): + if padded_size == image.size()[-2:]: + padded_images.append(image) + pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device)) + padded_annotations.append(annotation) + continue + padded_image, pixel_mask, padded_annotation = self.pad( + image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations + ) + padded_images.append(padded_image) + pixel_masks.append(pixel_mask) + padded_annotations.append(padded_annotation) + images = padded_images + if annotations is not None: + annotations = padded_annotations + del padded_images, padded_annotations + data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)}) + + data.update({"pixel_values": torch.stack(images, dim=0)}) + encoded_inputs = BatchFeature(data, tensor_type=return_tensors) + if annotations is not None: + encoded_inputs["labels"] = [ + BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations + ] + return encoded_inputs + + # Copied from transformers.models.rt_detr.image_processing_rt_detr.RTDetrImageProcessor.post_process_object_detection + def post_process_object_detection( + self, + outputs, + threshold: float = 0.5, + target_sizes: Union[TensorType, List[Tuple]] = None, + use_focal_loss: bool = True, + ): + """ + Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. Only supports PyTorch. + + Args: + outputs ([`DetrObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.5): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + use_focal_loss (`bool` defaults to `True`): + Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied + to compute the scores of each detection, otherwise, a softmax function is used. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + requires_backends(self, ["torch"]) + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + # convert from relative cxcywh to absolute xyxy + boxes = center_to_corners_format(out_bbox) + if target_sizes is not None: + if len(out_logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + if isinstance(target_sizes, List): + img_h = torch.Tensor([i[0] for i in target_sizes]) + img_w = torch.Tensor([i[1] for i in target_sizes]) + else: + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + num_top_queries = out_logits.shape[1] + num_classes = out_logits.shape[2] + + if use_focal_loss: + scores = torch.nn.functional.sigmoid(out_logits) + scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1) + labels = index % num_classes + index = index // num_classes + boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) + else: + scores = torch.nn.functional.softmax(out_logits)[:, :, :-1] + scores, labels = scores.max(dim=-1) + if scores.shape[1] > num_top_queries: + scores, index = torch.topk(scores, num_top_queries, dim=-1) + labels = torch.gather(labels, dim=1, index=index) + boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) + + results = [] + for s, l, b in zip(scores, labels, boxes): + score = s[s > threshold] + label = l[s > threshold] + box = b[s > threshold] + results.append({"scores": score, "labels": label, "boxes": box}) + + return results diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index d7f87717ca834a..19cf02a4e85826 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -569,6 +569,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class RTDetrImageProcessorFast(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class SamImageProcessor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/rt_detr/test_image_processing_rt_detr.py b/tests/models/rt_detr/test_image_processing_rt_detr.py index 2a38664d433fea..00a2d57c34051a 100644 --- a/tests/models/rt_detr/test_image_processing_rt_detr.py +++ b/tests/models/rt_detr/test_image_processing_rt_detr.py @@ -17,7 +17,7 @@ import requests from transformers.testing_utils import require_torch, require_vision, slow -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -25,7 +25,7 @@ if is_vision_available(): from PIL import Image - from transformers import RTDetrImageProcessor + from transformers import RTDetrImageProcessor, RTDetrImageProcessorFast if is_torch_available(): import torch @@ -91,6 +91,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F @require_vision class RtDetrImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = RTDetrImageProcessor if is_vision_available() else None + fast_image_processing_class = RTDetrImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -101,17 +102,19 @@ def image_processor_dict(self): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "resample")) - self.assertTrue(hasattr(image_processing, "do_rescale")) - self.assertTrue(hasattr(image_processing, "rescale_factor")) - self.assertTrue(hasattr(image_processing, "return_tensors")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "resample")) + self.assertTrue(hasattr(image_processing, "do_rescale")) + self.assertTrue(hasattr(image_processing, "rescale_factor")) + self.assertTrue(hasattr(image_processing, "return_tensors")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"height": 640, "width": 640}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 640, "width": 640}) def test_valid_coco_detection_annotations(self): # prepare image and target @@ -121,32 +124,33 @@ def test_valid_coco_detection_annotations(self): params = {"image_id": 39769, "annotations": target} - # encode them - image_processing = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd") + for image_processing_class in self.image_processor_list: + # encode them + image_processing = image_processing_class.from_pretrained("PekingU/rtdetr_r50vd") - # legal encodings (single image) - _ = image_processing(images=image, annotations=params, return_tensors="pt") - _ = image_processing(images=image, annotations=[params], return_tensors="pt") + # legal encodings (single image) + _ = image_processing(images=image, annotations=params, return_tensors="pt") + _ = image_processing(images=image, annotations=[params], return_tensors="pt") - # legal encodings (batch of one image) - _ = image_processing(images=[image], annotations=params, return_tensors="pt") - _ = image_processing(images=[image], annotations=[params], return_tensors="pt") + # legal encodings (batch of one image) + _ = image_processing(images=[image], annotations=params, return_tensors="pt") + _ = image_processing(images=[image], annotations=[params], return_tensors="pt") - # legal encoding (batch of more than one image) - n = 5 - _ = image_processing(images=[image] * n, annotations=[params] * n, return_tensors="pt") + # legal encoding (batch of more than one image) + n = 5 + _ = image_processing(images=[image] * n, annotations=[params] * n, return_tensors="pt") - # example of an illegal encoding (missing the 'image_id' key) - with self.assertRaises(ValueError) as e: - image_processing(images=image, annotations={"annotations": target}, return_tensors="pt") + # example of an illegal encoding (missing the 'image_id' key) + with self.assertRaises(ValueError) as e: + image_processing(images=image, annotations={"annotations": target}, return_tensors="pt") - self.assertTrue(str(e.exception).startswith("Invalid COCO detection annotations")) + self.assertTrue(str(e.exception).startswith("Invalid COCO detection annotations")) - # example of an illegal encoding (unequal lengths of images and annotations) - with self.assertRaises(ValueError) as e: - image_processing(images=[image] * n, annotations=[params] * (n - 1), return_tensors="pt") + # example of an illegal encoding (unequal lengths of images and annotations) + with self.assertRaises(ValueError) as e: + image_processing(images=[image] * n, annotations=[params] * (n - 1), return_tensors="pt") - self.assertTrue(str(e.exception) == "The number of images (5) and annotations (4) do not match.") + self.assertTrue(str(e.exception) == "The number of images (5) and annotations (4) do not match.") @slow def test_call_pytorch_with_coco_detection_annotations(self): @@ -157,55 +161,57 @@ def test_call_pytorch_with_coco_detection_annotations(self): target = {"image_id": 39769, "annotations": target} - # encode them - image_processing = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd") - encoding = image_processing(images=image, annotations=target, return_tensors="pt") - - # verify pixel values - expected_shape = torch.Size([1, 3, 640, 640]) - self.assertEqual(encoding["pixel_values"].shape, expected_shape) - - expected_slice = torch.tensor([0.5490, 0.5647, 0.5725]) - self.assertTrue(torch.allclose(encoding["pixel_values"][0, 0, 0, :3], expected_slice, atol=1e-4)) - - # verify area - expected_area = torch.tensor([2827.9883, 5403.4761, 235036.7344, 402070.2188, 71068.8281, 79601.2812]) - self.assertTrue(torch.allclose(encoding["labels"][0]["area"], expected_area)) - # verify boxes - expected_boxes_shape = torch.Size([6, 4]) - self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape) - expected_boxes_slice = torch.tensor([0.5503, 0.2765, 0.0604, 0.2215]) - self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"][0], expected_boxes_slice, atol=1e-3)) - # verify image_id - expected_image_id = torch.tensor([39769]) - self.assertTrue(torch.allclose(encoding["labels"][0]["image_id"], expected_image_id)) - # verify is_crowd - expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0]) - self.assertTrue(torch.allclose(encoding["labels"][0]["iscrowd"], expected_is_crowd)) - # verify class_labels - expected_class_labels = torch.tensor([75, 75, 63, 65, 17, 17]) - self.assertTrue(torch.allclose(encoding["labels"][0]["class_labels"], expected_class_labels)) - # verify orig_size - expected_orig_size = torch.tensor([480, 640]) - self.assertTrue(torch.allclose(encoding["labels"][0]["orig_size"], expected_orig_size)) - # verify size - expected_size = torch.tensor([640, 640]) - self.assertTrue(torch.allclose(encoding["labels"][0]["size"], expected_size)) + for image_processing_class in self.image_processor_list: + # encode them + image_processing = image_processing_class.from_pretrained("PekingU/rtdetr_r50vd") + encoding = image_processing(images=image, annotations=target, return_tensors="pt") + + # verify pixel values + expected_shape = torch.Size([1, 3, 640, 640]) + self.assertEqual(encoding["pixel_values"].shape, expected_shape) + + expected_slice = torch.tensor([0.5490, 0.5647, 0.5725]) + self.assertTrue(torch.allclose(encoding["pixel_values"][0, 0, 0, :3], expected_slice, atol=1e-4)) + + # verify area + expected_area = torch.tensor([2827.9883, 5403.4761, 235036.7344, 402070.2188, 71068.8281, 79601.2812]) + self.assertTrue(torch.allclose(encoding["labels"][0]["area"], expected_area)) + # verify boxes + expected_boxes_shape = torch.Size([6, 4]) + self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape) + expected_boxes_slice = torch.tensor([0.5503, 0.2765, 0.0604, 0.2215]) + self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"][0], expected_boxes_slice, atol=1e-3)) + # verify image_id + expected_image_id = torch.tensor([39769]) + self.assertTrue(torch.allclose(encoding["labels"][0]["image_id"], expected_image_id)) + # verify is_crowd + expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0]) + self.assertTrue(torch.allclose(encoding["labels"][0]["iscrowd"], expected_is_crowd)) + # verify class_labels + expected_class_labels = torch.tensor([75, 75, 63, 65, 17, 17]) + self.assertTrue(torch.allclose(encoding["labels"][0]["class_labels"], expected_class_labels)) + # verify orig_size + expected_orig_size = torch.tensor([480, 640]) + self.assertTrue(torch.allclose(encoding["labels"][0]["orig_size"], expected_orig_size)) + # verify size + expected_size = torch.tensor([640, 640]) + self.assertTrue(torch.allclose(encoding["labels"][0]["size"], expected_size)) @slow def test_image_processor_outputs(self): image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") - image_processing = self.image_processing_class(**self.image_processor_dict) - encoding = image_processing(images=image, return_tensors="pt") + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + encoding = image_processing(images=image, return_tensors="pt") - # verify pixel values: shape - expected_shape = torch.Size([1, 3, 640, 640]) - self.assertEqual(encoding["pixel_values"].shape, expected_shape) + # verify pixel values: shape + expected_shape = torch.Size([1, 3, 640, 640]) + self.assertEqual(encoding["pixel_values"].shape, expected_shape) - # verify pixel values: output values - expected_slice = torch.tensor([0.5490196347236633, 0.5647059082984924, 0.572549045085907]) - self.assertTrue(torch.allclose(encoding["pixel_values"][0, 0, 0, :3], expected_slice, atol=1e-5)) + # verify pixel values: output values + expected_slice = torch.tensor([0.5490196347236633, 0.5647059082984924, 0.572549045085907]) + self.assertTrue(torch.allclose(encoding["pixel_values"][0, 0, 0, :3], expected_slice, atol=1e-5)) def test_multiple_images_processor_outputs(self): images_urls = [ @@ -224,31 +230,32 @@ def test_multiple_images_processor_outputs(self): image = Image.open(requests.get(url, stream=True).raw) images.append(image) - # apply image processing - image_processing = self.image_processing_class(**self.image_processor_dict) - encoding = image_processing(images=images, return_tensors="pt") - - # verify if pixel_values is part of the encoding - self.assertIn("pixel_values", encoding) - - # verify pixel values: shape - expected_shape = torch.Size([8, 3, 640, 640]) - self.assertEqual(encoding["pixel_values"].shape, expected_shape) - - # verify pixel values: output values - expected_slices = torch.tensor( - [ - [0.5333333611488342, 0.5568627715110779, 0.5647059082984924], - [0.5372549295425415, 0.4705882668495178, 0.4274510145187378], - [0.3960784673690796, 0.35686275362968445, 0.3686274588108063], - [0.20784315466880798, 0.1882353127002716, 0.15294118225574493], - [0.364705890417099, 0.364705890417099, 0.3686274588108063], - [0.8078432083129883, 0.8078432083129883, 0.8078432083129883], - [0.4431372880935669, 0.4431372880935669, 0.4431372880935669], - [0.19607844948768616, 0.21176472306251526, 0.3607843220233917], - ] - ) - self.assertTrue(torch.allclose(encoding["pixel_values"][:, 1, 0, :3], expected_slices, atol=1e-5)) + for image_processing_class in self.image_processor_list: + # apply image processing + image_processing = image_processing_class(**self.image_processor_dict) + encoding = image_processing(images=images, return_tensors="pt") + + # verify if pixel_values is part of the encoding + self.assertIn("pixel_values", encoding) + + # verify pixel values: shape + expected_shape = torch.Size([8, 3, 640, 640]) + self.assertEqual(encoding["pixel_values"].shape, expected_shape) + + # verify pixel values: output values + expected_slices = torch.tensor( + [ + [0.5333333611488342, 0.5568627715110779, 0.5647059082984924], + [0.5372549295425415, 0.4705882668495178, 0.4274510145187378], + [0.3960784673690796, 0.35686275362968445, 0.3686274588108063], + [0.20784315466880798, 0.1882353127002716, 0.15294118225574493], + [0.364705890417099, 0.364705890417099, 0.3686274588108063], + [0.8078432083129883, 0.8078432083129883, 0.8078432083129883], + [0.4431372880935669, 0.4431372880935669, 0.4431372880935669], + [0.19607844948768616, 0.21176472306251526, 0.3607843220233917], + ] + ) + self.assertTrue(torch.allclose(encoding["pixel_values"][:, 1, 0, :3], expected_slices, atol=1e-5)) @slow def test_batched_coco_detection_annotations(self): @@ -277,89 +284,90 @@ def test_batched_coco_detection_annotations(self): images = [image_0, image_1] annotations = [annotations_0, annotations_1] - image_processing = RTDetrImageProcessor() - encoding = image_processing( - images=images, - annotations=annotations, - return_segmentation_masks=True, - return_tensors="pt", # do_convert_annotations=True - ) - - # Check the pixel values have been padded - postprocessed_height, postprocessed_width = 640, 640 - expected_shape = torch.Size([2, 3, postprocessed_height, postprocessed_width]) - self.assertEqual(encoding["pixel_values"].shape, expected_shape) - - # Check the bounding boxes have been adjusted for padded images - self.assertEqual(encoding["labels"][0]["boxes"].shape, torch.Size([6, 4])) - self.assertEqual(encoding["labels"][1]["boxes"].shape, torch.Size([6, 4])) - expected_boxes_0 = torch.tensor( - [ - [0.6879, 0.4609, 0.0755, 0.3691], - [0.2118, 0.3359, 0.2601, 0.1566], - [0.5011, 0.5000, 0.9979, 1.0000], - [0.5010, 0.5020, 0.9979, 0.9959], - [0.3284, 0.5944, 0.5884, 0.8112], - [0.8394, 0.5445, 0.3213, 0.9110], - ] - ) - expected_boxes_1 = torch.tensor( - [ - [0.5503, 0.2765, 0.0604, 0.2215], - [0.1695, 0.2016, 0.2080, 0.0940], - [0.5006, 0.4933, 0.9977, 0.9865], - [0.5008, 0.5002, 0.9983, 0.9955], - [0.2627, 0.5456, 0.4707, 0.8646], - [0.7715, 0.4115, 0.4570, 0.7161], - ] - ) - self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"], expected_boxes_0, rtol=1e-3)) - self.assertTrue(torch.allclose(encoding["labels"][1]["boxes"], expected_boxes_1, rtol=1e-3)) - - # Check if do_convert_annotations=False, then the annotations are not converted to centre_x, centre_y, width, height - # format and not in the range [0, 1] - encoding = image_processing( - images=images, - annotations=annotations, - return_segmentation_masks=True, - do_convert_annotations=False, - return_tensors="pt", - ) - self.assertEqual(encoding["labels"][0]["boxes"].shape, torch.Size([6, 4])) - self.assertEqual(encoding["labels"][1]["boxes"].shape, torch.Size([6, 4])) - # Convert to absolute coordinates - unnormalized_boxes_0 = torch.vstack( - [ - expected_boxes_0[:, 0] * postprocessed_width, - expected_boxes_0[:, 1] * postprocessed_height, - expected_boxes_0[:, 2] * postprocessed_width, - expected_boxes_0[:, 3] * postprocessed_height, - ] - ).T - unnormalized_boxes_1 = torch.vstack( - [ - expected_boxes_1[:, 0] * postprocessed_width, - expected_boxes_1[:, 1] * postprocessed_height, - expected_boxes_1[:, 2] * postprocessed_width, - expected_boxes_1[:, 3] * postprocessed_height, - ] - ).T - # Convert from centre_x, centre_y, width, height to x_min, y_min, x_max, y_max - expected_boxes_0 = torch.vstack( - [ - unnormalized_boxes_0[:, 0] - unnormalized_boxes_0[:, 2] / 2, - unnormalized_boxes_0[:, 1] - unnormalized_boxes_0[:, 3] / 2, - unnormalized_boxes_0[:, 0] + unnormalized_boxes_0[:, 2] / 2, - unnormalized_boxes_0[:, 1] + unnormalized_boxes_0[:, 3] / 2, - ] - ).T - expected_boxes_1 = torch.vstack( - [ - unnormalized_boxes_1[:, 0] - unnormalized_boxes_1[:, 2] / 2, - unnormalized_boxes_1[:, 1] - unnormalized_boxes_1[:, 3] / 2, - unnormalized_boxes_1[:, 0] + unnormalized_boxes_1[:, 2] / 2, - unnormalized_boxes_1[:, 1] + unnormalized_boxes_1[:, 3] / 2, - ] - ).T - self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"], expected_boxes_0, rtol=1)) - self.assertTrue(torch.allclose(encoding["labels"][1]["boxes"], expected_boxes_1, rtol=1)) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class() + encoding = image_processing( + images=images, + annotations=annotations, + return_segmentation_masks=True, + return_tensors="pt", # do_convert_annotations=True + ) + + # Check the pixel values have been padded + postprocessed_height, postprocessed_width = 640, 640 + expected_shape = torch.Size([2, 3, postprocessed_height, postprocessed_width]) + self.assertEqual(encoding["pixel_values"].shape, expected_shape) + + # Check the bounding boxes have been adjusted for padded images + self.assertEqual(encoding["labels"][0]["boxes"].shape, torch.Size([6, 4])) + self.assertEqual(encoding["labels"][1]["boxes"].shape, torch.Size([6, 4])) + expected_boxes_0 = torch.tensor( + [ + [0.6879, 0.4609, 0.0755, 0.3691], + [0.2118, 0.3359, 0.2601, 0.1566], + [0.5011, 0.5000, 0.9979, 1.0000], + [0.5010, 0.5020, 0.9979, 0.9959], + [0.3284, 0.5944, 0.5884, 0.8112], + [0.8394, 0.5445, 0.3213, 0.9110], + ] + ) + expected_boxes_1 = torch.tensor( + [ + [0.5503, 0.2765, 0.0604, 0.2215], + [0.1695, 0.2016, 0.2080, 0.0940], + [0.5006, 0.4933, 0.9977, 0.9865], + [0.5008, 0.5002, 0.9983, 0.9955], + [0.2627, 0.5456, 0.4707, 0.8646], + [0.7715, 0.4115, 0.4570, 0.7161], + ] + ) + self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"], expected_boxes_0, rtol=1e-3)) + self.assertTrue(torch.allclose(encoding["labels"][1]["boxes"], expected_boxes_1, rtol=1e-3)) + + # Check if do_convert_annotations=False, then the annotations are not converted to centre_x, centre_y, width, height + # format and not in the range [0, 1] + encoding = image_processing( + images=images, + annotations=annotations, + return_segmentation_masks=True, + do_convert_annotations=False, + return_tensors="pt", + ) + self.assertEqual(encoding["labels"][0]["boxes"].shape, torch.Size([6, 4])) + self.assertEqual(encoding["labels"][1]["boxes"].shape, torch.Size([6, 4])) + # Convert to absolute coordinates + unnormalized_boxes_0 = torch.vstack( + [ + expected_boxes_0[:, 0] * postprocessed_width, + expected_boxes_0[:, 1] * postprocessed_height, + expected_boxes_0[:, 2] * postprocessed_width, + expected_boxes_0[:, 3] * postprocessed_height, + ] + ).T + unnormalized_boxes_1 = torch.vstack( + [ + expected_boxes_1[:, 0] * postprocessed_width, + expected_boxes_1[:, 1] * postprocessed_height, + expected_boxes_1[:, 2] * postprocessed_width, + expected_boxes_1[:, 3] * postprocessed_height, + ] + ).T + # Convert from centre_x, centre_y, width, height to x_min, y_min, x_max, y_max + expected_boxes_0 = torch.vstack( + [ + unnormalized_boxes_0[:, 0] - unnormalized_boxes_0[:, 2] / 2, + unnormalized_boxes_0[:, 1] - unnormalized_boxes_0[:, 3] / 2, + unnormalized_boxes_0[:, 0] + unnormalized_boxes_0[:, 2] / 2, + unnormalized_boxes_0[:, 1] + unnormalized_boxes_0[:, 3] / 2, + ] + ).T + expected_boxes_1 = torch.vstack( + [ + unnormalized_boxes_1[:, 0] - unnormalized_boxes_1[:, 2] / 2, + unnormalized_boxes_1[:, 1] - unnormalized_boxes_1[:, 3] / 2, + unnormalized_boxes_1[:, 0] + unnormalized_boxes_1[:, 2] / 2, + unnormalized_boxes_1[:, 1] + unnormalized_boxes_1[:, 3] / 2, + ] + ).T + self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"], expected_boxes_0, rtol=1)) + self.assertTrue(torch.allclose(encoding["labels"][1]["boxes"], expected_boxes_1, rtol=1)) diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py index 7d89b43ce35ba4..3e9579b2e59841 100644 --- a/tests/test_image_processing_common.py +++ b/tests/test_image_processing_common.py @@ -283,6 +283,8 @@ def test_save_load_fast_slow_auto(self): image_processor_slow_1 = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=False) self.assertEqual(image_processor_slow_0.to_dict(), image_processor_slow_1.to_dict()) + print(image_processor_fast_0.to_dict()) + print(image_processor_fast_1.to_dict()) self.assertEqual(image_processor_fast_0.to_dict(), image_processor_fast_1.to_dict()) def test_init_without_params(self): From e4c57c3e5828b6c3d1cea9fde870c90303dd03b0 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Wed, 23 Oct 2024 17:26:32 +0000 Subject: [PATCH 2/8] add gpu/cpu test and fix docstring --- .../rt_detr/image_processing_rt_detr_fast.py | 4 +- .../models/detr/test_image_processing_detr.py | 4 +- .../rt_detr/test_image_processing_rt_detr.py | 58 ++++++++++++++++++- 3 files changed, 62 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py b/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py index a15ba17cab71de..e0a353c58e3bc2 100644 --- a/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py +++ b/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py @@ -224,7 +224,9 @@ class RTDetrImageProcessorFast(BaseImageProcessorFast): `preprocess` method. Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` method. - do_normalize (`bool`, *optional*, defaults to `False`): + do_normalize (`bool`, *optional*, defaults to `False`): + Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the + `preprocess` method. image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`): Mean values to use when normalizing the image. Can be a single value or a list of values, one for each channel. Can be overridden by the `image_mean` parameter in the `preprocess` method. diff --git a/tests/models/detr/test_image_processing_detr.py b/tests/models/detr/test_image_processing_detr.py index 976b306115b68a..f91c520873668f 100644 --- a/tests/models/detr/test_image_processing_detr.py +++ b/tests/models/detr/test_image_processing_detr.py @@ -677,7 +677,7 @@ def test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations(self): target = {"image_id": 39769, "annotations": target} - processor = self.image_processor_list[1].from_pretrained("facebook/detr-resnet-50") + processor = self.image_processor_list[1]() # 1. run processor on CPU encoding_cpu = processor(images=image, annotations=target, return_tensors="pt", device="cpu") # 2. run processor on GPU @@ -734,7 +734,7 @@ def test_fast_processor_equivalence_cpu_gpu_coco_panoptic_annotations(self): masks_path = pathlib.Path("./tests/fixtures/tests_samples/COCO/coco_panoptic") - processor = self.image_processor_list[1].from_pretrained("facebook/detr-resnet-50-panoptic") + processor = self.image_processor_list[1](format="coco_panoptic") # 1. run processor on CPU encoding_cpu = processor( images=image, annotations=target, masks_path=masks_path, return_tensors="pt", device="cpu" diff --git a/tests/models/rt_detr/test_image_processing_rt_detr.py b/tests/models/rt_detr/test_image_processing_rt_detr.py index 00a2d57c34051a..e7bfbae3f9c27a 100644 --- a/tests/models/rt_detr/test_image_processing_rt_detr.py +++ b/tests/models/rt_detr/test_image_processing_rt_detr.py @@ -16,7 +16,7 @@ import requests -from transformers.testing_utils import require_torch, require_vision, slow +from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -371,3 +371,59 @@ def test_batched_coco_detection_annotations(self): ).T self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"], expected_boxes_0, rtol=1)) self.assertTrue(torch.allclose(encoding["labels"][1]["boxes"], expected_boxes_1, rtol=1)) + + @slow + @require_torch_gpu + # Copied from tests.models.detr.test_image_processing_detr.DetrImageProcessingTest.test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations + def test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations(self): + # prepare image and target + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt", "r") as f: + target = json.loads(f.read()) + + target = {"image_id": 39769, "annotations": target} + + processor = self.image_processor_list[1]() + # 1. run processor on CPU + encoding_cpu = processor(images=image, annotations=target, return_tensors="pt", device="cpu") + # 2. run processor on GPU + encoding_gpu = processor(images=image, annotations=target, return_tensors="pt", device="cuda") + + # verify pixel values + self.assertEqual(encoding_cpu["pixel_values"].shape, encoding_gpu["pixel_values"].shape) + self.assertTrue( + torch.allclose( + encoding_cpu["pixel_values"][0, 0, 0, :3], + encoding_gpu["pixel_values"][0, 0, 0, :3].to("cpu"), + atol=1e-4, + ) + ) + # verify area + self.assertTrue(torch.allclose(encoding_cpu["labels"][0]["area"], encoding_gpu["labels"][0]["area"].to("cpu"))) + # verify boxes + self.assertEqual(encoding_cpu["labels"][0]["boxes"].shape, encoding_gpu["labels"][0]["boxes"].shape) + self.assertTrue( + torch.allclose( + encoding_cpu["labels"][0]["boxes"][0], encoding_gpu["labels"][0]["boxes"][0].to("cpu"), atol=1e-3 + ) + ) + # verify image_id + self.assertTrue( + torch.allclose(encoding_cpu["labels"][0]["image_id"], encoding_gpu["labels"][0]["image_id"].to("cpu")) + ) + # verify is_crowd + self.assertTrue( + torch.allclose(encoding_cpu["labels"][0]["iscrowd"], encoding_gpu["labels"][0]["iscrowd"].to("cpu")) + ) + # verify class_labels + self.assertTrue( + torch.allclose( + encoding_cpu["labels"][0]["class_labels"], encoding_gpu["labels"][0]["class_labels"].to("cpu") + ) + ) + # verify orig_size + self.assertTrue( + torch.allclose(encoding_cpu["labels"][0]["orig_size"], encoding_gpu["labels"][0]["orig_size"].to("cpu")) + ) + # verify size + self.assertTrue(torch.allclose(encoding_cpu["labels"][0]["size"], encoding_gpu["labels"][0]["size"].to("cpu"))) From bf21b518278df25c18a3ad20e011b1b2858e751a Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Wed, 23 Oct 2024 17:29:15 +0000 Subject: [PATCH 3/8] remove prints --- tests/test_image_processing_common.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py index 3e9579b2e59841..7d89b43ce35ba4 100644 --- a/tests/test_image_processing_common.py +++ b/tests/test_image_processing_common.py @@ -283,8 +283,6 @@ def test_save_load_fast_slow_auto(self): image_processor_slow_1 = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=False) self.assertEqual(image_processor_slow_0.to_dict(), image_processor_slow_1.to_dict()) - print(image_processor_fast_0.to_dict()) - print(image_processor_fast_1.to_dict()) self.assertEqual(image_processor_fast_0.to_dict(), image_processor_fast_1.to_dict()) def test_init_without_params(self): From a2c957728e4b5a77342d2a0af2a64016cff9e7ca Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Wed, 23 Oct 2024 17:34:31 +0000 Subject: [PATCH 4/8] add to doc --- docs/source/en/model_doc/rt_detr.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/rt_detr.md b/docs/source/en/model_doc/rt_detr.md index 5540266c6215de..8ad220dc4bd113 100644 --- a/docs/source/en/model_doc/rt_detr.md +++ b/docs/source/en/model_doc/rt_detr.md @@ -46,7 +46,7 @@ Initially, an image is processed using a pre-trained convolutional neural networ >>> from PIL import Image >>> from transformers import RTDetrForObjectDetection, RTDetrImageProcessor ->>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' +>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' >>> image = Image.open(requests.get(url, stream=True).raw) >>> image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd") @@ -95,6 +95,12 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h - preprocess - post_process_object_detection +## RTDetrImageProcessorFast + +[[autodoc]] RTDetrImageProcessorFast + - preprocess + - post_process_object_detection + ## RTDetrModel [[autodoc]] RTDetrModel From 827d1c20fd134e0d7a4a33e6e4b09d5292ae50d3 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Thu, 24 Oct 2024 15:49:53 +0000 Subject: [PATCH 5/8] nit docstring --- .../models/rt_detr/image_processing_rt_detr_fast.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py b/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py index e0a353c58e3bc2..47cfde0f4d7e18 100644 --- a/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py +++ b/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py @@ -141,7 +141,7 @@ def prepare_coco_detection_annotation( input_data_format: Optional[Union[ChannelDimension, str]] = None, ): """ - Convert the target in COCO format into the format expected by RTDETR. + Convert the target in COCO format into the format expected by RT-DETR. """ image_height, image_width = image.size()[-2:] @@ -222,8 +222,6 @@ class RTDetrImageProcessorFast(BaseImageProcessorFast): rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the `preprocess` method. - Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the - `preprocess` method. do_normalize (`bool`, *optional*, defaults to `False`): Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` method. From 64b244964210542e3914db4375ab0f2a9920aa1c Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 29 Oct 2024 17:53:27 +0000 Subject: [PATCH 6/8] avoid iterating over images/annotations several times --- .../image_processing_utils_fast.py | 67 ++++- .../models/detr/image_processing_detr_fast.py | 208 ++++++---------- .../rt_detr/image_processing_rt_detr.py | 17 +- .../rt_detr/image_processing_rt_detr_fast.py | 234 +++++++----------- 4 files changed, 241 insertions(+), 285 deletions(-) diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index d1a08132d73d89..ff482abbf19dfc 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -15,14 +15,18 @@ import functools from dataclasses import dataclass +from typing import Any, Iterable, List, Optional, Tuple from .image_processing_utils import BaseImageProcessor -from .utils.import_utils import is_torchvision_available +from .utils.import_utils import is_torch_available, is_torchvision_available if is_torchvision_available(): from torchvision.transforms import Compose +if is_torch_available(): + import torch + @dataclass(frozen=True) class SizeDict: @@ -66,3 +70,64 @@ def to_dict(self): encoder_dict = super().to_dict() encoder_dict.pop("_transform_params", None) return encoder_dict + + +def get_image_size_for_max_height_width( + image_size: Tuple[int, int], + max_height: int, + max_width: int, +) -> Tuple[int, int]: + """ + Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio. + Important, even if image_height < max_height and image_width < max_width, the image will be resized + to at least one of the edges be equal to max_height or max_width. + + For example: + - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50) + - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400) + + Args: + image_size (`Tuple[int, int]`): + The image to resize. + max_height (`int`): + The maximum allowed height. + max_width (`int`): + The maximum allowed width. + """ + height, width = image_size + height_scale = max_height / height + width_scale = max_width / width + min_scale = min(height_scale, width_scale) + new_height = int(height * min_scale) + new_width = int(width * min_scale) + return new_height, new_width + + +def safe_squeeze(tensor: torch.Tensor, axis: Optional[int] = None) -> torch.Tensor: + """ + Squeezes a tensor, but only if the axis specified has dim 1. + """ + if axis is None: + return tensor.squeeze() + + try: + return tensor.squeeze(axis=axis) + except ValueError: + return tensor + + +def max_across_indices(values: Iterable[Any]) -> List[Any]: + """ + Return the maximum value across all indices of an iterable of values. + """ + return [max(values_i) for values_i in zip(*values)] + + +def get_max_height_width(images: List[torch.Tensor]) -> Tuple[int]: + """ + Get the maximum height and width across all images in a batch. + """ + + _, max_height, max_width = max_across_indices([img.shape for img in images]) + + return (max_height, max_width) diff --git a/src/transformers/models/detr/image_processing_detr_fast.py b/src/transformers/models/detr/image_processing_detr_fast.py index 0fa1d0ffd9dba9..eadde59e55e475 100644 --- a/src/transformers/models/detr/image_processing_detr_fast.py +++ b/src/transformers/models/detr/image_processing_detr_fast.py @@ -21,7 +21,13 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union from ...image_processing_utils import BatchFeature, get_size_dict -from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + SizeDict, + get_image_size_for_max_height_width, + get_max_height_width, + safe_squeeze, +) from ...image_transforms import ( center_to_corners_format, corners_to_center_format, @@ -55,7 +61,6 @@ compute_segments, convert_segmentation_to_rle, get_size_with_aspect_ratio, - max_across_indices, remove_low_and_no_objects, ) @@ -85,60 +90,6 @@ SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC) -def get_image_size_for_max_height_width( - image_size: Tuple[int, int], - max_height: int, - max_width: int, -) -> Tuple[int, int]: - """ - Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio. - Important, even if image_height < max_height and image_width < max_width, the image will be resized - to at least one of the edges be equal to max_height or max_width. - - For example: - - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50) - - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400) - - Args: - image_size (`Tuple[int, int]`): - The image to resize. - max_height (`int`): - The maximum allowed height. - max_width (`int`): - The maximum allowed width. - """ - height, width = image_size - height_scale = max_height / height - width_scale = max_width / width - min_scale = min(height_scale, width_scale) - new_height = int(height * min_scale) - new_width = int(width * min_scale) - return new_height, new_width - - -def safe_squeeze(tensor: torch.Tensor, axis: Optional[int] = None) -> torch.Tensor: - """ - Squeezes a tensor, but only if the axis specified has dim 1. - """ - if axis is None: - return tensor.squeeze() - - try: - return tensor.squeeze(axis=axis) - except ValueError: - return tensor - - -def get_max_height_width(images: List[torch.Tensor]) -> Tuple[int]: - """ - Get the maximum height and width across all images in a batch. - """ - - _, max_height, max_width = max_across_indices([img.shape for img in images]) - - return (max_height, max_width) - - # inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L33 def convert_coco_poly_to_mask(segmentations, height: int, width: int, device: torch.device) -> torch.Tensor: """ @@ -191,18 +142,21 @@ def prepare_coco_detection_annotation( # Get all COCO annotations for the given image. annotations = target["annotations"] - annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0] + classes = [] + area = [] + boxes = [] + keypoints = [] + for obj in annotations: + if "iscrowd" not in obj or obj["iscrowd"] == 0: + classes.append(obj["category_id"]) + area.append(obj["area"]) + boxes.append(obj["bbox"]) + if "keypoints" in obj: + keypoints.append(obj["keypoints"]) - classes = [obj["category_id"] for obj in annotations] classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device) - - # for conversion to coco api - area = torch.as_tensor([obj["area"] for obj in annotations], dtype=torch.float32, device=image.device) - iscrowd = torch.as_tensor( - [obj["iscrowd"] if "iscrowd" in obj else 0 for obj in annotations], dtype=torch.int64, device=image.device - ) - - boxes = [obj["bbox"] for obj in annotations] + area = torch.as_tensor(area, dtype=torch.float32, device=image.device) + iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device) # guard against no boxes via resizing boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4) boxes[:, 2:] += boxes[:, :2] @@ -211,19 +165,16 @@ def prepare_coco_detection_annotation( keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) - new_target = {} - new_target["image_id"] = image_id - new_target["class_labels"] = classes[keep] - new_target["boxes"] = boxes[keep] - new_target["area"] = area[keep] - new_target["iscrowd"] = iscrowd[keep] - new_target["orig_size"] = torch.as_tensor( - [int(image_height), int(image_width)], dtype=torch.int64, device=image.device - ) + new_target = { + "image_id": image_id, + "class_labels": classes[keep], + "boxes": boxes[keep], + "area": area[keep], + "iscrowd": iscrowd[keep], + "orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device), + } - if annotations and "keypoints" in annotations[0]: - keypoints = [obj["keypoints"] for obj in annotations] - # Converting the filtered keypoints list to a numpy array + if keypoints: keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device) # Apply the keep mask here to filter the relevant annotations keypoints = keypoints[keep] @@ -911,84 +862,81 @@ def preprocess( if input_data_format == ChannelDimension.LAST: images = [image.permute(2, 0, 1).contiguous() for image in images] - # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) - if annotations is not None: - prepared_images = [] - prepared_annotations = [] - for image, target in zip(images, annotations): - target = self.prepare_annotation( + if do_rescale and do_normalize: + # fused rescale and normalize + new_mean = torch.tensor(image_mean, device=images[0].device) * (1.0 / rescale_factor) + new_std = torch.tensor(image_std, device=images[0].device) * (1.0 / rescale_factor) + + processed_images = [] + processed_annotations = [] + pixel_masks = [] # Initialize pixel_masks here + for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)): + # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) + if annotations is not None: + annotation = self.prepare_annotation( image, - target, + annotation, format, return_segmentation_masks=return_segmentation_masks, masks_path=masks_path, input_data_format=input_data_format, ) - prepared_images.append(image) - prepared_annotations.append(target) - images = prepared_images - annotations = prepared_annotations - del prepared_images, prepared_annotations - - if do_resize: - if isinstance(resample, (PILImageResampling, int)): - interpolation = pil_torch_interpolation_mapping[resample] - else: - interpolation = resample - resized_images = [self.resize(image, size=size, interpolation=interpolation) for image in images] - if annotations is not None: - for i, (image, target) in enumerate(zip(resized_images, annotations)): - annotations[i] = self.resize_annotation( - target, - orig_size=images[i].size()[-2:], - target_size=image.size()[-2:], + + if do_resize: + interpolation = ( + pil_torch_interpolation_mapping[resample] + if isinstance(resample, (PILImageResampling, int)) + else resample + ) + resized_image = self.resize(image, size=size, interpolation=interpolation) + if annotations is not None: + annotation = self.resize_annotation( + annotation, + orig_size=image.size()[-2:], + target_size=resized_image.size()[-2:], ) - images = resized_images - del resized_images + image = resized_image - if do_rescale and do_normalize: - # fused rescale and normalize - new_mean = torch.tensor(image_mean, device=images[0].device) * (1.0 / rescale_factor) - new_std = torch.tensor(image_std, device=images[0].device) * (1.0 / rescale_factor) - images = [F.normalize(image.to(dtype=torch.float32), new_mean, new_std) for image in images] - elif do_rescale: - images = [image * rescale_factor for image in images] - elif do_normalize: - images = [F.normalize(image, image_mean, image_std) for image in images] - - if do_convert_annotations and annotations is not None: - annotations = [ - self.normalize_annotation(annotation, get_image_size(image, input_data_format)) - for annotation, image in zip(annotations, images) - ] + if do_rescale and do_normalize: + # fused rescale and normalize + image = F.normalize(image.to(dtype=torch.float32), new_mean, new_std) + elif do_rescale: + image = image * rescale_factor + elif do_normalize: + image = F.normalize(image, image_mean, image_std) + + if do_convert_annotations and annotations is not None: + annotation = self.normalize_annotation(annotation, get_image_size(image, input_data_format)) + + processed_images.append(image) + processed_annotations.append(annotation) + images = processed_images + annotations = processed_annotations if annotations is not None else None if do_pad: - # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} + # depends on all resized image shapes so we need another loop if pad_size is not None: padded_size = (pad_size["height"], pad_size["width"]) else: padded_size = get_max_height_width(images) - annotation_list = annotations if annotations is not None else [None] * len(images) padded_images = [] - pixel_masks = [] padded_annotations = [] - for image, annotation in zip(images, annotation_list): + for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)): + # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} if padded_size == image.size()[-2:]: padded_images.append(image) pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device)) padded_annotations.append(annotation) continue - padded_image, pixel_mask, padded_annotation = self.pad( + image, pixel_mask, annotation = self.pad( image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations ) - padded_images.append(padded_image) + padded_images.append(image) + padded_annotations.append(annotation) pixel_masks.append(pixel_mask) - padded_annotations.append(padded_annotation) images = padded_images - if annotations is not None: - annotations = padded_annotations - del padded_images, padded_annotations + annotations = padded_annotations if annotations is not None else None data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)}) data.update({"pixel_values": torch.stack(images, dim=0)}) diff --git a/src/transformers/models/rt_detr/image_processing_rt_detr.py b/src/transformers/models/rt_detr/image_processing_rt_detr.py index 44b2702aa634bc..eead5b18693d2f 100644 --- a/src/transformers/models/rt_detr/image_processing_rt_detr.py +++ b/src/transformers/models/rt_detr/image_processing_rt_detr.py @@ -1062,10 +1062,8 @@ def post_process_object_detection( raise ValueError( "Make sure that you pass in as many target sizes as the batch dimension of the logits" ) - if isinstance(target_sizes, List): - img_h = torch.Tensor([i[0] for i in target_sizes]) - img_w = torch.Tensor([i[1] for i in target_sizes]) + img_h, img_w = torch.as_tensor(target_sizes).unbind(1) else: img_h, img_w = target_sizes.unbind(1) scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) @@ -1089,10 +1087,13 @@ def post_process_object_detection( boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) results = [] - for s, l, b in zip(scores, labels, boxes): - score = s[s > threshold] - label = l[s > threshold] - box = b[s > threshold] - results.append({"scores": score, "labels": label, "boxes": box}) + for score, label, box in zip(scores, labels, boxes): + results.append( + { + "scores": score[score > threshold], + "labels": label[score > threshold], + "boxes": box[score > threshold], + } + ) return results diff --git a/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py b/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py index 47cfde0f4d7e18..9f63b5b7ced467 100644 --- a/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py +++ b/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py @@ -19,7 +19,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union from ...image_processing_utils import BatchFeature, get_size_dict -from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + SizeDict, + get_image_size_for_max_height_width, + get_max_height_width, + safe_squeeze, +) from ...image_transforms import ( center_to_corners_format, corners_to_center_format, @@ -46,22 +52,17 @@ is_torch_available, is_torchvision_available, is_torchvision_v2_available, - is_vision_available, logging, requires_backends, ) from .image_processing_rt_detr import ( get_size_with_aspect_ratio, - max_across_indices, ) if is_torch_available(): import torch -if is_vision_available(): - pass - if is_torchvision_available(): from ...image_utils import pil_torch_interpolation_mapping @@ -77,63 +78,6 @@ SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION,) -# Copied from transformers.models.detr.image_processing_detr_fast.get_image_size_for_max_height_width -def get_image_size_for_max_height_width( - image_size: Tuple[int, int], - max_height: int, - max_width: int, -) -> Tuple[int, int]: - """ - Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio. - Important, even if image_height < max_height and image_width < max_width, the image will be resized - to at least one of the edges be equal to max_height or max_width. - - For example: - - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50) - - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400) - - Args: - image_size (`Tuple[int, int]`): - The image to resize. - max_height (`int`): - The maximum allowed height. - max_width (`int`): - The maximum allowed width. - """ - height, width = image_size - height_scale = max_height / height - width_scale = max_width / width - min_scale = min(height_scale, width_scale) - new_height = int(height * min_scale) - new_width = int(width * min_scale) - return new_height, new_width - - -# Copied from transformers.models.detr.image_processing_detr_fast.safe_squeeze -def safe_squeeze(tensor: torch.Tensor, axis: Optional[int] = None) -> torch.Tensor: - """ - Squeezes a tensor, but only if the axis specified has dim 1. - """ - if axis is None: - return tensor.squeeze() - - try: - return tensor.squeeze(axis=axis) - except ValueError: - return tensor - - -# Copied from transformers.models.detr.image_processing_detr_fast.get_max_height_width -def get_max_height_width(images: List[torch.Tensor]) -> Tuple[int]: - """ - Get the maximum height and width across all images in a batch. - """ - - _, max_height, max_width = max_across_indices([img.shape for img in images]) - - return (max_height, max_width) - - def prepare_coco_detection_annotation( image, target, @@ -150,18 +94,21 @@ def prepare_coco_detection_annotation( # Get all COCO annotations for the given image. annotations = target["annotations"] - annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0] + classes = [] + area = [] + boxes = [] + keypoints = [] + for obj in annotations: + if "iscrowd" not in obj or obj["iscrowd"] == 0: + classes.append(obj["category_id"]) + area.append(obj["area"]) + boxes.append(obj["bbox"]) + if "keypoints" in obj: + keypoints.append(obj["keypoints"]) - classes = [obj["category_id"] for obj in annotations] classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device) - - # for conversion to coco api - area = torch.as_tensor([obj["area"] for obj in annotations], dtype=torch.float32, device=image.device) - iscrowd = torch.as_tensor( - [obj["iscrowd"] if "iscrowd" in obj else 0 for obj in annotations], dtype=torch.int64, device=image.device - ) - - boxes = [obj["bbox"] for obj in annotations] + area = torch.as_tensor(area, dtype=torch.float32, device=image.device) + iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device) # guard against no boxes via resizing boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4) boxes[:, 2:] += boxes[:, :2] @@ -170,19 +117,16 @@ def prepare_coco_detection_annotation( keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) - new_target = {} - new_target["image_id"] = image_id - new_target["class_labels"] = classes[keep] - new_target["boxes"] = boxes[keep] - new_target["area"] = area[keep] - new_target["iscrowd"] = iscrowd[keep] - new_target["orig_size"] = torch.as_tensor( - [int(image_height), int(image_width)], dtype=torch.int64, device=image.device - ) - - if annotations and "keypoints" in annotations[0]: - keypoints = [obj["keypoints"] for obj in annotations] - # Converting the filtered keypoints list to a numpy array + new_target = { + "image_id": image_id, + "class_labels": classes[keep], + "boxes": boxes[keep], + "area": area[keep], + "iscrowd": iscrowd[keep], + "orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device), + } + + if keypoints: keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device) # Apply the keep mask here to filter the relevant annotations keypoints = keypoints[keep] @@ -695,84 +639,81 @@ def preprocess( if input_data_format == ChannelDimension.LAST: images = [image.permute(2, 0, 1).contiguous() for image in images] - # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) - if annotations is not None: - prepared_images = [] - prepared_annotations = [] - for image, target in zip(images, annotations): - target = self.prepare_annotation( + if do_rescale and do_normalize: + # fused rescale and normalize + new_mean = torch.tensor(image_mean, device=images[0].device) * (1.0 / rescale_factor) + new_std = torch.tensor(image_std, device=images[0].device) * (1.0 / rescale_factor) + + processed_images = [] + processed_annotations = [] + pixel_masks = [] # Initialize pixel_masks here + for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)): + # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) + if annotations is not None: + annotation = self.prepare_annotation( image, - target, + annotation, format, return_segmentation_masks=return_segmentation_masks, masks_path=masks_path, input_data_format=input_data_format, ) - prepared_images.append(image) - prepared_annotations.append(target) - images = prepared_images - annotations = prepared_annotations - del prepared_images, prepared_annotations - - if do_resize: - if isinstance(resample, (PILImageResampling, int)): - interpolation = pil_torch_interpolation_mapping[resample] - else: - interpolation = resample - resized_images = [self.resize(image, size=size, interpolation=interpolation) for image in images] - if annotations is not None: - for i, (image, target) in enumerate(zip(resized_images, annotations)): - annotations[i] = self.resize_annotation( - target, - orig_size=images[i].size()[-2:], - target_size=image.size()[-2:], + + if do_resize: + interpolation = ( + pil_torch_interpolation_mapping[resample] + if isinstance(resample, (PILImageResampling, int)) + else resample + ) + resized_image = self.resize(image, size=size, interpolation=interpolation) + if annotations is not None: + annotation = self.resize_annotation( + annotation, + orig_size=image.size()[-2:], + target_size=resized_image.size()[-2:], ) - images = resized_images - del resized_images + image = resized_image - if do_rescale and do_normalize: - # fused rescale and normalize - new_mean = torch.tensor(image_mean, device=images[0].device) * (1.0 / rescale_factor) - new_std = torch.tensor(image_std, device=images[0].device) * (1.0 / rescale_factor) - images = [F.normalize(image.to(dtype=torch.float32), new_mean, new_std) for image in images] - elif do_rescale: - images = [image * rescale_factor for image in images] - elif do_normalize: - images = [F.normalize(image, image_mean, image_std) for image in images] - - if do_convert_annotations and annotations is not None: - annotations = [ - self.normalize_annotation(annotation, get_image_size(image, input_data_format)) - for annotation, image in zip(annotations, images) - ] + if do_rescale and do_normalize: + # fused rescale and normalize + image = F.normalize(image.to(dtype=torch.float32), new_mean, new_std) + elif do_rescale: + image = image * rescale_factor + elif do_normalize: + image = F.normalize(image, image_mean, image_std) + + if do_convert_annotations and annotations is not None: + annotation = self.normalize_annotation(annotation, get_image_size(image, input_data_format)) + + processed_images.append(image) + processed_annotations.append(annotation) + images = processed_images + annotations = processed_annotations if annotations is not None else None if do_pad: - # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} + # depends on all resized image shapes so we need another loop if pad_size is not None: padded_size = (pad_size["height"], pad_size["width"]) else: padded_size = get_max_height_width(images) - annotation_list = annotations if annotations is not None else [None] * len(images) padded_images = [] - pixel_masks = [] padded_annotations = [] - for image, annotation in zip(images, annotation_list): + for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)): + # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} if padded_size == image.size()[-2:]: padded_images.append(image) pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device)) padded_annotations.append(annotation) continue - padded_image, pixel_mask, padded_annotation = self.pad( + image, pixel_mask, annotation = self.pad( image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations ) - padded_images.append(padded_image) + padded_images.append(image) + padded_annotations.append(annotation) pixel_masks.append(pixel_mask) - padded_annotations.append(padded_annotation) images = padded_images - if annotations is not None: - annotations = padded_annotations - del padded_images, padded_annotations + annotations = padded_annotations if annotations is not None else None data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)}) data.update({"pixel_values": torch.stack(images, dim=0)}) @@ -820,10 +761,8 @@ def post_process_object_detection( raise ValueError( "Make sure that you pass in as many target sizes as the batch dimension of the logits" ) - if isinstance(target_sizes, List): - img_h = torch.Tensor([i[0] for i in target_sizes]) - img_w = torch.Tensor([i[1] for i in target_sizes]) + img_h, img_w = torch.as_tensor(target_sizes).unbind(1) else: img_h, img_w = target_sizes.unbind(1) scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) @@ -847,10 +786,13 @@ def post_process_object_detection( boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) results = [] - for s, l, b in zip(scores, labels, boxes): - score = s[s > threshold] - label = l[s > threshold] - box = b[s > threshold] - results.append({"scores": score, "labels": label, "boxes": box}) + for score, label, box in zip(scores, labels, boxes): + results.append( + { + "scores": score[score > threshold], + "labels": label[score > threshold], + "boxes": box[score > threshold], + } + ) return results From 29608598086f6e6e6a742340c43bfc0713ad769f Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 29 Oct 2024 18:07:49 +0000 Subject: [PATCH 7/8] change torch typing --- src/transformers/image_processing_utils_fast.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index ff482abbf19dfc..3c1be325b7eb30 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -103,7 +103,7 @@ def get_image_size_for_max_height_width( return new_height, new_width -def safe_squeeze(tensor: torch.Tensor, axis: Optional[int] = None) -> torch.Tensor: +def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor": """ Squeezes a tensor, but only if the axis specified has dim 1. """ @@ -123,7 +123,7 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]: return [max(values_i) for values_i in zip(*values)] -def get_max_height_width(images: List[torch.Tensor]) -> Tuple[int]: +def get_max_height_width(images: List["torch.Tensor"]) -> Tuple[int]: """ Get the maximum height and width across all images in a batch. """ From 0b18cf39e5556a3e1c6e42e9caca1c1afb3f5d3d Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 29 Oct 2024 19:26:08 +0000 Subject: [PATCH 8/8] Add image processor fast documentation --- .../source/en/main_classes/image_processor.md | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/docs/source/en/main_classes/image_processor.md b/docs/source/en/main_classes/image_processor.md index 59a78e68214d6d..320916f1ce9421 100644 --- a/docs/source/en/main_classes/image_processor.md +++ b/docs/source/en/main_classes/image_processor.md @@ -18,6 +18,49 @@ rendered properly in your Markdown viewer. An image processor is in charge of preparing input features for vision models and post processing their outputs. This includes transformations such as resizing, normalization, and conversion to PyTorch, TensorFlow, Flax and Numpy tensors. It may also include model specific post-processing such as converting logits to segmentation masks. +Fast image processors are available for a few models and more will be added in the future. They are based on the [torchvision](https://pytorch.org/vision/stable/index.html) library and provide a significant speed-up, especially when processing on GPU. +They have the same API as the base image processors and can be used as drop-in replacements. +To use a fast image processor, you need to install the `torchvision` library, and set the `use_fast` argument to `True` when instantiating the image processor: + +```python +from transformers import AutoImageProcessor + +processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50", use_fast=True) +``` + +When using a fast image processor, you can also set the `device` argument to specify the device on which the processing should be done. By default, the processing is done on the same device as the inputs if the inputs are tensors, or on the CPU otherwise. + +```python +from torchvision.io import read_image +from transformers import DetrImageProcessorFast + +images = read_image("image.jpg") +processor = DetrImageProcessorFast.from_pretrained("facebook/detr-resnet-50") +images_processed = processor(images, return_tensors="pt", device="cuda") +``` + +Here are some speed comparisons between the base and fast image processors for the `DETR` and `RT-DETR` models, and how they impact overall inference time: + +
+
+ +
+
+ +
+
+ +
+
+ +
+
+ +
+
+ +These benchmarks were run on an [AWS EC2 g5.2xlarge instance](https://aws.amazon.com/ec2/instance-types/g5/), utilizing an NVIDIA A10G Tensor Core GPU. + ## ImageProcessingMixin