# Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterable from functools import lru_cache, partial from typing import Any, Optional, TypedDict, Union import numpy as np from .image_processing_utils import ( BaseImageProcessor, BatchFeature, get_size_dict, ) from .image_transforms import ( convert_to_rgb, get_resize_output_image_size, get_size_with_aspect_ratio, group_images_by_shape, reorder_images, ) from .image_utils import ( ChannelDimension, ImageInput, ImageType, SizeDict, get_image_size, get_image_size_for_max_height_width, get_image_type, infer_channel_dimension_format, make_flat_list_of_images, validate_kwargs, validate_preprocess_arguments, ) from .processing_utils import Unpack from .utils import ( TensorType, add_start_docstrings, is_torch_available, is_torchvision_available, is_torchvision_v2_available, is_vision_available, logging, ) if is_vision_available(): from .image_utils import PILImageResampling if is_torch_available(): import torch if is_torchvision_available(): from .image_utils import pil_torch_interpolation_mapping if is_torchvision_v2_available(): from torchvision.transforms.v2 import functional as F else: from torchvision.transforms import functional as F logger = logging.get_logger(__name__) @lru_cache(maxsize=10) def validate_fast_preprocess_arguments( do_rescale: Optional[bool] = None, rescale_factor: Optional[float] = None, do_normalize: Optional[bool] = None, image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, do_pad: Optional[bool] = None, size_divisibility: Optional[int] = None, do_center_crop: Optional[bool] = None, crop_size: Optional[SizeDict] = None, do_resize: Optional[bool] = None, size: Optional[SizeDict] = None, resample: Optional["PILImageResampling"] = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, ): """ Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method. Raises `ValueError` if arguments incompatibility is caught. """ validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, do_pad=do_pad, size_divisibility=size_divisibility, do_center_crop=do_center_crop, crop_size=crop_size, do_resize=do_resize, size=size, resample=resample, ) # Extra checks for ImageProcessorFast if return_tensors is not None and return_tensors != "pt": raise ValueError("Only returning PyTorch tensors is currently supported.") if data_format != ChannelDimension.FIRST: raise ValueError("Only channel first data format is currently supported.") def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor": """ Squeezes a tensor, but only if the axis specified has dim 1. """ if axis is None: return tensor.squeeze() try: return tensor.squeeze(axis=axis) except ValueError: return tensor def max_across_indices(values: Iterable[Any]) -> list[Any]: """ Return the maximum value across all indices of an iterable of values. """ return [max(values_i) for values_i in zip(*values)] def get_max_height_width(images: list["torch.Tensor"]) -> tuple[int]: """ Get the maximum height and width across all images in a batch. """ _, max_height, max_width = max_across_indices([img.shape for img in images]) return (max_height, max_width) def divide_to_patches( image: Union[np.array, "torch.Tensor"], patch_size: int ) -> list[Union[np.array, "torch.Tensor"]]: """ Divides an image into patches of a specified size. Args: image (`Union[np.array, "torch.Tensor"]`): The input image. patch_size (`int`): The size of each patch. Returns: list: A list of Union[np.array, "torch.Tensor"] representing the patches. """ patches = [] height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST) for i in range(0, height, patch_size): for j in range(0, width, patch_size): patch = image[:, i : i + patch_size, j : j + patch_size] patches.append(patch) return patches class DefaultFastImageProcessorKwargs(TypedDict, total=False): do_resize: Optional[bool] size: Optional[dict[str, int]] default_to_square: Optional[bool] resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] do_center_crop: Optional[bool] crop_size: Optional[dict[str, int]] do_rescale: Optional[bool] rescale_factor: Optional[Union[int, float]] do_normalize: Optional[bool] image_mean: Optional[Union[float, list[float]]] image_std: Optional[Union[float, list[float]]] do_convert_rgb: Optional[bool] return_tensors: Optional[Union[str, TensorType]] data_format: Optional[ChannelDimension] input_data_format: Optional[Union[str, ChannelDimension]] device: Optional["torch.device"] BASE_IMAGE_PROCESSOR_FAST_DOCSTRING = r""" Args: do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the `do_resize` parameter in the `preprocess` method. size (`dict`, *optional*, defaults to `self.size`): Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` method. default_to_square (`bool`, *optional*, defaults to `self.default_to_square`): Whether to default to a square image when resizing, if size is an int. resample (`PILImageResampling`, *optional*, defaults to `self.resample`): Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be overridden by the `resample` parameter in the `preprocess` method. do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the `preprocess` method. crop_size (`Dict[str, int]` *optional*, defaults to `self.crop_size`): Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` method. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` parameter in the `preprocess` method. rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`): Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be overridden by the `rescale_factor` parameter in the `preprocess` method. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): Mean to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be overridden by the `image_mean` parameter in the `preprocess` method. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. Can be overridden by the `image_std` parameter in the `preprocess` method. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): Whether to convert the image to RGB. return_tensors (`str` or `TensorType`, *optional*, defaults to `self.return_tensors`): Returns stacked tensors if set to `pt, otherwise returns a list of tensors. data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.data_format`): Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors. input_data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.input_data_format`): The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. device (`torch.device`, *optional*, defaults to `self.device`): The device to process the images on. If unset, the device is inferred from the input images.""" BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS = r""" Preprocess an image or batch of images. Args: images (`ImageInput`): Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): Describes the maximum input dimensions to the model. resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to `self.resample`): Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only has an effect if `do_resize` is set to `True`. do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): Whether to center crop the image. crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): Size of the output image after applying `center_crop`. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): Whether to rescale the image. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): Rescale factor to rescale the image by if `do_rescale` is set to `True`. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): Whether to normalize the image. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): Whether to convert the image to RGB. return_tensors (`str` or `TensorType`, *optional*, defaults to `self.return_tensors`): Returns stacked tensors if set to `pt, otherwise returns a list of tensors. data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.data_format`): Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors. input_data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.input_data_format`): The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. device (`torch.device`, *optional*, defaults to `self.device`): The device to process the images on. If unset, the device is inferred from the input images.""" @add_start_docstrings( "Constructs a fast base image processor.", BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, ) class BaseImageProcessorFast(BaseImageProcessor): resample = None image_mean = None image_std = None size = None default_to_square = True crop_size = None do_resize = None do_center_crop = None do_rescale = None rescale_factor = 1 / 255 do_normalize = None do_convert_rgb = None return_tensors = None data_format = ChannelDimension.FIRST input_data_format = None device = None model_input_names = ["pixel_values"] valid_kwargs = DefaultFastImageProcessorKwargs unused_kwargs = None def __init__( self, **kwargs: Unpack[DefaultFastImageProcessorKwargs], ) -> None: super().__init__(**kwargs) kwargs = self.filter_out_unused_kwargs(kwargs) size = kwargs.pop("size", self.size) self.size = ( get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square)) if size is not None else None ) crop_size = kwargs.pop("crop_size", self.crop_size) self.crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None for key in self.valid_kwargs.__annotations__.keys(): kwarg = kwargs.pop(key, None) if kwarg is not None: setattr(self, key, kwarg) else: setattr(self, key, getattr(self, key, None)) def resize( self, image: "torch.Tensor", size: SizeDict, interpolation: "F.InterpolationMode" = None, antialias: bool = True, **kwargs, ) -> "torch.Tensor": """ Resize an image to `(size["height"], size["width"])`. Args: image (`torch.Tensor`): Image to resize. size (`SizeDict`): Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. resample (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`. Returns: `torch.Tensor`: The resized image. """ interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR if size.shortest_edge and size.longest_edge: # Resize the image so that the shortest edge or the longest edge is of the given size # while maintaining the aspect ratio of the original image. new_size = get_size_with_aspect_ratio( image.size()[-2:], size.shortest_edge, size.longest_edge, ) elif size.shortest_edge: new_size = get_resize_output_image_size( image, size=size.shortest_edge, default_to_square=False, input_data_format=ChannelDimension.FIRST, ) elif size.max_height and size.max_width: new_size = get_image_size_for_max_height_width(image.size()[-2:], size.max_height, size.max_width) elif size.height and size.width: new_size = (size.height, size.width) else: raise ValueError( "Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got" f" {size}." ) return F.resize(image, new_size, interpolation=interpolation, antialias=antialias) def rescale( self, image: "torch.Tensor", scale: float, **kwargs, ) -> "torch.Tensor": """ Rescale an image by a scale factor. image = image * scale. Args: image (`torch.Tensor`): Image to rescale. scale (`float`): The scaling factor to rescale pixel values by. Returns: `torch.Tensor`: The rescaled image. """ return image * scale def normalize( self, image: "torch.Tensor", mean: Union[float, Iterable[float]], std: Union[float, Iterable[float]], **kwargs, ) -> "torch.Tensor": """ Normalize an image. image = (image - image_mean) / image_std. Args: image (`torch.Tensor`): Image to normalize. mean (`torch.Tensor`, `float` or `Iterable[float]`): Image mean to use for normalization. std (`torch.Tensor`, `float` or `Iterable[float]`): Image standard deviation to use for normalization. Returns: `torch.Tensor`: The normalized image. """ return F.normalize(image, mean, std) @lru_cache(maxsize=10) def _fuse_mean_std_and_rescale_factor( self, do_normalize: Optional[bool] = None, image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, do_rescale: Optional[bool] = None, rescale_factor: Optional[float] = None, device: Optional["torch.device"] = None, ) -> tuple: if do_rescale and do_normalize: # Fused rescale and normalize image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor) image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor) do_rescale = False return image_mean, image_std, do_rescale def rescale_and_normalize( self, images: "torch.Tensor", do_rescale: bool, rescale_factor: float, do_normalize: bool, image_mean: Union[float, list[float]], image_std: Union[float, list[float]], ) -> "torch.Tensor": """ Rescale and normalize images. """ image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor( do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, do_rescale=do_rescale, rescale_factor=rescale_factor, device=images.device, ) # if/elif as we use fused rescale and normalize if both are set to True if do_normalize: images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std) elif do_rescale: images = self.rescale(images, rescale_factor) return images def center_crop( self, image: "torch.Tensor", size: dict[str, int], **kwargs, ) -> "torch.Tensor": """ Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along any edge, the image is padded with 0's and then center cropped. Args: image (`"torch.Tensor"`): Image to center crop. size (`Dict[str, int]`): Size of the output image. Returns: `torch.Tensor`: The center cropped image. """ if size.height is None or size.width is None: raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}") return F.center_crop(image, (size["height"], size["width"])) def convert_to_rgb( self, image: ImageInput, ) -> ImageInput: """ Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image as is. Args: image (ImageInput): The image to convert. Returns: ImageInput: The converted image. """ return convert_to_rgb(image) def filter_out_unused_kwargs(self, kwargs: dict): """ Filter out the unused kwargs from the kwargs dictionary. """ if self.unused_kwargs is None: return kwargs for kwarg_name in self.unused_kwargs: if kwarg_name in kwargs: logger.warning_once(f"This processor does not use the `{kwarg_name}` parameter. It will be ignored.") kwargs.pop(kwarg_name) return kwargs def _prepare_images_structure( self, images: ImageInput, ) -> ImageInput: """ Prepare the images structure for processing. Args: images (`ImageInput`): The input images to process. Returns: `ImageInput`: The images with a valid nesting. """ return make_flat_list_of_images(images) def _process_image( self, image: ImageInput, do_convert_rgb: Optional[bool] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, device: Optional["torch.device"] = None, ) -> "torch.Tensor": image_type = get_image_type(image) if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]: raise ValueError(f"Unsupported input image type {image_type}") if do_convert_rgb: image = self.convert_to_rgb(image) if image_type == ImageType.PIL: image = F.pil_to_tensor(image) elif image_type == ImageType.NUMPY: # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays image = torch.from_numpy(image).contiguous() # Infer the channel dimension format if not provided if input_data_format is None: input_data_format = infer_channel_dimension_format(image) if input_data_format == ChannelDimension.LAST: # We force the channel dimension to be first for torch tensors as this is what torchvision expects. image = image.permute(2, 0, 1).contiguous() # Now that we have torch tensors, we can move them to the right device if device is not None: image = image.to(device) return image def _prepare_input_images( self, images: ImageInput, do_convert_rgb: Optional[bool] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, device: Optional["torch.device"] = None, ) -> list["torch.Tensor"]: """ Prepare the input images for processing. """ images = self._prepare_images_structure(images) process_image_fn = partial( self._process_image, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device, ) # todo: yoni - check if we can parallelize this efficiently processed_images = [] for image in images: processed_images.append(process_image_fn(image)) return processed_images def _further_process_kwargs( self, size: Optional[SizeDict] = None, crop_size: Optional[SizeDict] = None, default_to_square: Optional[bool] = None, image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, data_format: Optional[ChannelDimension] = None, **kwargs, ) -> dict: """ Update kwargs that need further processing before being validated Can be overridden by subclasses to customize the processing of kwargs. """ if kwargs is None: kwargs = {} if size is not None: size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if crop_size is not None: crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) if isinstance(image_mean, list): image_mean = tuple(image_mean) if isinstance(image_std, list): image_std = tuple(image_std) if data_format is None: data_format = ChannelDimension.FIRST kwargs["size"] = size kwargs["crop_size"] = crop_size kwargs["default_to_square"] = default_to_square kwargs["image_mean"] = image_mean kwargs["image_std"] = image_std kwargs["data_format"] = data_format return kwargs def _validate_preprocess_kwargs( self, do_rescale: Optional[bool] = None, rescale_factor: Optional[float] = None, do_normalize: Optional[bool] = None, image_mean: Optional[Union[float, tuple[float]]] = None, image_std: Optional[Union[float, tuple[float]]] = None, do_resize: Optional[bool] = None, size: Optional[SizeDict] = None, do_center_crop: Optional[bool] = None, crop_size: Optional[SizeDict] = None, resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: Optional[ChannelDimension] = None, **kwargs, ): """ validate the kwargs for the preprocess method. """ validate_fast_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, do_resize=do_resize, size=size, do_center_crop=do_center_crop, crop_size=crop_size, resample=resample, return_tensors=return_tensors, data_format=data_format, ) @add_start_docstrings(BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS) def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature: validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) # Set default kwargs from self. This ensures that if a kwarg is not provided # by the user, it gets its default value from the instance, or is set to None. for kwarg_name in self.valid_kwargs.__annotations__: kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) # Extract parameters that are only used for preparing the input images do_convert_rgb = kwargs.pop("do_convert_rgb") input_data_format = kwargs.pop("input_data_format") device = kwargs.pop("device") # Prepare input images images = self._prepare_input_images( images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device ) # Update kwargs that need further processing before being validated kwargs = self._further_process_kwargs(**kwargs) # Validate kwargs self._validate_preprocess_kwargs(**kwargs) # torch resize uses interpolation instead of resample resample = kwargs.pop("resample") kwargs["interpolation"] = ( pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample ) # Pop kwargs that are not needed in _preprocess kwargs.pop("default_to_square") kwargs.pop("data_format") return self._preprocess(images=images, **kwargs) def _preprocess( self, images: list["torch.Tensor"], do_resize: bool, size: SizeDict, interpolation: Optional["F.InterpolationMode"], do_center_crop: bool, crop_size: SizeDict, do_rescale: bool, rescale_factor: float, do_normalize: bool, image_mean: Optional[Union[float, list[float]]], image_std: Optional[Union[float, list[float]]], return_tensors: Optional[Union[str, TensorType]], **kwargs, ) -> BatchFeature: # Group images by size for batched resizing grouped_images, grouped_images_index = group_images_by_shape(images) resized_images_grouped = {} for shape, stacked_images in grouped_images.items(): if do_resize: stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) resized_images_grouped[shape] = stacked_images resized_images = reorder_images(resized_images_grouped, grouped_images_index) # Group images by size for further processing # Needed in case do_resize is False, or resize returns images with different sizes grouped_images, grouped_images_index = group_images_by_shape(resized_images) processed_images_grouped = {} for shape, stacked_images in grouped_images.items(): if do_center_crop: stacked_images = self.center_crop(stacked_images, crop_size) # Fused rescale and normalize stacked_images = self.rescale_and_normalize( stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std ) processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) def to_dict(self): encoder_dict = super().to_dict() encoder_dict.pop("_valid_processor_keys", None) return encoder_dict class SemanticSegmentationMixin: def post_process_semantic_segmentation(self, outputs, target_sizes: list[tuple] = None): """ Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. Args: outputs ([`MobileNetV2ForSemanticSegmentation`]): Raw outputs of the model. target_sizes (`List[Tuple]` of length `batch_size`, *optional*): List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, predictions will not be resized. Returns: semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id. """ logits = outputs.logits # Resize logits and compute semantic segmentation maps if target_sizes is not None: if len(logits) != len(target_sizes): raise ValueError( "Make sure that you pass in as many target sizes as the batch dimension of the logits" ) # if is_torch_tensor(target_sizes): # target_sizes = target_sizes.numpy() semantic_segmentation = [] for idx in range(len(logits)): resized_logits = torch.nn.functional.interpolate( logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False ) semantic_map = resized_logits[0].argmax(dim=0) semantic_segmentation.append(semantic_map) else: semantic_segmentation = logits.argmax(dim=1) semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] return semantic_segmentation
Memory