diff --git a/docs/source/en/internal/image_processing_utils.mdx b/docs/source/en/internal/image_processing_utils.mdx index 1ec890e9e1f78..857d48f0fe6e9 100644 --- a/docs/source/en/internal/image_processing_utils.mdx +++ b/docs/source/en/internal/image_processing_utils.mdx @@ -19,6 +19,8 @@ Most of those are only useful if you are studying the code of the image processo ## Image Transformations +[[autodoc]] image_transforms.normalize + [[autodoc]] image_transforms.rescale [[autodoc]] image_transforms.resize diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index 024b46911a750..04d8332be11cc 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +import warnings +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union import numpy as np @@ -25,11 +26,13 @@ from .image_utils import ( ChannelDimension, + get_channel_dimension_axis, get_image_size, infer_channel_dimension_format, is_jax_tensor, is_tf_tensor, is_torch_tensor, + to_numpy_array, ) @@ -257,3 +260,59 @@ def resize( resized_image = np.array(resized_image) resized_image = to_channel_dimension_format(resized_image, data_format) return resized_image + + +def normalize( + image: np.ndarray, + mean: Union[float, Iterable[float]], + std: Union[float, Iterable[float]], + data_format: Optional[ChannelDimension] = None, +) -> np.ndarray: + """ + Normalizes `image` using the mean and standard deviation specified by `mean` and `std`. + + image = (image - mean) / std + + Args: + image (`np.ndarray`): + The image to normalize. + mean (`float` or `Iterable[float]`): + The mean to use for normalization. + std (`float` or `Iterable[float]`): + The standard deviation to use for normalization. + data_format (`ChannelDimension`, *optional*): + The channel dimension format of the output image. If `None`, will use the inferred format from the input. + """ + if isinstance(image, PIL.Image.Image): + warnings.warn( + "PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.", + FutureWarning, + ) + # Convert PIL image to numpy array with the same logic as in the previous feature extractor normalize - + # casting to numpy array and dividing by 255. + image = to_numpy_array(image) + image = rescale(image, scale=1 / 255) + + input_data_format = infer_channel_dimension_format(image) + channel_axis = get_channel_dimension_axis(image) + num_channels = image.shape[channel_axis] + + if isinstance(mean, Iterable): + if len(mean) != num_channels: + raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}") + else: + mean = [mean] * num_channels + + if isinstance(std, Iterable): + if len(std) != num_channels: + raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}") + else: + std = [std] * num_channels + + if input_data_format == ChannelDimension.LAST: + image = (image - mean) / std + else: + image = ((image.T - mean) / std).T + + image = to_channel_dimension_format(image, data_format) if data_format is not None else image + return image diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 0ba86d14b7975..fdba17dc824ff 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -112,6 +112,25 @@ def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension: raise ValueError("Unable to infer channel dimension format") +def get_channel_dimension_axis(image: np.ndarray) -> int: + """ + Returns the channel dimension axis of the image. + + Args: + image (`np.ndarray`): + The image to get the channel dimension axis of. + + Returns: + The channel dimension axis of the image. + """ + channel_dim = infer_channel_dimension_format(image) + if channel_dim == ChannelDimension.FIRST: + return image.ndim - 3 + elif channel_dim == ChannelDimension.LAST: + return image.ndim - 1 + raise ValueError(f"Unsupported data format: {channel_dim}") + + def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]: """ Returns the (height, width) dimensions of the image. diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index 69e6de1587b8d..ee51bd358f40b 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -36,6 +36,7 @@ from transformers.image_transforms import ( get_resize_output_image_size, + normalize, resize, to_channel_dimension_format, to_pil_image, @@ -172,3 +173,25 @@ def test_resize(self): self.assertIsInstance(resized_image, PIL.Image.Image) # PIL size is in (width, height) order self.assertEqual(resized_image.size, (40, 30)) + + def test_normalize(self): + image = np.random.randint(0, 256, (224, 224, 3)) / 255 + + # Number of mean values != number of channels + with self.assertRaises(ValueError): + normalize(image, mean=(0.5, 0.6), std=1) + + # Number of std values != number of channels + with self.assertRaises(ValueError): + normalize(image, mean=1, std=(0.5, 0.6)) + + # Test result is correct - output data format is channels_first and normalization + # correctly computed + mean = (0.5, 0.6, 0.7) + std = (0.1, 0.2, 0.3) + expected_image = ((image - mean) / std).transpose((2, 0, 1)) + + normalized_image = normalize(image, mean=mean, std=std, data_format="channels_first") + self.assertIsInstance(normalized_image, np.ndarray) + self.assertEqual(normalized_image.shape, (3, 224, 224)) + self.assertTrue(np.allclose(normalized_image, expected_image)) diff --git a/tests/utils/test_image_utils.py b/tests/utils/test_image_utils.py index 0ae5d78fb2dc0..6868e117c4c38 100644 --- a/tests/utils/test_image_utils.py +++ b/tests/utils/test_image_utils.py @@ -20,7 +20,7 @@ import pytest from transformers import is_torch_available, is_vision_available -from transformers.image_utils import ChannelDimension +from transformers.image_utils import ChannelDimension, get_channel_dimension_axis from transformers.testing_utils import require_torch, require_vision @@ -535,3 +535,26 @@ def test_infer_channel_dimension(self): image = np.random.randint(0, 256, (1, 3, 4, 5)) inferred_dim = infer_channel_dimension_format(image) self.assertEqual(inferred_dim, ChannelDimension.FIRST) + + def test_get_channel_dimension_axis(self): + # Test we correctly identify the channel dimension + image = np.random.randint(0, 256, (3, 4, 5)) + inferred_axis = get_channel_dimension_axis(image) + self.assertEqual(inferred_axis, 0) + + image = np.random.randint(0, 256, (1, 4, 5)) + inferred_axis = get_channel_dimension_axis(image) + self.assertEqual(inferred_axis, 0) + + image = np.random.randint(0, 256, (4, 5, 3)) + inferred_axis = get_channel_dimension_axis(image) + self.assertEqual(inferred_axis, 2) + + image = np.random.randint(0, 256, (4, 5, 1)) + inferred_axis = get_channel_dimension_axis(image) + self.assertEqual(inferred_axis, 2) + + # We can take a batched array of images and find the dimension + image = np.random.randint(0, 256, (1, 3, 4, 5)) + inferred_axis = get_channel_dimension_axis(image) + self.assertEqual(inferred_axis, 1)