diff --git a/pytorch3d/implicitron/dataset/blob_loader.py b/pytorch3d/implicitron/dataset/blob_loader.py new file mode 100644 index 000000000..83f39c78e --- /dev/null +++ b/pytorch3d/implicitron/dataset/blob_loader.py @@ -0,0 +1,349 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image + +from pytorch3d.implicitron.dataset import types +from pytorch3d.implicitron.dataset.dataset_base import FrameData +from pytorch3d.implicitron.dataset.utils import _get_bbox_from_mask +from pytorch3d.io import IO +from pytorch3d.renderer.cameras import PerspectiveCameras +from pytorch3d.structures.pointclouds import Pointclouds + + +@dataclass +class BlobLoader: + """ + A loader for correctly (according to setup) loading blobs for FrameData. + Beware that modification done in place + + Args: + dataset_root: The root folder of the dataset; all the paths in jsons are + specified relative to this root (but not json paths themselves). + load_images: Enable loading the frame RGB data. + load_depths: Enable loading the frame depth maps. + load_depth_masks: Enable loading the frame depth map masks denoting the + depth values used for evaluation (the points consistent across views). + load_masks: Enable loading frame foreground masks. + load_point_clouds: Enable loading sequence-level point clouds. + max_points: Cap on the number of loaded points in the point cloud; + if reached, they are randomly sampled without replacement. + mask_images: Whether to mask the images with the loaded foreground masks; + 0 value is used for background. + mask_depths: Whether to mask the depth maps with the loaded foreground + masks; 0 value is used for background. + image_height: The height of the returned images, masks, and depth maps; + aspect ratio is preserved during cropping/resizing. + image_width: The width of the returned images, masks, and depth maps; + aspect ratio is preserved during cropping/resizing. + box_crop: Enable cropping of the image around the bounding box inferred + from the foreground region of the loaded segmentation mask; masks + and depth maps are cropped accordingly; cameras are corrected. + box_crop_mask_thr: The threshold used to separate pixels into foreground + and background based on the foreground_probability mask; if no value + is greater than this threshold, the loader lowers it and repeats. + box_crop_context: The amount of additional padding added to each + dimension of the cropping bounding box, relative to box size. + """ + + dataset_root: str = "" + load_images: bool = True + load_depths: bool = True + load_depth_masks: bool = True + load_masks: bool = True + load_point_clouds: bool = False + max_points: int = 0 + mask_images: bool = False + mask_depths: bool = False + image_height: Optional[int] = 800 + image_width: Optional[int] = 800 + box_crop: bool = True + box_crop_mask_thr: float = 0.4 + box_crop_context: float = 0.3 + path_manager: Any = None + + def load_( + self, + frame_data: FrameData, + entry: types.FrameAnnotation, + seq_annotation: types.SequenceAnnotation, + bbox_xywh: Optional[torch.Tensor] = None, + ) -> FrameData: + """Main method for loader. + FrameData modification done inplace + if bbox_xywh not provided bbox will be calculated from mask + """ + ( + frame_data.fg_probability, + frame_data.mask_path, + frame_data.bbox_xywh, + ) = self._load_fg_probability(entry, bbox_xywh) + + if self.load_images and entry.image is not None: + # original image size + frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long) + + ( + frame_data.image_rgb, + frame_data.image_path, + ) = self._load_images(entry, frame_data.fg_probability) + + if self.load_depths and entry.depth is not None: + ( + frame_data.depth_map, + frame_data.depth_path, + frame_data.depth_mask, + ) = self._load_mask_depth(entry, frame_data.fg_probability) + + if self.load_point_clouds and seq_annotation.point_cloud is not None: + pcl_path = self._fix_point_cloud_path(seq_annotation.point_cloud.path) + frame_data.sequence_point_cloud = _load_pointcloud( + self._local_path(pcl_path), max_points=self.max_points + ) + frame_data.sequence_point_cloud_path = pcl_path + + clamp_bbox_xyxy = None + if self.box_crop: + clamp_bbox_xyxy = frame_data.crop_by_bbox_(self.box_crop_context) + + scale = ( + min( + self.image_height / entry.image.size[0], + # pyre-ignore + self.image_width / entry.image.size[1], + ) + if self.image_height is not None and self.image_width is not None + else 1.0 + ) + + if self.image_height is not None and self.image_width is not None: + optional_scale = frame_data.resize_frame_( + self.image_height, self.image_width + ) + scale = optional_scale or scale + + # creating camera taking to account bbox and resize scale + if entry.viewpoint is not None: + frame_data.camera = self._get_pytorch3d_camera( + entry, scale, clamp_bbox_xyxy + ) + return frame_data + + def _load_fg_probability( + self, + entry: types.FrameAnnotation, + bbox_xywh: Optional[torch.Tensor], + ) -> Tuple[Optional[torch.Tensor], Optional[str], Optional[torch.Tensor]]: + fg_probability = None + full_path = None + + if (self.load_masks) and entry.mask is not None: + full_path = os.path.join(self.dataset_root, entry.mask.path) + fg_probability = _load_mask(self._local_path(full_path)) + # we can use provided bbox_xywh or calculate it based on mask + if bbox_xywh is None: + bbox_xywh = _get_bbox_from_mask(fg_probability, self.box_crop_mask_thr) + if fg_probability.shape[-2:] != entry.image.size: + raise ValueError( + f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!" + ) + + return ( + _safe_as_tensor(fg_probability, torch.float), + full_path, + _safe_as_tensor(bbox_xywh, torch.long), + ) + + def _load_images( + self, + entry: types.FrameAnnotation, + fg_probability: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, str]: + assert self.dataset_root is not None and entry.image is not None + path = os.path.join(self.dataset_root, entry.image.path) + image_rgb = _load_image(self._local_path(path)) + + if image_rgb.shape[-2:] != entry.image.size: + raise ValueError( + f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!" + ) + + if self.mask_images: + assert fg_probability is not None + image_rgb *= fg_probability + + return image_rgb, path + + def _load_mask_depth( + self, + entry: types.FrameAnnotation, + fg_probability: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, str, torch.Tensor]: + entry_depth = entry.depth + assert entry_depth is not None + path = os.path.join(self.dataset_root, entry_depth.path) + depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment) + + if self.mask_depths: + assert fg_probability is not None + depth_map *= fg_probability + + if self.load_depth_masks: + assert entry_depth.mask_path is not None + mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) + depth_mask = _load_depth_mask(self._local_path(mask_path)) + else: + depth_mask = torch.ones_like(depth_map) + + return torch.tensor(depth_map), path, torch.tensor(depth_mask) + + def _get_pytorch3d_camera( + self, + entry: types.FrameAnnotation, + scale: float, + clamp_bbox_xyxy: Optional[torch.Tensor], + ) -> PerspectiveCameras: + entry_viewpoint = entry.viewpoint + assert entry_viewpoint is not None + # principal point and focal length + principal_point = torch.tensor( + entry_viewpoint.principal_point, dtype=torch.float + ) + focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float) + + half_image_size_wh_orig = ( + torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0 + ) + + # first, we convert from the dataset's NDC convention to pixels + format = entry_viewpoint.intrinsics_format + if format.lower() == "ndc_norm_image_bounds": + # this is e.g. currently used in CO3D for storing intrinsics + rescale = half_image_size_wh_orig + elif format.lower() == "ndc_isotropic": + rescale = half_image_size_wh_orig.min() + else: + raise ValueError(f"Unknown intrinsics format: {format}") + + # principal point and focal length in pixels + principal_point_px = half_image_size_wh_orig - principal_point * rescale + focal_length_px = focal_length * rescale + + # changing principal_point according to bbox_crop + if clamp_bbox_xyxy is not None: + principal_point_px -= clamp_bbox_xyxy[:2] + + # now, convert from pixels to PyTorch3D v0.5+ NDC convention + if self.image_height is None or self.image_width is None: + out_size = list(reversed(entry.image.size)) + else: + out_size = [self.image_width, self.image_height] + + half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0 + half_min_image_size_output = half_image_size_output.min() + + # rescaled principal point and focal length in ndc + principal_point = ( + half_image_size_output - principal_point_px * scale + ) / half_min_image_size_output + focal_length = focal_length_px * scale / half_min_image_size_output + + return PerspectiveCameras( + focal_length=focal_length[None], + principal_point=principal_point[None], + R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None], + T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None], + ) + + def _fix_point_cloud_path(self, path: str) -> str: + """ + Fix up a point cloud path from the dataset. + Some files in Co3Dv2 have an accidental absolute path stored. + """ + unwanted_prefix = ( + "/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/" + ) + if path.startswith(unwanted_prefix): + path = path[len(unwanted_prefix) :] + return os.path.join(self.dataset_root, path) + + def _local_path(self, path: str) -> str: + if self.path_manager is None: + return path + return self.path_manager.get_local_path(path) + + +def _load_image(path) -> np.ndarray: + with Image.open(path) as pil_im: + im = np.array(pil_im.convert("RGB")) + im = im.transpose((2, 0, 1)) + im = im.astype(np.float32) / 255.0 + return im + + +def _load_mask(path) -> np.ndarray: + with Image.open(path) as pil_im: + mask = np.array(pil_im) + mask = mask.astype(np.float32) / 255.0 + return mask[None] # fake feature channel + + +def _load_depth(path, scale_adjustment) -> np.ndarray: + if not path.lower().endswith(".png"): + raise ValueError('unsupported depth file name "%s"' % path) + + d = _load_16big_png_depth(path) * scale_adjustment + d[~np.isfinite(d)] = 0.0 + return d[None] # fake feature channel + + +def _load_16big_png_depth(depth_png) -> np.ndarray: + with Image.open(depth_png) as depth_pil: + # the image is stored with 16-bit depth but PIL reads it as I (32 bit). + # we cast it to uint16, then reinterpret as float16, then cast to float32 + depth = ( + np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) + .astype(np.float32) + .reshape((depth_pil.size[1], depth_pil.size[0])) + ) + return depth + + +def _load_1bit_png_mask(file: str) -> np.ndarray: + with Image.open(file) as pil_im: + mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32) + return mask + + +def _load_depth_mask(path: str) -> np.ndarray: + if not path.lower().endswith(".png"): + raise ValueError('unsupported depth mask file name "%s"' % path) + m = _load_1bit_png_mask(path) + return m[None] # fake feature channel + + +def _safe_as_tensor(data, dtype): + return torch.tensor(data, dtype=dtype) if data is not None else None + + +# NOTE this cache is per-worker; they are implemented as processes. +# each batch is loaded and collated by a single worker; +# since sequences tend to co-occur within batches, this is useful. +@functools.lru_cache(maxsize=256) +def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds: + pcl = IO().load_pointcloud(pcl_path) + if max_points > 0: + pcl = pcl.subsample(max_points) + + return pcl diff --git a/pytorch3d/implicitron/dataset/dataset_base.py b/pytorch3d/implicitron/dataset/dataset_base.py index 283ef3dcd..7c4268fb9 100644 --- a/pytorch3d/implicitron/dataset/dataset_base.py +++ b/pytorch3d/implicitron/dataset/dataset_base.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import warnings from collections import defaultdict from dataclasses import dataclass, field, fields from typing import ( @@ -23,6 +24,14 @@ import numpy as np import torch +from pytorch3d.implicitron.dataset.utils import ( + _bbox_xyxy_to_xywh, + _clamp_box_to_image_bounds_and_round, + _crop_around_box, + _get_clamp_bbox, + _rescale_bbox, + _resize_image, +) from pytorch3d.renderer.camera_utils import join_cameras_as_batch from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds @@ -90,6 +99,7 @@ class FrameData(Mapping[str, Any]): frame_type: The type of the loaded frame specified in `subset_lists_file`, if provided. meta: A dict for storing additional frame information. + cropped: Bool to avoid cropping FrameData twice """ frame_number: Optional[torch.LongTensor] @@ -116,6 +126,7 @@ class FrameData(Mapping[str, Any]): sequence_point_cloud_idx: Optional[torch.Tensor] = None frame_type: Union[str, List[str], None] = None # known | unseen meta: dict = field(default_factory=lambda: {}) + cropped: bool = False def to(self, *args, **kwargs): new_params = {} @@ -144,6 +155,109 @@ def __getitem__(self, key): def __len__(self): return len(fields(self)) + def crop_by_bbox_(self, box_crop_context) -> Optional[torch.Tensor]: + if self.cropped: + warnings.warn( + f"You called cropping on same frame twice " + f"sequence_name: {self.sequence_name}, skipping cropping" + ) + return None + + if ( + self.bbox_xywh is None + or self.fg_probability is None + or self.mask_path is None + or self.image_path is None + ): + warnings.warn( + "You called cropping without loading frame data" + "please call blob_loader.load_ first, skipping cropping" + ) + return None + + bbox_xyxy = _get_clamp_bbox( + self.bbox_xywh, + # pyre-ignore + image_path=self.image_path, + box_crop_context=box_crop_context, + ) + clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round( + bbox_xyxy, + # pyre-ignore + image_size_hw=tuple(self.image_size_hw), + ) + self.crop_bbox_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy) + + self.fg_probability = _crop_around_box( + self.fg_probability, + clamp_bbox_xyxy, + # pyre-ignore + self.mask_path, + ) + self.image_rgb = _crop_around_box( + self.image_rgb, + clamp_bbox_xyxy, + # pyre-ignore + self.image_path, + ) + + if self.depth_map is not None: + self.depth_map = _crop_around_box( + self.depth_map, + clamp_bbox_xyxy, + # pyre-ignore + self.depth_path, + ) + if self.depth_mask is not None: + self.depth_mask = _crop_around_box( + self.depth_mask, + clamp_bbox_xyxy, + # pyre-ignore + self.mask_path, + ) + self.cropped = True + return clamp_bbox_xyxy + + def resize_frame_(self, image_height, image_width) -> Optional[float]: + if self.bbox_xywh is not None: + self.bbox_xywh = _rescale_bbox( + self.bbox_xywh, + np.array(self.image_size_hw), + # pyre-ignore + self.image_rgb.shape[-2:], + ) + + scale = None + if self.image_rgb is not None: + self.image_rgb, scale, self.mask_crop = _resize_image( + self.image_rgb, image_height=image_height, image_width=image_width + ) + + if self.fg_probability is not None: + self.fg_probability, _, _ = _resize_image( + self.fg_probability, + image_height=image_height, + image_width=image_width, + mode="nearest", + ) + + if self.depth_map is not None: + self.depth_map, _, _ = _resize_image( + self.depth_map, + image_height=image_height, + image_width=image_width, + mode="nearest", + ) + + if self.depth_mask is not None: + self.depth_mask, _, _ = _resize_image( + self.depth_mask, + image_height=image_height, + image_width=image_width, + mode="nearest", + ) + return scale + @classmethod def collate(cls, batch): """ diff --git a/pytorch3d/implicitron/dataset/json_index_dataset.py b/pytorch3d/implicitron/dataset/json_index_dataset.py index 669f4e9b6..5f9b2685a 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset.py @@ -15,7 +15,6 @@ import warnings from collections import defaultdict from itertools import islice -from pathlib import Path from typing import ( Any, ClassVar, @@ -30,20 +29,18 @@ Union, ) -import numpy as np import torch -from PIL import Image + +from pytorch3d.implicitron.dataset import types +from pytorch3d.implicitron.dataset.blob_loader import BlobLoader +from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData +from pytorch3d.implicitron.dataset.utils import is_known_frame_scalar + from pytorch3d.implicitron.tools.config import registry, ReplaceableBase -from pytorch3d.io import IO from pytorch3d.renderer.camera_utils import join_cameras_as_batch -from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras -from pytorch3d.structures.pointclouds import Pointclouds +from pytorch3d.renderer.cameras import CamerasBase from tqdm import tqdm -from . import types -from .dataset_base import DatasetBase, FrameData -from .utils import is_known_frame_scalar - logger = logging.getLogger(__name__) @@ -65,7 +62,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): A dataset with annotations in json files like the Common Objects in 3D (CO3D) dataset. - Args: + Metadata-related args:: frame_annotations_file: A zipped json file containing metadata of the frames in the dataset, serialized List[types.FrameAnnotation]. sequence_annotations_file: A zipped json file containing metadata of the @@ -83,6 +80,24 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): pick_sequence: A list of sequence names to restrict the dataset to. exclude_sequence: A list of the names of the sequences to exclude. limit_category_to: Restrict the dataset to the given list of categories. + remove_empty_masks: Removes the frames with no active foreground pixels + in the segmentation mask after thresholding (see box_crop_mask_thr). + n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence + frames in each sequences uniformly without replacement if it has + more frames than that; applied before other frame-level filters. + seed: The seed of the random generator sampling #n_frames_per_sequence + random frames per sequence. + sort_frames: Enable frame annotations sorting to group frames from the + same sequences together and order them by timestamps + eval_batches: A list of batches that form the evaluation set; + list of batch-sized lists of indices corresponding to __getitem__ + of this class, thus it can be used directly as a batch sampler. + eval_batch_index: + ( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] ) + A list of batches of frames described as (sequence_name, frame_idx) + that can form the evaluation set, `eval_batches` will be set from this. + + Blob-loading parameters: dataset_root: The root folder of the dataset; all the paths in jsons are specified relative to this root (but not json paths themselves). load_images: Enable loading the frame RGB data. @@ -109,23 +124,6 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): is greater than this threshold, the loader lowers it and repeats. box_crop_context: The amount of additional padding added to each dimension of the cropping bounding box, relative to box size. - remove_empty_masks: Removes the frames with no active foreground pixels - in the segmentation mask after thresholding (see box_crop_mask_thr). - n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence - frames in each sequences uniformly without replacement if it has - more frames than that; applied before other frame-level filters. - seed: The seed of the random generator sampling #n_frames_per_sequence - random frames per sequence. - sort_frames: Enable frame annotations sorting to group frames from the - same sequences together and order them by timestamps - eval_batches: A list of batches that form the evaluation set; - list of batch-sized lists of indices corresponding to __getitem__ - of this class, thus it can be used directly as a batch sampler. - eval_batch_index: - ( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] ) - A list of batches of frames described as (sequence_name, frame_idx) - that can form the evaluation set, `eval_batches` will be set from this. - """ frame_annotations_type: ClassVar[ @@ -162,12 +160,14 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): sort_frames: bool = False eval_batches: Any = None eval_batch_index: Any = None + # initialised in __post_init__ + # commented because of OmegaConf (for tests to pass) + # blob_loader: BlobLoader = field(init=False) # frame_annots: List[FrameAnnotsEntry] = field(init=False) # seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False) + # _seq_to_idx: Dict[str, List[int]] = field(init=False) def __post_init__(self) -> None: - # pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`. - self.subset_to_image_path = None self._load_frames() self._load_sequences() if self.sort_frames: @@ -175,6 +175,24 @@ def __post_init__(self) -> None: self._load_subset_lists() self._filter_db() # also computes sequence indices self._extract_and_set_eval_batches() + + # pyre-ignore + self.blob_loader = BlobLoader( + dataset_root=self.dataset_root, + load_images=self.load_images, + load_depths=self.load_depths, + load_depth_masks=self.load_depth_masks, + load_masks=self.load_masks, + load_point_clouds=self.load_point_clouds, + max_points=self.max_points, + mask_images=self.mask_images, + mask_depths=self.mask_depths, + image_height=self.image_height, + image_width=self.image_width, + box_crop=self.box_crop, + box_crop_mask_thr=self.box_crop_mask_thr, + box_crop_context=self.box_crop_context, + ) logger.info(str(self)) def _extract_and_set_eval_batches(self): @@ -190,7 +208,8 @@ def _extract_and_set_eval_batches(self): self.eval_batch_index ) - def join(self, other_datasets: Iterable[DatasetBase]) -> None: + # pyre-ignore + def join(self, other_datasets: Iterable["JsonIndexDataset"]) -> None: """ Join the dataset with other JsonIndexDataset objects. @@ -200,19 +219,18 @@ def join(self, other_datasets: Iterable[DatasetBase]) -> None: """ if not all(isinstance(d, JsonIndexDataset) for d in other_datasets): raise ValueError("This function can only join a list of JsonIndexDataset") - # pyre-ignore[16] + # pyre-ignore self.frame_annots.extend([fa for d in other_datasets for fa in d.frame_annots]) - # pyre-ignore[16] + # pyre-ignore self.seq_annots.update( # https://gist.github.com/treyhunner/f35292e676efa0be1728 functools.reduce( lambda a, b: {**a, **b}, - [d.seq_annots for d in other_datasets], # pyre-ignore[16] + [d.seq_annots for d in other_datasets], ) ) all_eval_batches = [ self.eval_batches, - # pyre-ignore *[d.eval_batches for d in other_datasets], ] if not ( @@ -251,7 +269,7 @@ def seq_frame_index_to_dataset_index( allow_missing_indices: bool = False, remove_missing_indices: bool = False, suppress_missing_index_warning: bool = True, - ) -> List[List[Union[Optional[int], int]]]: + ) -> Union[List[List[Optional[int]]], List[List[int]]]: """ Obtain indices into the dataset object given a list of frame ids. @@ -279,11 +297,11 @@ def seq_frame_index_to_dataset_index( """ _dataset_seq_frame_n_index = { seq: { - # pyre-ignore[16] + # pyre-ignore self.frame_annots[idx]["frame_annotation"].frame_number: idx for idx in seq_idx } - # pyre-ignore[16] + # pyre-ignore for seq, seq_idx in self._seq_to_idx.items() } @@ -306,7 +324,7 @@ def _get_dataset_idx( # Check that the loaded frame path is consistent # with the one stored in self.frame_annots. assert os.path.normpath( - # pyre-ignore[16] + # pyre-ignore self.frame_annots[idx]["frame_annotation"].image.path ) == os.path.normpath( path @@ -323,9 +341,7 @@ def _get_dataset_idx( valid_dataset_idx = [ [b for b in batch if b is not None] for batch in dataset_idx ] - return [ # pyre-ignore[7] - batch for batch in valid_dataset_idx if len(batch) > 0 - ] + return [batch for batch in valid_dataset_idx if len(batch) > 0] return dataset_idx @@ -358,7 +374,8 @@ def subset_from_frame_index( # Deep copy the whole dataset except frame_annots, which are large so we # deep copy only the requested subset of frame_annots. - memo = {id(self.frame_annots): None} # pyre-ignore[16] + # pyre-ignore + memo = {id(self.frame_annots): None} dataset_new = copy.deepcopy(self, memo) dataset_new.frame_annots = copy.deepcopy( [self.frame_annots[i] for i in valid_dataset_indices] @@ -386,11 +403,11 @@ def subset_from_frame_index( return dataset_new def __str__(self) -> str: - # pyre-ignore[16] + # pyre-ignore return f"JsonIndexDataset #frames={len(self.frame_annots)}" def __len__(self) -> int: - # pyre-ignore[16] + # pyre-ignore return len(self.frame_annots) def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]: @@ -402,7 +419,7 @@ def get_all_train_cameras(self) -> CamerasBase: """ logger.info("Loading all train cameras.") cameras = [] - # pyre-ignore[16] + # pyre-ignore for frame_idx, frame_annot in enumerate(tqdm(self.frame_annots)): frame_type = self._get_frame_type(frame_annot) if frame_type is None: @@ -412,12 +429,12 @@ def get_all_train_cameras(self) -> CamerasBase: return join_cameras_as_batch(cameras) def __getitem__(self, index) -> FrameData: - # pyre-ignore[16] + # pyre-ignore if index >= len(self.frame_annots): raise IndexError(f"index {index} out of range {len(self.frame_annots)}") entry = self.frame_annots[index]["frame_annotation"] - # pyre-ignore[16] + # pyre-ignore point_cloud = self.seq_annots[entry.sequence_name].point_cloud frame_data = FrameData( frame_number=_safe_as_tensor(entry.frame_number, torch.long), @@ -435,237 +452,12 @@ def __getitem__(self, index) -> FrameData: else None, ) - # The rest of the fields are optional + # Optional field frame_data.frame_type = self._get_frame_type(self.frame_annots[index]) - - ( - frame_data.fg_probability, - frame_data.mask_path, - frame_data.bbox_xywh, - clamp_bbox_xyxy, - frame_data.crop_bbox_xywh, - ) = self._load_crop_fg_probability(entry) - - scale = 1.0 - if self.load_images and entry.image is not None: - # original image size - frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long) - - ( - frame_data.image_rgb, - frame_data.image_path, - frame_data.mask_crop, - scale, - ) = self._load_crop_images( - entry, frame_data.fg_probability, clamp_bbox_xyxy - ) - - if self.load_depths and entry.depth is not None: - ( - frame_data.depth_map, - frame_data.depth_path, - frame_data.depth_mask, - ) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability) - - if entry.viewpoint is not None: - frame_data.camera = self._get_pytorch3d_camera( - entry, - scale, - clamp_bbox_xyxy, - ) - - if self.load_point_clouds and point_cloud is not None: - pcl_path = self._fix_point_cloud_path(point_cloud.path) - frame_data.sequence_point_cloud = _load_pointcloud( - self._local_path(pcl_path), max_points=self.max_points - ) - frame_data.sequence_point_cloud_path = pcl_path - + # pyre-ignore + self.blob_loader.load_(frame_data, entry, self.seq_annots[entry.sequence_name]) return frame_data - def _fix_point_cloud_path(self, path: str) -> str: - """ - Fix up a point cloud path from the dataset. - Some files in Co3Dv2 have an accidental absolute path stored. - """ - unwanted_prefix = ( - "/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/" - ) - if path.startswith(unwanted_prefix): - path = path[len(unwanted_prefix) :] - return os.path.join(self.dataset_root, path) - - def _load_crop_fg_probability( - self, entry: types.FrameAnnotation - ) -> Tuple[ - Optional[torch.Tensor], - Optional[str], - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], - ]: - fg_probability = None - full_path = None - bbox_xywh = None - clamp_bbox_xyxy = None - crop_box_xywh = None - - if (self.load_masks or self.box_crop) and entry.mask is not None: - full_path = os.path.join(self.dataset_root, entry.mask.path) - mask = _load_mask(self._local_path(full_path)) - - if mask.shape[-2:] != entry.image.size: - raise ValueError( - f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!" - ) - - bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr)) - - if self.box_crop: - clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round( - _get_clamp_bbox( - bbox_xywh, - image_path=entry.image.path, - box_crop_context=self.box_crop_context, - ), - image_size_hw=tuple(mask.shape[-2:]), - ) - crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy) - - mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path) - - fg_probability, _, _ = self._resize_image(mask, mode="nearest") - - return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh - - def _load_crop_images( - self, - entry: types.FrameAnnotation, - fg_probability: Optional[torch.Tensor], - clamp_bbox_xyxy: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, str, torch.Tensor, float]: - assert self.dataset_root is not None and entry.image is not None - path = os.path.join(self.dataset_root, entry.image.path) - image_rgb = _load_image(self._local_path(path)) - - if image_rgb.shape[-2:] != entry.image.size: - raise ValueError( - f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!" - ) - - if self.box_crop: - assert clamp_bbox_xyxy is not None - image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path) - - image_rgb, scale, mask_crop = self._resize_image(image_rgb) - - if self.mask_images: - assert fg_probability is not None - image_rgb *= fg_probability - - return image_rgb, path, mask_crop, scale - - def _load_mask_depth( - self, - entry: types.FrameAnnotation, - clamp_bbox_xyxy: Optional[torch.Tensor], - fg_probability: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, str, torch.Tensor]: - entry_depth = entry.depth - assert entry_depth is not None - path = os.path.join(self.dataset_root, entry_depth.path) - depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment) - - if self.box_crop: - assert clamp_bbox_xyxy is not None - depth_bbox_xyxy = _rescale_bbox( - clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:] - ) - depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path) - - depth_map, _, _ = self._resize_image(depth_map, mode="nearest") - - if self.mask_depths: - assert fg_probability is not None - depth_map *= fg_probability - - if self.load_depth_masks: - assert entry_depth.mask_path is not None - mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) - depth_mask = _load_depth_mask(self._local_path(mask_path)) - - if self.box_crop: - assert clamp_bbox_xyxy is not None - depth_mask_bbox_xyxy = _rescale_bbox( - clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:] - ) - depth_mask = _crop_around_box( - depth_mask, depth_mask_bbox_xyxy, mask_path - ) - - depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest") - else: - depth_mask = torch.ones_like(depth_map) - - return depth_map, path, depth_mask - - def _get_pytorch3d_camera( - self, - entry: types.FrameAnnotation, - scale: float, - clamp_bbox_xyxy: Optional[torch.Tensor], - ) -> PerspectiveCameras: - entry_viewpoint = entry.viewpoint - assert entry_viewpoint is not None - # principal point and focal length - principal_point = torch.tensor( - entry_viewpoint.principal_point, dtype=torch.float - ) - focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float) - - half_image_size_wh_orig = ( - torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0 - ) - - # first, we convert from the dataset's NDC convention to pixels - format = entry_viewpoint.intrinsics_format - if format.lower() == "ndc_norm_image_bounds": - # this is e.g. currently used in CO3D for storing intrinsics - rescale = half_image_size_wh_orig - elif format.lower() == "ndc_isotropic": - rescale = half_image_size_wh_orig.min() - else: - raise ValueError(f"Unknown intrinsics format: {format}") - - # principal point and focal length in pixels - principal_point_px = half_image_size_wh_orig - principal_point * rescale - focal_length_px = focal_length * rescale - if self.box_crop: - assert clamp_bbox_xyxy is not None - principal_point_px -= clamp_bbox_xyxy[:2] - - # now, convert from pixels to PyTorch3D v0.5+ NDC convention - if self.image_height is None or self.image_width is None: - out_size = list(reversed(entry.image.size)) - else: - out_size = [self.image_width, self.image_height] - - half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0 - half_min_image_size_output = half_image_size_output.min() - - # rescaled principal point and focal length in ndc - principal_point = ( - half_image_size_output - principal_point_px * scale - ) / half_min_image_size_output - focal_length = focal_length_px * scale / half_min_image_size_output - - return PerspectiveCameras( - focal_length=focal_length[None], - principal_point=principal_point[None], - R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None], - T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None], - ) - def _load_frames(self) -> None: logger.info(f"Loading Co3D frames from {self.frame_annotations_file}.") local_file = self._local_path(self.frame_annotations_file) @@ -675,7 +467,7 @@ def _load_frames(self) -> None: ) if not frame_annots_list: raise ValueError("Empty dataset!") - # pyre-ignore[16] + # pyre-ignore self.frame_annots = [ FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list ] @@ -687,7 +479,7 @@ def _load_sequences(self) -> None: seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation]) if not seq_annots: raise ValueError("Empty sequences file!") - # pyre-ignore[16] + # pyre-ignore self.seq_annots = {entry.sequence_name: entry for entry in seq_annots} def _load_subset_lists(self) -> None: @@ -703,7 +495,7 @@ def _load_subset_lists(self) -> None: for subset, frames in subset_to_seq_frame.items() for _, _, path in frames } - # pyre-ignore[16] + # pyre-ignore for frame in self.frame_annots: frame["subset"] = frame_path_to_subset.get( frame["frame_annotation"].image.path, None @@ -716,7 +508,7 @@ def _load_subset_lists(self) -> None: def _sort_frames(self) -> None: # Sort frames to have them grouped by sequence, ordered by timestamp - # pyre-ignore[16] + # pyre-ignore self.frame_annots = sorted( self.frame_annots, key=lambda f: ( @@ -728,7 +520,7 @@ def _sort_frames(self) -> None: def _filter_db(self) -> None: if self.remove_empty_masks: logger.info("Removing images with empty masks.") - # pyre-ignore[16] + # pyre-ignore old_len = len(self.frame_annots) msg = "remove_empty_masks needs every MaskAnnotation.mass to be set." @@ -769,7 +561,7 @@ def positive_mass(frame_annot: types.FrameAnnotation) -> bool: if len(self.limit_category_to) > 0: logger.info(f"Limiting dataset to categories: {self.limit_category_to}") - # pyre-ignore[16] + # pyre-ignore self.seq_annots = { name: entry for name, entry in self.seq_annots.items() @@ -807,7 +599,7 @@ def positive_mass(frame_annot: types.FrameAnnotation) -> bool: if self.n_frames_per_sequence > 0: logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.") keep_idx = [] - # pyre-ignore[16] + # pyre-ignore for seq, seq_indices in self._seq_to_idx.items(): # infer the seed from the sequence name, this is reproducible # and makes the selection differ for different sequences @@ -837,51 +629,21 @@ def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None: self._invalidate_seq_to_idx() if filter_seq_annots: - # pyre-ignore[16] + # pyre-ignore self.seq_annots = { k: v for k, v in self.seq_annots.items() - # pyre-ignore[16] - if k in self._seq_to_idx + if k in self._seq_to_idx # pyre-ignore } def _invalidate_seq_to_idx(self) -> None: seq_to_idx = defaultdict(list) - # pyre-ignore[16] + # pyre-ignore for idx, entry in enumerate(self.frame_annots): seq_to_idx[entry["frame_annotation"].sequence_name].append(idx) - # pyre-ignore[16] + # pyre-ignore self._seq_to_idx = seq_to_idx - def _resize_image( - self, image, mode="bilinear" - ) -> Tuple[torch.Tensor, float, torch.Tensor]: - image_height, image_width = self.image_height, self.image_width - if image_height is None or image_width is None: - # skip the resizing - imre_ = torch.from_numpy(image) - return imre_, 1.0, torch.ones_like(imre_[:1]) - # takes numpy array, returns pytorch tensor - minscale = min( - image_height / image.shape[-2], - image_width / image.shape[-1], - ) - imre = torch.nn.functional.interpolate( - torch.from_numpy(image)[None], - scale_factor=minscale, - mode=mode, - align_corners=False if mode == "bilinear" else None, - recompute_scale_factor=True, - )[0] - # pyre-fixme[19]: Expected 1 positional argument. - imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width) - imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre - # pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`. - # pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`. - mask = torch.zeros(1, self.image_height, self.image_width) - mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0 - return imre_, minscale, mask - def _local_path(self, path: str) -> str: if self.path_manager is None: return path @@ -894,7 +656,7 @@ def get_frame_numbers_and_timestamps( for idx in idxs: if ( subset_filter is not None - # pyre-fixme[16]: `JsonIndexDataset` has no attribute `frame_annots`. + # pyre-ignore and self.frame_annots[idx]["subset"] not in subset_filter ): continue @@ -920,167 +682,5 @@ def _seq_name_to_seed(seq_name) -> int: return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest(), 16) -def _load_image(path) -> np.ndarray: - with Image.open(path) as pil_im: - im = np.array(pil_im.convert("RGB")) - im = im.transpose((2, 0, 1)) - im = im.astype(np.float32) / 255.0 - return im - - -def _load_16big_png_depth(depth_png) -> np.ndarray: - with Image.open(depth_png) as depth_pil: - # the image is stored with 16-bit depth but PIL reads it as I (32 bit). - # we cast it to uint16, then reinterpret as float16, then cast to float32 - depth = ( - np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) - .astype(np.float32) - .reshape((depth_pil.size[1], depth_pil.size[0])) - ) - return depth - - -def _load_1bit_png_mask(file: str) -> np.ndarray: - with Image.open(file) as pil_im: - mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32) - return mask - - -def _load_depth_mask(path: str) -> np.ndarray: - if not path.lower().endswith(".png"): - raise ValueError('unsupported depth mask file name "%s"' % path) - m = _load_1bit_png_mask(path) - return m[None] # fake feature channel - - -def _load_depth(path, scale_adjustment) -> np.ndarray: - if not path.lower().endswith(".png"): - raise ValueError('unsupported depth file name "%s"' % path) - - d = _load_16big_png_depth(path) * scale_adjustment - d[~np.isfinite(d)] = 0.0 - return d[None] # fake feature channel - - -def _load_mask(path) -> np.ndarray: - with Image.open(path) as pil_im: - mask = np.array(pil_im) - mask = mask.astype(np.float32) / 255.0 - return mask[None] # fake feature channel - - -def _get_1d_bounds(arr) -> Tuple[int, int]: - nz = np.flatnonzero(arr) - return nz[0], nz[-1] + 1 - - -def _get_bbox_from_mask( - mask, thr, decrease_quant: float = 0.05 -) -> Tuple[int, int, int, int]: - # bbox in xywh - masks_for_box = np.zeros_like(mask) - while masks_for_box.sum() <= 1.0: - masks_for_box = (mask > thr).astype(np.float32) - thr -= decrease_quant - if thr <= 0.0: - warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.") - - x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2)) - y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1)) - - return x0, y0, x1 - x0, y1 - y0 - - -def _get_clamp_bbox( - bbox: torch.Tensor, - box_crop_context: float = 0.0, - image_path: str = "", -) -> torch.Tensor: - # box_crop_context: rate of expansion for bbox - # returns possibly expanded bbox xyxy as float - - bbox = bbox.clone() # do not edit bbox in place - - # increase box size - if box_crop_context > 0.0: - c = box_crop_context - bbox = bbox.float() - bbox[0] -= bbox[2] * c / 2 - bbox[1] -= bbox[3] * c / 2 - bbox[2] += bbox[2] * c - bbox[3] += bbox[3] * c - - if (bbox[2:] <= 1.0).any(): - raise ValueError( - f"squashed image {image_path}!! The bounding box contains no pixels." - ) - - bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes - bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2) - - return bbox_xyxy - - -def _crop_around_box(tensor, bbox, impath: str = ""): - # bbox is xyxy, where the upper bound is corrected with +1 - bbox = _clamp_box_to_image_bounds_and_round( - bbox, - image_size_hw=tensor.shape[-2:], - ) - tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]] - assert all(c > 0 for c in tensor.shape), f"squashed image {impath}" - return tensor - - -def _clamp_box_to_image_bounds_and_round( - bbox_xyxy: torch.Tensor, - image_size_hw: Tuple[int, int], -) -> torch.LongTensor: - bbox_xyxy = bbox_xyxy.clone() - bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1]) - bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2]) - if not isinstance(bbox_xyxy, torch.LongTensor): - bbox_xyxy = bbox_xyxy.round().long() - return bbox_xyxy # pyre-ignore [7] - - -def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor: - assert bbox is not None - assert np.prod(orig_res) > 1e-8 - # average ratio of dimensions - rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0 - return bbox * rel_size - - -def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor: - wh = xyxy[2:] - xyxy[:2] - xywh = torch.cat([xyxy[:2], wh]) - return xywh - - -def _bbox_xywh_to_xyxy( - xywh: torch.Tensor, clamp_size: Optional[int] = None -) -> torch.Tensor: - xyxy = xywh.clone() - if clamp_size is not None: - xyxy[2:] = torch.clamp(xyxy[2:], clamp_size) - xyxy[2:] += xyxy[:2] - return xyxy - - def _safe_as_tensor(data, dtype): - if data is None: - return None - return torch.tensor(data, dtype=dtype) - - -# NOTE this cache is per-worker; they are implemented as processes. -# each batch is loaded and collated by a single worker; -# since sequences tend to co-occur within batches, this is useful. -@functools.lru_cache(maxsize=256) -def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds: - pcl = IO().load_pointcloud(pcl_path) - if max_points > 0: - pcl = pcl.subsample(max_points) - - return pcl + return torch.tensor(data, dtype=dtype) if data is not None else None diff --git a/pytorch3d/implicitron/dataset/utils.py b/pytorch3d/implicitron/dataset/utils.py index 05252aff1..aca0507dd 100644 --- a/pytorch3d/implicitron/dataset/utils.py +++ b/pytorch3d/implicitron/dataset/utils.py @@ -5,7 +5,10 @@ # LICENSE file in the root directory of this source tree. -from typing import List, Optional +import warnings +from typing import List, Optional, Tuple + +import numpy as np import torch @@ -52,3 +55,133 @@ def is_train_frame( dtype=torch.bool, device=device, ) + + +def _get_bbox_from_mask( + mask, thr, decrease_quant: float = 0.05 +) -> Tuple[int, int, int, int]: + # bbox in xywh + masks_for_box = np.zeros_like(mask) + while masks_for_box.sum() <= 1.0: + masks_for_box = (mask > thr).astype(np.float32) + thr -= decrease_quant + if thr <= 0.0: + warnings.warn( + f"Empty masks_for_bbox (thr={thr}) => using full image.", stacklevel=1 + ) + + x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2)) + y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1)) + + return x0, y0, x1 - x0, y1 - y0 + + +def _crop_around_box(tensor, bbox, impath: str = ""): + # bbox is xyxy, where the upper bound is corrected with +1 + bbox = _clamp_box_to_image_bounds_and_round( + bbox, + image_size_hw=tensor.shape[-2:], + ) + tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]] + assert all(c > 0 for c in tensor.shape), f"squashed image {impath}" + return tensor + + +def _clamp_box_to_image_bounds_and_round( + bbox_xyxy: torch.Tensor, + image_size_hw: Tuple[int, int], +) -> torch.LongTensor: + bbox_xyxy = bbox_xyxy.clone() + bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1]) + bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2]) + if not isinstance(bbox_xyxy, torch.LongTensor): + bbox_xyxy = bbox_xyxy.round().long() + return bbox_xyxy # pyre-ignore [7] + + +def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor: + wh = xyxy[2:] - xyxy[:2] + xywh = torch.cat([xyxy[:2], wh]) + return xywh + + +def _get_clamp_bbox( + bbox: torch.Tensor, + box_crop_context: float = 0.0, + image_path: str = "", +) -> torch.Tensor: + # box_crop_context: rate of expansion for bbox + # returns possibly expanded bbox xyxy as float + + bbox = bbox.clone() # do not edit bbox in place + + # increase box size + if box_crop_context > 0.0: + c = box_crop_context + bbox = bbox.float() + bbox[0] -= bbox[2] * c / 2 + bbox[1] -= bbox[3] * c / 2 + bbox[2] += bbox[2] * c + bbox[3] += bbox[3] * c + + if (bbox[2:] <= 1.0).any(): + raise ValueError( + f"squashed image {image_path}!! The bounding box contains no pixels." + ) + + bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes + bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2) + + return bbox_xyxy + + +def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor: + assert bbox is not None + assert np.prod(orig_res) > 1e-8 + # average ratio of dimensions + rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0 + return bbox * rel_size + + +def _bbox_xywh_to_xyxy( + xywh: torch.Tensor, clamp_size: Optional[int] = None +) -> torch.Tensor: + xyxy = xywh.clone() + if clamp_size is not None: + xyxy[2:] = torch.clamp(xyxy[2:], clamp_size) + xyxy[2:] += xyxy[:2] + return xyxy + + +def _get_1d_bounds(arr) -> Tuple[int, int]: + nz = np.flatnonzero(arr) + return nz[0], nz[-1] + 1 + + +def _resize_image( + image, image_height, image_width, mode="bilinear" +) -> Tuple[torch.Tensor, float, torch.Tensor]: + + if type(image) == np.ndarray: + image = torch.from_numpy(image) + + if image_height is None or image_width is None: + # skip the resizing + return image, 1.0, torch.ones_like(image[:1]) + # takes numpy array or tensor, returns pytorch tensor + minscale = min( + image_height / image.shape[-2], + image_width / image.shape[-1], + ) + imre = torch.nn.functional.interpolate( + image[None], + scale_factor=minscale, + mode=mode, + align_corners=False if mode == "bilinear" else None, + recompute_scale_factor=True, + )[0] + imre_ = torch.zeros(image.shape[0], image_height, image_width) + imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre + mask = torch.zeros(1, image_height, image_width) + mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0 + return imre_, minscale, mask diff --git a/tests/implicitron/test_bbox.py b/tests/implicitron/test_bbox.py index 999dfc924..3c45ee793 100644 --- a/tests/implicitron/test_bbox.py +++ b/tests/implicitron/test_bbox.py @@ -9,11 +9,19 @@ import numpy as np import torch -from pytorch3d.implicitron.dataset.json_index_dataset import ( + +from pytorch3d.implicitron.dataset.utils import ( _bbox_xywh_to_xyxy, _bbox_xyxy_to_xywh, + _clamp_box_to_image_bounds_and_round, + _crop_around_box, + _get_1d_bounds, _get_bbox_from_mask, + _get_clamp_bbox, + _rescale_bbox, + _resize_image, ) + from tests.common_testing import TestCaseMixin @@ -76,3 +84,59 @@ def test_mask_to_bbox(self): expected_bbox_xywh = [2, 1, 2, 1] bbox_xywh = _get_bbox_from_mask(mask, 0.5) self.assertClose(bbox_xywh, expected_bbox_xywh) + + def test_crop_around_box(self): + bbox = torch.LongTensor([0, 1, 2, 3]) # (x_min, y_min, x_max, y_max) + image = torch.LongTensor( + [ + [0, 0, 10, 20], + [10, 20, 5, 1], + [10, 20, 1, 1], + [5, 4, 0, 1], + ] + ) + cropped = _crop_around_box(image, bbox) + self.assertClose(cropped, image[1:3, 0:2]) + + def test_clamp_box_to_image_bounds_and_round(self): + bbox = torch.LongTensor([0, 1, 10, 12]) + image_size = (5, 6) + expected_clamped_bbox = torch.LongTensor([0, 1, image_size[1], image_size[0]]) + clamped_bbox = _clamp_box_to_image_bounds_and_round(bbox, image_size) + self.assertClose(clamped_bbox, expected_clamped_bbox) + + def test_get_clamp_bbox(self): + bbox_xywh = torch.LongTensor([1, 1, 4, 5]) + clamped_bbox_xyxy = _get_clamp_bbox(bbox_xywh, box_crop_context=2) + # size multiplied by 2 and added coordinates + self.assertClose(clamped_bbox_xyxy, torch.Tensor([-3, -4, 9, 11])) + + def test_rescale_bbox(self): + bbox = torch.Tensor([0.0, 1.0, 3.0, 4.0]) + original_resolution = (4, 4) + new_resolution = (8, 8) # twice bigger + rescaled_bbox = _rescale_bbox(bbox, original_resolution, new_resolution) + self.assertClose(bbox * 2, rescaled_bbox) + + def test_get_1d_bounds(self): + array = [0, 1, 2] + bounds = _get_1d_bounds(array) + # make nonzero 1d bounds of image + self.assertClose(bounds, [1, 3]) + + def test_resize_image(self): + image = np.random.rand(3, 300, 500) # rgb image 300x500 + expected_shape = (150, 250) + + resized_image, scale, mask_crop = _resize_image( + image, image_height=expected_shape[0], image_width=expected_shape[1] + ) + + original_shape = image.shape[-2:] + expected_scale = min( + expected_shape[0] / original_shape[0], expected_shape[1] / original_shape[1] + ) + + self.assertEqual(scale, expected_scale) + self.assertEqual(resized_image.shape[-2:], expected_shape) + self.assertEqual(mask_crop.shape[-2:], expected_shape) diff --git a/tests/implicitron/test_blob_loader.py b/tests/implicitron/test_blob_loader.py new file mode 100644 index 000000000..ef18d6258 --- /dev/null +++ b/tests/implicitron/test_blob_loader.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import gzip +import os +import unittest +from typing import List + +import numpy as np +import torch + +from pytorch3d.implicitron.dataset import types +from pytorch3d.implicitron.dataset.blob_loader import ( + _load_16big_png_depth, + _load_1bit_png_mask, + _load_depth, + _load_depth_mask, + _load_image, + _load_mask, + _safe_as_tensor, + BlobLoader, +) +from pytorch3d.implicitron.dataset.dataset_base import FrameData +from pytorch3d.implicitron.tools.config import get_default_args +from pytorch3d.renderer.cameras import PerspectiveCameras + +from tests.common_testing import TestCaseMixin +from tests.implicitron.common_resources import get_skateboard_data + + +class TestBlobLoader(TestCaseMixin, unittest.TestCase): + def setUp(self): + torch.manual_seed(42) + + category = "skateboard" + stack = contextlib.ExitStack() + self.dataset_root, self.path_manager = stack.enter_context( + get_skateboard_data() + ) + self.addCleanup(stack.close) + self.image_height = 768 + self.image_width = 512 + + self.blob_loader = BlobLoader( + image_height=self.image_height, + image_width=self.image_width, + dataset_root=self.dataset_root, + path_manager=self.path_manager, + ) + + # loading single frame annotation of dataset (see JsonIndexDataset._load_frames()) + frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz") + local_file = self.path_manager.get_local_path(frame_file) + with gzip.open(local_file, "rt", encoding="utf8") as zipfile: + frame_annots_list = types.load_dataclass( + zipfile, List[types.FrameAnnotation] + ) + self.frame_annotation = frame_annots_list[0] + + sequence_annotations_file = os.path.join( + self.dataset_root, category, "sequence_annotations.jgz" + ) + local_file = self.path_manager.get_local_path(sequence_annotations_file) + with gzip.open(local_file, "rt", encoding="utf8") as zipfile: + seq_annots_list = types.load_dataclass( + zipfile, List[types.SequenceAnnotation] + ) + seq_annots = {entry.sequence_name: entry for entry in seq_annots_list} + self.seq_annotation = seq_annots[self.frame_annotation.sequence_name] + + point_cloud = self.seq_annotation.point_cloud + self.frame_data = FrameData( + frame_number=_safe_as_tensor( + self.frame_annotation.frame_number, torch.long + ), + frame_timestamp=_safe_as_tensor( + self.frame_annotation.frame_timestamp, torch.float + ), + sequence_name=self.frame_annotation.sequence_name, + sequence_category=self.seq_annotation.category, + camera_quality_score=_safe_as_tensor( + self.seq_annotation.viewpoint_quality_score, torch.float + ), + point_cloud_quality_score=_safe_as_tensor( + point_cloud.quality_score, torch.float + ) + if point_cloud is not None + else None, + ) + + def test_BlobLoader_args(self): + # test that BlobLoader works with get_default_args + get_default_args(BlobLoader) + + def test_fix_point_cloud_path(self): + """Some files in Co3Dv2 have an accidental absolute path stored.""" + original_path = "some_file_path" + modified_path = self.blob_loader._fix_point_cloud_path(original_path) + assert original_path in modified_path + assert self.blob_loader.dataset_root in modified_path + + def test_load_(self): + bbox_xywh = None + self.frame_data.image_size_hw = _safe_as_tensor( + self.frame_annotation.image.size, torch.long + ) + ( + self.frame_data.fg_probability, + self.frame_data.mask_path, + self.frame_data.bbox_xywh, + ) = self.blob_loader._load_fg_probability(self.frame_annotation, bbox_xywh) + + assert self.frame_data.mask_path + assert torch.is_tensor(self.frame_data.fg_probability) + assert torch.is_tensor(self.frame_data.bbox_xywh) + # assert bboxes shape + self.assertEqual(self.frame_data.bbox_xywh.shape, torch.Size([4])) + ( + self.frame_data.image_rgb, + self.frame_data.image_path, + ) = self.blob_loader._load_images( + self.frame_annotation, self.frame_data.fg_probability + ) + self.assertEqual(type(self.frame_data.image_rgb), np.ndarray) + assert self.frame_data.image_path + + ( + self.frame_data.depth_map, + depth_path, + self.frame_data.depth_mask, + ) = self.blob_loader._load_mask_depth( + self.frame_annotation, + self.frame_data.fg_probability, + ) + assert torch.is_tensor(self.frame_data.depth_map) + assert depth_path + assert torch.is_tensor(self.frame_data.depth_mask) + + clamp_bbox_xyxy = None + if self.blob_loader.box_crop: + clamp_bbox_xyxy = self.frame_data.crop_by_bbox_( + self.blob_loader.box_crop_context + ) + + # assert image and mask shapes after resize + scale = self.frame_data.resize_frame_(self.image_height, self.image_width) + assert scale + self.assertEqual( + self.frame_data.mask_crop.shape, + torch.Size([1, self.image_height, self.image_width]), + ) + self.assertEqual( + self.frame_data.image_rgb.shape, + torch.Size([3, self.image_height, self.image_width]), + ) + self.assertEqual( + self.frame_data.mask_crop.shape, + torch.Size([1, self.image_height, self.image_width]), + ) + self.assertEqual( + self.frame_data.fg_probability.shape, + torch.Size([1, self.image_height, self.image_width]), + ) + self.assertEqual( + self.frame_data.depth_map.shape, + torch.Size([1, self.image_height, self.image_width]), + ) + self.assertEqual( + self.frame_data.depth_mask.shape, + torch.Size([1, self.image_height, self.image_width]), + ) + + self.frame_data.camera = self.blob_loader._get_pytorch3d_camera( + self.frame_annotation, + scale, + clamp_bbox_xyxy, + ) + self.assertEqual(type(self.frame_data.camera), PerspectiveCameras) + + def test_load_image(self): + path = os.path.join(self.dataset_root, self.frame_annotation.image.path) + local_path = self.path_manager.get_local_path(path) + image = _load_image(local_path) + self.assertEqual(image.dtype, np.float32) + assert np.max(image) <= 1.0 + assert np.min(image) >= 0.0 + + def test_load_mask(self): + path = os.path.join(self.dataset_root, self.frame_annotation.mask.path) + mask = _load_mask(path) + self.assertEqual(mask.dtype, np.float32) + assert np.max(mask) <= 1.0 + assert np.min(mask) >= 0.0 + + def test_load_depth(self): + path = os.path.join(self.dataset_root, self.frame_annotation.depth.path) + depth_map = _load_depth(path, self.frame_annotation.depth.scale_adjustment) + self.assertEqual(depth_map.dtype, np.float32) + self.assertEqual(len(depth_map.shape), 3) + + def test_load_16big_png_depth(self): + path = os.path.join(self.dataset_root, self.frame_annotation.depth.path) + depth_map = _load_16big_png_depth(path) + self.assertEqual(depth_map.dtype, np.float32) + self.assertEqual(len(depth_map.shape), 2) + + def test_load_1bit_png_mask(self): + mask_path = os.path.join( + self.dataset_root, self.frame_annotation.depth.mask_path + ) + mask = _load_1bit_png_mask(mask_path) + self.assertEqual(mask.dtype, np.float32) + self.assertEqual(len(mask.shape), 2) + + def test_load_depth_mask(self): + mask_path = os.path.join( + self.dataset_root, self.frame_annotation.depth.mask_path + ) + mask = _load_depth_mask(mask_path) + self.assertEqual(mask.dtype, np.float32) + self.assertEqual(len(mask.shape), 3)