# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import os
from collections.abc import Iterable
from contextlib import redirect_stdout
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, Callable, Optional, Union
import numpy as np
import requests
from packaging import version
from .utils import (
ExplicitEnum,
is_av_available,
is_cv2_available,
is_decord_available,
is_jax_tensor,
is_numpy_array,
is_tf_tensor,
is_torch_available,
is_torch_tensor,
is_torchvision_available,
is_vision_available,
is_yt_dlp_available,
logging,
requires_backends,
to_numpy,
)
from .utils.constants import ( # noqa: F401
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
)
if is_vision_available():
import PIL.Image
import PIL.ImageOps
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PILImageResampling = PIL.Image.Resampling
else:
PILImageResampling = PIL.Image
if is_torchvision_available():
from torchvision import io as torchvision_io
from torchvision.transforms import InterpolationMode
pil_torch_interpolation_mapping = {
PILImageResampling.NEAREST: InterpolationMode.NEAREST,
PILImageResampling.BOX: InterpolationMode.BOX,
PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
PILImageResampling.HAMMING: InterpolationMode.HAMMING,
PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
}
if TYPE_CHECKING:
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
ImageInput = Union[
"PIL.Image.Image", np.ndarray, "torch.Tensor", list["PIL.Image.Image"], list[np.ndarray], list["torch.Tensor"]
] # noqa
VideoInput = Union[
list["PIL.Image.Image"],
"np.ndarray",
"torch.Tensor",
list["np.ndarray"],
list["torch.Tensor"],
list[list["PIL.Image.Image"]],
list[list["np.ndarray"]],
list[list["torch.Tensor"]],
] # noqa
class ChannelDimension(ExplicitEnum):
FIRST = "channels_first"
LAST = "channels_last"
class AnnotationFormat(ExplicitEnum):
COCO_DETECTION = "coco_detection"
COCO_PANOPTIC = "coco_panoptic"
class AnnotionFormat(ExplicitEnum):
COCO_DETECTION = AnnotationFormat.COCO_DETECTION.value
COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
@dataclass
class VideoMetadata:
total_num_frames: int
fps: float
duration: float
video_backend: str
AnnotationType = dict[str, Union[int, str, list[dict]]]
def is_pil_image(img):
return is_vision_available() and isinstance(img, PIL.Image.Image)
class ImageType(ExplicitEnum):
PIL = "pillow"
TORCH = "torch"
NUMPY = "numpy"
TENSORFLOW = "tensorflow"
JAX = "jax"
def get_image_type(image):
if is_pil_image(image):
return ImageType.PIL
if is_torch_tensor(image):
return ImageType.TORCH
if is_numpy_array(image):
return ImageType.NUMPY
if is_tf_tensor(image):
return ImageType.TENSORFLOW
if is_jax_tensor(image):
return ImageType.JAX
raise ValueError(f"Unrecognised image type {type(image)}")
def is_valid_image(img):
return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img)
def is_valid_list_of_images(images: list):
return images and all(is_valid_image(image) for image in images)
def valid_images(imgs):
# If we have an list of images, make sure every image is valid
if isinstance(imgs, (list, tuple)):
for img in imgs:
if not valid_images(img):
return False
# If not a list of tuple, we have been given a single image or batched tensor of images
elif not is_valid_image(imgs):
return False
return True
def is_batched(img):
if isinstance(img, (list, tuple)):
return is_valid_image(img[0])
return False
def is_scaled_image(image: np.ndarray) -> bool:
"""
Checks to see whether the pixel values have already been rescaled to [0, 1].
"""
if image.dtype == np.uint8:
return False
# It's possible the image has pixel values in [0, 255] but is of floating type
return np.min(image) >= 0 and np.max(image) <= 1
def make_list_of_images(images, expected_ndims: int = 3) -> list[ImageInput]:
"""
Ensure that the output is a list of images. If the input is a single image, it is converted to a list of length 1.
If the input is a batch of images, it is converted to a list of images.
Args:
images (`ImageInput`):
Image of images to turn into a list of images.
expected_ndims (`int`, *optional*, defaults to 3):
Expected number of dimensions for a single input image. If the input image has a different number of
dimensions, an error is raised.
"""
if is_batched(images):
return images
# Either the input is a single image, in which case we create a list of length 1
if is_pil_image(images):
# PIL images are never batched
return [images]
if is_valid_image(images):
if images.ndim == expected_ndims + 1:
# Batch of images
images = list(images)
elif images.ndim == expected_ndims:
# Single image
images = [images]
else:
raise ValueError(
f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
f" {images.ndim} dimensions."
)
return images
raise ValueError(
"Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or "
f"jax.ndarray, but got {type(images)}."
)
def make_flat_list_of_images(
images: Union[list[ImageInput], ImageInput],
) -> ImageInput:
"""
Ensure that the output is a flat list of images. If the input is a single image, it is converted to a list of length 1.
If the input is a nested list of images, it is converted to a flat list of images.
Args:
images (`Union[List[ImageInput], ImageInput]`):
The input image.
Returns:
list: A list of images or a 4d array of images.
"""
# If the input is a nested list of images, we flatten it
if (
isinstance(images, (list, tuple))
and all(isinstance(images_i, (list, tuple)) for images_i in images)
and all(is_valid_list_of_images(images_i) for images_i in images)
):
return [img for img_list in images for img in img_list]
if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
if is_pil_image(images[0]) or images[0].ndim == 3:
return images
if images[0].ndim == 4:
return [img for img_list in images for img in img_list]
if is_valid_image(images):
if is_pil_image(images) or images.ndim == 3:
return [images]
if images.ndim == 4:
return list(images)
raise ValueError(f"Could not make a flat list of images from {images}")
def make_nested_list_of_images(
images: Union[list[ImageInput], ImageInput],
) -> ImageInput:
"""
Ensure that the output is a nested list of images.
Args:
images (`Union[List[ImageInput], ImageInput]`):
The input image.
Returns:
list: A list of list of images or a list of 4d array of images.
"""
# If it's a list of batches, it's already in the right format
if (
isinstance(images, (list, tuple))
and all(isinstance(images_i, (list, tuple)) for images_i in images)
and all(is_valid_list_of_images(images_i) for images_i in images)
):
return images
# If it's a list of images, it's a single batch, so convert it to a list of lists
if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
if is_pil_image(images[0]) or images[0].ndim == 3:
return [images]
if images[0].ndim == 4:
return [list(image) for image in images]
# If it's a single image, convert it to a list of lists
if is_valid_image(images):
if is_pil_image(images) or images.ndim == 3:
return [[images]]
if images.ndim == 4:
return [list(images)]
raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.")
def make_batched_videos(videos) -> VideoInput:
"""
Ensure that the input is a list of videos.
Args:
videos (`VideoInput`):
Video or videos to turn into a list of videos.
Returns:
list: A list of videos.
"""
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
# case 1: nested batch of videos so we flatten it
if not is_pil_image(videos[0][0]) and videos[0][0].ndim == 4:
videos = [[video for batch_list in batched_videos for video in batch_list] for batched_videos in videos]
# case 2: list of videos represented as list of video frames
return videos
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
if is_pil_image(videos[0]) or videos[0].ndim == 3:
return [videos]
elif videos[0].ndim == 4:
return [list(video) for video in videos]
elif is_valid_image(videos):
if is_pil_image(videos) or videos.ndim == 3:
return [[videos]]
elif videos.ndim == 4:
return [list(videos)]
raise ValueError(f"Could not make batched video from {videos}")
def to_numpy_array(img) -> np.ndarray:
if not is_valid_image(img):
raise ValueError(f"Invalid image type: {type(img)}")
if is_vision_available() and isinstance(img, PIL.Image.Image):
return np.array(img)
return to_numpy(img)
def infer_channel_dimension_format(
image: np.ndarray, num_channels: Optional[Union[int, tuple[int, ...]]] = None
) -> ChannelDimension:
"""
Infers the channel dimension format of `image`.
Args:
image (`np.ndarray`):
The image to infer the channel dimension of.
num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
The number of channels of the image.
Returns:
The channel dimension of the image.
"""
num_channels = num_channels if num_channels is not None else (1, 3)
num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
if image.ndim == 3:
first_dim, last_dim = 0, 2
elif image.ndim == 4:
first_dim, last_dim = 1, 3
else:
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels:
logger.warning(
f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension."
)
return ChannelDimension.FIRST
elif image.shape[first_dim] in num_channels:
return ChannelDimension.FIRST
elif image.shape[last_dim] in num_channels:
return ChannelDimension.LAST
raise ValueError("Unable to infer channel dimension format")
def get_channel_dimension_axis(
image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
) -> int:
"""
Returns the channel dimension axis of the image.
Args:
image (`np.ndarray`):
The image to get the channel dimension axis of.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
Returns:
The channel dimension axis of the image.
"""
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
if input_data_format == ChannelDimension.FIRST:
return image.ndim - 3
elif input_data_format == ChannelDimension.LAST:
return image.ndim - 1
raise ValueError(f"Unsupported data format: {input_data_format}")
def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> tuple[int, int]:
"""
Returns the (height, width) dimensions of the image.
Args:
image (`np.ndarray`):
The image to get the dimensions of.
channel_dim (`ChannelDimension`, *optional*):
Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
Returns:
A tuple of the image's height and width.
"""
if channel_dim is None:
channel_dim = infer_channel_dimension_format(image)
if channel_dim == ChannelDimension.FIRST:
return image.shape[-2], image.shape[-1]
elif channel_dim == ChannelDimension.LAST:
return image.shape[-3], image.shape[-2]
else:
raise ValueError(f"Unsupported data format: {channel_dim}")
def get_image_size_for_max_height_width(
image_size: tuple[int, int],
max_height: int,
max_width: int,
) -> tuple[int, int]:
"""
Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
Important, even if image_height < max_height and image_width < max_width, the image will be resized
to at least one of the edges be equal to max_height or max_width.
For example:
- input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
- input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
Args:
image_size (`Tuple[int, int]`):
The image to resize.
max_height (`int`):
The maximum allowed height.
max_width (`int`):
The maximum allowed width.
"""
height, width = image_size
height_scale = max_height / height
width_scale = max_width / width
min_scale = min(height_scale, width_scale)
new_height = int(height * min_scale)
new_width = int(width * min_scale)
return new_height, new_width
def is_valid_annotation_coco_detection(annotation: dict[str, Union[list, tuple]]) -> bool:
if (
isinstance(annotation, dict)
and "image_id" in annotation
and "annotations" in annotation
and isinstance(annotation["annotations"], (list, tuple))
and (
# an image can have no annotations
len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
)
):
return True
return False
def is_valid_annotation_coco_panoptic(annotation: dict[str, Union[list, tuple]]) -> bool:
if (
isinstance(annotation, dict)
and "image_id" in annotation
and "segments_info" in annotation
and "file_name" in annotation
and isinstance(annotation["segments_info"], (list, tuple))
and (
# an image can have no segments
len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
)
):
return True
return False
def valid_coco_detection_annotations(annotations: Iterable[dict[str, Union[list, tuple]]]) -> bool:
return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
def valid_coco_panoptic_annotations(annotations: Iterable[dict[str, Union[list, tuple]]]) -> bool:
return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image":
"""
Loads `image` to a PIL Image.
Args:
image (`str` or `PIL.Image.Image`):
The image to convert to the PIL Image format.
timeout (`float`, *optional*):
The timeout value in seconds for the URL request.
Returns:
`PIL.Image.Image`: A PIL Image.
"""
requires_backends(load_image, ["vision"])
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
image = PIL.Image.open(BytesIO(requests.get(image, timeout=timeout).content))
elif os.path.isfile(image):
image = PIL.Image.open(image)
else:
if image.startswith("data:image/"):
image = image.split(",")[1]
# Try to load as base64
try:
b64 = base64.decodebytes(image.encode())
image = PIL.Image.open(BytesIO(b64))
except Exception as e:
raise ValueError(
f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
)
elif isinstance(image, PIL.Image.Image):
image = image
else:
raise TypeError(
"Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
"""
A default sampling function that replicates the logic used in get_uniform_frame_indices,
while optionally handling `fps` if `num_frames` is not provided.
Args:
metadata (`VideoMetadata`):
`VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps".
num_frames (`int`, *optional*):
Number of frames to sample uniformly.
fps (`int`, *optional*):
Desired frames per second. Takes priority over num_frames if both are provided.
Returns:
`np.ndarray`: Array of frame indices to sample.
"""
total_num_frames = metadata.total_num_frames
video_fps = metadata.fps
# If num_frames is not given but fps is, calculate num_frames from fps
if num_frames is None and fps is not None:
num_frames = int(total_num_frames / video_fps * fps)
if num_frames > total_num_frames:
raise ValueError(
f"When loading the video with fps={fps}, we computed num_frames={num_frames} "
f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata."
)
if num_frames is not None:
indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
else:
indices = np.arange(0, total_num_frames, dtype=int)
return indices
def read_video_opencv(
video_path: str,
sample_indices_fn: Callable,
**kwargs,
):
"""
Decode a video using the OpenCV backend.
Args:
video_path (`str`):
Path to the video file.
sample_indices_fn (`Callable`):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniform sampling with fps is performed.
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
Returns:
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import cv2
requires_backends(read_video_opencv, ["cv2"])
import cv2
video = cv2.VideoCapture(video_path)
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
video_fps = video.get(cv2.CAP_PROP_FPS)
duration = total_num_frames / video_fps if video_fps else 0
metadata = VideoMetadata(
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="opencv"
)
indices = sample_indices_fn(metadata=metadata, **kwargs)
index = 0
frames = []
while video.isOpened():
success, frame = video.read()
if not success:
break
if index in indices:
height, width, channel = frame.shape
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame[0:height, 0:width, 0:channel])
if success:
index += 1
if index >= total_num_frames:
break
video.release()
metadata.frames_indices = indices
return np.stack(frames), metadata
def read_video_decord(
video_path: str,
sample_indices_fn: Optional[Callable] = None,
**kwargs,
):
"""
Decode a video using the Decord backend.
Args:
video_path (`str`):
Path to the video file.
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniform sampling with fps is performed.
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
Returns:
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import from decord
requires_backends(read_video_decord, ["decord"])
from decord import VideoReader, cpu
vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
video_fps = vr.get_avg_fps()
total_num_frames = len(vr)
duration = total_num_frames / video_fps if video_fps else 0
metadata = VideoMetadata(
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="decord"
)
indices = sample_indices_fn(metadata=metadata, **kwargs)
frames = vr.get_batch(indices).asnumpy()
metadata.frames_indices = indices
return frames, metadata
def read_video_pyav(
video_path: str,
sample_indices_fn: Callable,
**kwargs,
):
"""
Decode the video with PyAV decoder.
Args:
video_path (`str`):
Path to the video file.
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniform sampling with fps is performed.
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
Returns:
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import av
requires_backends(read_video_pyav, ["av"])
import av
container = av.open(video_path)
total_num_frames = container.streams.video[0].frames
video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
duration = total_num_frames / video_fps if video_fps else 0
metadata = VideoMetadata(
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="pyav"
)
indices = sample_indices_fn(metadata=metadata, **kwargs)
frames = []
container.seek(0)
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= 0 and i in indices:
frames.append(frame)
video = np.stack([x.to_ndarray(format="rgb24") for x in frames])
metadata.frames_indices = indices
return video, metadata
def read_video_torchvision(
video_path: str,
sample_indices_fn: Callable,
**kwargs,
):
"""
Decode the video with torchvision decoder.
Args:
video_path (`str`):
Path to the video file.
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniform sampling with fps is performed.
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
Returns:
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
video, _, info = torchvision_io.read_video(
video_path,
start_pts=0.0,
end_pts=None,
pts_unit="sec",
output_format="THWC",
)
video_fps = info["video_fps"]
total_num_frames = video.size(0)
duration = total_num_frames / video_fps if video_fps else 0
metadata = VideoMetadata(
total_num_frames=int(total_num_frames),
fps=float(video_fps),
duration=float(duration),
video_backend="torchvision",
)
indices = sample_indices_fn(metadata=metadata, **kwargs)
video = video[indices].contiguous().numpy()
metadata.frames_indices = indices
return video, metadata
VIDEO_DECODERS = {
"decord": read_video_decord,
"opencv": read_video_opencv,
"pyav": read_video_pyav,
"torchvision": read_video_torchvision,
}
def load_video(
video: Union[str, "VideoInput"],
num_frames: Optional[int] = None,
fps: Optional[int] = None,
backend: str = "opencv",
sample_indices_fn: Optional[Callable] = None,
**kwargs,
) -> np.array:
"""
Loads `video` to a numpy array.
Args:
video (`str` or `VideoInput`):
The video to convert to the numpy array format. Can be a link to video or local path.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not passed, the whole video is loaded.
fps (`int`, *optional*):
Number of frames to sample per second. Should be passed only when `num_frames=None`.
If not specified and `num_frames==None`, all frames are sampled.
backend (`str`, *optional*, defaults to `"opencv"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
indices at which the video should be sampled. For example:
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
Returns:
Tuple[`np.array`, Dict]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- Metadata dictionary.
"""
# If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn`
if fps is not None and num_frames is not None and sample_indices_fn is None:
raise ValueError(
"`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
)
# If user didn't pass a sampling function, create one on the fly with default logic
if sample_indices_fn is None:
def sample_indices_fn_func(metadata, **fn_kwargs):
return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)
sample_indices_fn = sample_indices_fn_func
if video.startswith("https://www.youtube.com") or video.startswith("http://www.youtube.com"):
if not is_yt_dlp_available():
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
# Lazy import from yt_dlp
requires_backends(load_video, ["yt_dlp"])
from yt_dlp import YoutubeDL
buffer = BytesIO()
with redirect_stdout(buffer), YoutubeDL() as f:
f.download([video])
bytes_obj = buffer.getvalue()
file_obj = BytesIO(bytes_obj)
elif video.startswith("http://") or video.startswith("https://"):
file_obj = BytesIO(requests.get(video).content)
elif os.path.isfile(video):
file_obj = video
elif is_valid_image(video) or (isinstance(video, (list, tuple)) and is_valid_image(video[0])):
file_obj = None
else:
raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
# can also load with decord, but not cv2/torchvision
# both will fail in case of url links
video_is_url = video.startswith("http://") or video.startswith("https://")
if video_is_url and backend in ["opencv", "torchvision"]:
raise ValueError(
"If you are trying to load a video from URL, you can decode the video only with `pyav` or `decord` as backend"
)
if file_obj is None:
return video
if (
(not is_decord_available() and backend == "decord")
or (not is_av_available() and backend == "pyav")
or (not is_cv2_available() and backend == "opencv")
or (not is_torchvision_available() and backend == "torchvision")
):
raise ImportError(
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
f"Make sure to install {backend} before loading the video."
)
video_decoder = VIDEO_DECODERS[backend]
video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
return video, metadata
def load_images(
images: Union[list, tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None
) -> Union["PIL.Image.Image", list["PIL.Image.Image"], list[list["PIL.Image.Image"]]]:
"""Loads images, handling different levels of nesting.
Args:
images: A single image, a list of images, or a list of lists of images to load.
timeout: Timeout for loading images.
Returns:
A single image, a list of images, a list of lists of images.
"""
if isinstance(images, (list, tuple)):
if len(images) and isinstance(images[0], (list, tuple)):
return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images]
else:
return [load_image(image, timeout=timeout) for image in images]
else:
return load_image(images, timeout=timeout)
def validate_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[dict[str, int]] = None,
do_resize: Optional[bool] = None,
size: Optional[dict[str, int]] = None,
resample: Optional["PILImageResampling"] = None,
):
"""
Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
Raises `ValueError` if arguments incompatibility is caught.
Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
existing arguments when possible.
"""
if do_rescale and rescale_factor is None:
raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
if do_pad and size_divisibility is None:
# Here, size_divisor might be passed as the value of size
raise ValueError(
"Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `True`."
)
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
if do_center_crop and crop_size is None:
raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
if do_resize and (size is None or resample is None):
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
# In the future we can add a TF implementation here when we have TF models.
class ImageFeatureExtractionMixin:
"""
Mixin that contain utilities for preparing image features.
"""
def _ensure_format_supported(self, image):
if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
raise ValueError(
f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and "
"`torch.Tensor` are."
)
def to_pil_image(self, image, rescale=None):
"""
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
needed.
Args:
image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
The image to convert to the PIL Image format.
rescale (`bool`, *optional*):
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
default to `True` if the image type is a floating type, `False` otherwise.
"""
self._ensure_format_supported(image)
if is_torch_tensor(image):
image = image.numpy()
if isinstance(image, np.ndarray):
if rescale is None:
# rescale default to the array being of floating type.
rescale = isinstance(image.flat[0], np.floating)
# If the channel as been moved to first dim, we put it back at the end.
if image.ndim == 3 and image.shape[0] in [1, 3]:
image = image.transpose(1, 2, 0)
if rescale:
image = image * 255
image = image.astype(np.uint8)
return PIL.Image.fromarray(image)
return image
def convert_rgb(self, image):
"""
Converts `PIL.Image.Image` to RGB format.
Args:
image (`PIL.Image.Image`):
The image to convert.
"""
self._ensure_format_supported(image)
if not isinstance(image, PIL.Image.Image):
return image
return image.convert("RGB")
def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray:
"""
Rescale a numpy image by scale amount
"""
self._ensure_format_supported(image)
return image * scale
def to_numpy_array(self, image, rescale=None, channel_first=True):
"""
Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
dimension.
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image to convert to a NumPy array.
rescale (`bool`, *optional*):
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
channel_first (`bool`, *optional*, defaults to `True`):
Whether or not to permute the dimensions of the image to put the channel dimension first.
"""
self._ensure_format_supported(image)
if isinstance(image, PIL.Image.Image):
image = np.array(image)
if is_torch_tensor(image):
image = image.numpy()
rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale
if rescale:
image = self.rescale(image.astype(np.float32), 1 / 255.0)
if channel_first and image.ndim == 3:
image = image.transpose(2, 0, 1)
return image
def expand_dims(self, image):
"""
Expands 2-dimensional `image` to 3 dimensions.
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image to expand.
"""
self._ensure_format_supported(image)
# Do nothing if PIL image
if isinstance(image, PIL.Image.Image):
return image
if is_torch_tensor(image):
image = image.unsqueeze(0)
else:
image = np.expand_dims(image, axis=0)
return image
def normalize(self, image, mean, std, rescale=False):
"""
Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
if it's a PIL Image.
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image to normalize.
mean (`List[float]` or `np.ndarray` or `torch.Tensor`):
The mean (per channel) to use for normalization.
std (`List[float]` or `np.ndarray` or `torch.Tensor`):
The standard deviation (per channel) to use for normalization.
rescale (`bool`, *optional*, defaults to `False`):
Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will
happen automatically.
"""
self._ensure_format_supported(image)
if isinstance(image, PIL.Image.Image):
image = self.to_numpy_array(image, rescale=True)
# If the input image is a PIL image, it automatically gets rescaled. If it's another
# type it may need rescaling.
elif rescale:
if isinstance(image, np.ndarray):
image = self.rescale(image.astype(np.float32), 1 / 255.0)
elif is_torch_tensor(image):
image = self.rescale(image.float(), 1 / 255.0)
if isinstance(image, np.ndarray):
if not isinstance(mean, np.ndarray):
mean = np.array(mean).astype(image.dtype)
if not isinstance(std, np.ndarray):
std = np.array(std).astype(image.dtype)
elif is_torch_tensor(image):
import torch
if not isinstance(mean, torch.Tensor):
if isinstance(mean, np.ndarray):
mean = torch.from_numpy(mean)
else:
mean = torch.tensor(mean)
if not isinstance(std, torch.Tensor):
if isinstance(std, np.ndarray):
std = torch.from_numpy(std)
else:
std = torch.tensor(std)
if image.ndim == 3 and image.shape[0] in [1, 3]:
return (image - mean[:, None, None]) / std[:, None, None]
else:
return (image - mean) / std
def resize(self, image, size, resample=None, default_to_square=True, max_size=None):
"""
Resizes `image`. Enforces conversion of input to PIL.Image.
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image to resize.
size (`int` or `Tuple[int, int]`):
The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be
matched to this.
If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
`size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to
this number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
The filter to user for resampling.
default_to_square (`bool`, *optional*, defaults to `True`):
How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
square (`size`,`size`). If set to `False`, will replicate
[`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
with support for resizing only the smallest edge and providing an optional `max_size`.
max_size (`int`, *optional*, defaults to `None`):
The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
greater than `max_size` after being resized according to `size`, then the image is resized again so
that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller
edge may be shorter than `size`. Only used if `default_to_square` is `False`.
Returns:
image: A resized `PIL.Image.Image`.
"""
resample = resample if resample is not None else PILImageResampling.BILINEAR
self._ensure_format_supported(image)
if not isinstance(image, PIL.Image.Image):
image = self.to_pil_image(image)
if isinstance(size, list):
size = tuple(size)
if isinstance(size, int) or len(size) == 1:
if default_to_square:
size = (size, size) if isinstance(size, int) else (size[0], size[0])
else:
width, height = image.size
# specified size only for the smallest edge
short, long = (width, height) if width <= height else (height, width)
requested_new_short = size if isinstance(size, int) else size[0]
if short == requested_new_short:
return image
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
if max_size is not None:
if max_size <= requested_new_short:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size
size = (new_short, new_long) if width <= height else (new_long, new_short)
return image.resize(size, resample=resample)
def center_crop(self, image, size):
"""
Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
size given, it will be padded (so the returned result has the size asked).
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)):
The image to resize.
size (`int` or `Tuple[int, int]`):
The size to which crop the image.
Returns:
new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels,
height, width).
"""
self._ensure_format_supported(image)
if not isinstance(size, tuple):
size = (size, size)
# PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width)
if is_torch_tensor(image) or isinstance(image, np.ndarray):
if image.ndim == 2:
image = self.expand_dims(image)
image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2]
else:
image_shape = (image.size[1], image.size[0])
top = (image_shape[0] - size[0]) // 2
bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
left = (image_shape[1] - size[1]) // 2
right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
# For PIL Images we have a method to crop directly.
if isinstance(image, PIL.Image.Image):
return image.crop((left, top, right, bottom))
# Check if image is in (n_channels, height, width) or (height, width, n_channels) format
channel_first = True if image.shape[0] in [1, 3] else False
# Transpose (height, width, n_channels) format images
if not channel_first:
if isinstance(image, np.ndarray):
image = image.transpose(2, 0, 1)
if is_torch_tensor(image):
image = image.permute(2, 0, 1)
# Check if cropped area is within image boundaries
if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]:
return image[..., top:bottom, left:right]
# Otherwise, we may need to pad if the image is too small. Oh joy...
new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1]))
if isinstance(image, np.ndarray):
new_image = np.zeros_like(image, shape=new_shape)
elif is_torch_tensor(image):
new_image = image.new_zeros(new_shape)
top_pad = (new_shape[-2] - image_shape[0]) // 2
bottom_pad = top_pad + image_shape[0]
left_pad = (new_shape[-1] - image_shape[1]) // 2
right_pad = left_pad + image_shape[1]
new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
top += top_pad
bottom += top_pad
left += left_pad
right += left_pad
new_image = new_image[
..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right)
]
return new_image
def flip_channel_order(self, image):
"""
Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
`image` to a NumPy array if it's a PIL Image.
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should
be first.
"""
self._ensure_format_supported(image)
if isinstance(image, PIL.Image.Image):
image = self.to_numpy_array(image)
return image[::-1, :, :]
def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None):
"""
Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
counter clockwise around its centre.
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
rotating.
Returns:
image: A rotated `PIL.Image.Image`.
"""
resample = resample if resample is not None else PIL.Image.NEAREST
self._ensure_format_supported(image)
if not isinstance(image, PIL.Image.Image):
image = self.to_pil_image(image)
return image.rotate(
angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
)
def validate_annotations(
annotation_format: AnnotationFormat,
supported_annotation_formats: tuple[AnnotationFormat, ...],
annotations: list[dict],
) -> None:
if annotation_format not in supported_annotation_formats:
raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}")
if annotation_format is AnnotationFormat.COCO_DETECTION:
if not valid_coco_detection_annotations(annotations):
raise ValueError(
"Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts "
"(batch of images) with the following keys: `image_id` and `annotations`, with the latter "
"being a list of annotations in the COCO format."
)
if annotation_format is AnnotationFormat.COCO_PANOPTIC:
if not valid_coco_panoptic_annotations(annotations):
raise ValueError(
"Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts "
"(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
"the latter being a list of annotations in the COCO format."
)
def validate_kwargs(valid_processor_keys: list[str], captured_kwargs: list[str]):
unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
if unused_keys:
unused_key_str = ", ".join(unused_keys)
# TODO raise a warning here instead of simply logging?
logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
@dataclass(frozen=True)
class SizeDict:
"""
Hashable dictionary to store image size information.
"""
height: Optional[int] = None
width: Optional[int] = None
longest_edge: Optional[int] = None
shortest_edge: Optional[int] = None
max_height: Optional[int] = None
max_width: Optional[int] = None
def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
raise KeyError(f"Key {key} not found in SizeDict.")