diff --git a/src/transformers/image_processing_utils.py b/src/transformers/image_processing_utils.py
index 1c099cf14e2d44..ba9d3c0962e3f6 100644
--- a/src/transformers/image_processing_utils.py
+++ b/src/transformers/image_processing_utils.py
@@ -38,387 +38,9 @@ class BatchFeature(BaseBatchFeature):
"""
-class ImageProcessorMixin(PushToHubMixin):
- """
- Image processor mixin used to provide saving/loading functionality
- """
-
- _auto_class = None
-
- def __init__(self, **kwargs):
- """Set elements of `kwargs` as attributes."""
- # Pop "processor_class" as it should be saved as private attribute
- self._processor_class = kwargs.pop("processor_class", None)
- # Additional attributes without default values
- for key, value in kwargs.items():
- try:
- setattr(self, key, value)
- except AttributeError as err:
- logger.error(f"Can't set {key} with value {value} for {self}")
- raise err
-
- def _set_processor_class(self, processor_class: str):
- """Sets processor class as an attribute."""
- self._processor_class = processor_class
-
- @classmethod
- def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs):
- r"""
- Instantiate a type of [`~image_processing_utils.ImageProcessorMixin`] from a image processor, *e.g.* a derived
- class of [`BaseImageProcessor`].
-
- Args:
- pretrained_model_name_or_path (`str` or `os.PathLike`):
- This can be either:
-
- - a string, the *model id* of a pretrained image_processor hosted inside a model repo on
- huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
- namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
- - a path to a *directory* containing a image processor file saved using the
- [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] method, e.g.,
- `./my_model_directory/`.
- - a path or url to a saved image processor JSON *file*, e.g.,
- `./my_model_directory/preprocessor_config.json`.
- cache_dir (`str` or `os.PathLike`, *optional*):
- Path to a directory in which a downloaded pretrained model image processor should be cached if the
- standard cache should not be used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force to (re-)download the image processor files and override the cached versions if
- they exist.
- resume_download (`bool`, *optional*, defaults to `False`):
- Whether or not to delete incompletely received file. Attempts to resume the download if such a file
- exists.
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
- use_auth_token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
- when running `transformers-cli login` (stored in `~/.huggingface`).
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
- identifier allowed by git.
- return_unused_kwargs (`bool`, *optional*, defaults to `False`):
- If `False`, then this function returns just the final image processor object. If `True`, then this
- functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
- consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of
- `kwargs` which has not been used to update `image_processor` and is otherwise ignored.
- kwargs (`Dict[str, Any]`, *optional*):
- The values in kwargs of any keys which are image processor attributes will be used to override the
- loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
- controlled by the `return_unused_kwargs` keyword parameter.
-
-
-
- Passing `use_auth_token=True` is required when you want to use a private model.
-
-
-
- Returns:
- An image processor of type [`~image_processing_utils.ImageProcessorMixin`].
-
- Examples: FIXME
-
- """
- image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
-
- return cls.from_dict(image_processor_dict, **kwargs)
-
- def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
- """
- Save an image_processor object to the directory `save_directory`, so that it can be re-loaded using the
- [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method.
-
- Args:
- save_directory (`str` or `os.PathLike`):
- Directory where the image processor JSON file will be saved (will be created if it does not exist).
- push_to_hub (`bool`, *optional*, defaults to `False`):
- Whether or not to push your image processor to the Hugging Face model hub after saving it.
-
-
-
- Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
- which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
- folder. Pass along `temp_dir=True` to use a temporary directory instead.
-
-
-
- kwargs:
- Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
- """
- if os.path.isfile(save_directory):
- raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
-
- if push_to_hub:
- commit_message = kwargs.pop("commit_message", None)
- repo = self._create_or_get_repo(save_directory, **kwargs)
-
- # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
- # loaded from the Hub.
- if self._auto_class is not None:
- custom_object_save(self, save_directory, config=self)
-
- os.makedirs(save_directory, exist_ok=True)
- # If we save using the predefined names, we can load using `from_pretrained`
- output_image_processor_file = os.path.join(save_directory, IMAGE_PROCESSOR_NAME)
-
- self.to_json_file(output_image_processor_file)
- logger.info(f"Image processor saved in {output_image_processor_file}")
-
- if push_to_hub:
- url = self._push_to_hub(repo, commit_message=commit_message)
- logger.info(f"Image processor pushed to the hub in this commit: {url}")
-
- return [output_image_processor_file]
-
- @classmethod
- def get_image_processor_dict(
- cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
- """
- From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
- image processor of type [`~feature_extraction_utils.FeatureExtractionMixin`] using `from_dict`.
-
- Parameters:
- pretrained_model_name_or_path (`str` or `os.PathLike`):
- The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
-
- Returns:
- `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object.
- """
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- resume_download = kwargs.pop("resume_download", False)
- proxies = kwargs.pop("proxies", None)
- use_auth_token = kwargs.pop("use_auth_token", None)
- local_files_only = kwargs.pop("local_files_only", False)
- revision = kwargs.pop("revision", None)
-
- from_pipeline = kwargs.pop("_from_pipeline", None)
- from_auto_class = kwargs.pop("_from_auto", False)
-
- user_agent = {"file_type": "image processor", "from_auto_class": from_auto_class}
- if from_pipeline is not None:
- user_agent["using_pipeline"] = from_pipeline
-
- if is_offline_mode() and not local_files_only:
- logger.info("Offline mode: forcing local_files_only=True")
- local_files_only = True
-
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
- if os.path.isdir(pretrained_model_name_or_path):
- image_processor_file = os.path.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)
- elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
- image_processor_file = pretrained_model_name_or_path
- else:
- image_processor_file = hf_bucket_url(
- pretrained_model_name_or_path, filename=IMAGE_PROCESSOR_NAME, revision=revision, mirror=None
- )
-
- try:
- # Load from URL or cache if already cached
- resolved_image_processor_file = cached_path(
- image_processor_file,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- user_agent=user_agent,
- )
-
- except RepositoryNotFoundError:
- raise EnvironmentError(
- f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
- "'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
- "permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
- "`use_auth_token=True`."
- )
- except RevisionNotFoundError:
- raise EnvironmentError(
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
- f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
- "available revisions."
- )
- except EntryNotFoundError:
- raise EnvironmentError(
- f"{pretrained_model_name_or_path} does not appear to have a file named {IMAGE_PROCESSOR_NAME}."
- )
- except HTTPError as err:
- raise EnvironmentError(
- f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
- )
- except ValueError:
- raise EnvironmentError(
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
- f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
- f" containing a {IMAGE_PROCESSOR_NAME} file.\nCheckout your internet connection or see how to run"
- " the library in offline mode at"
- " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
- )
- except EnvironmentError:
- raise EnvironmentError(
- f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load it "
- "from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
- f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
- f"containing a {IMAGE_PROCESSOR_NAME} file"
- )
-
- try:
- # Load image_processor dict
- with open(resolved_image_processor_file, "r", encoding="utf-8") as reader:
- text = reader.read()
- image_processor_dict = json.loads(text)
-
- except json.JSONDecodeError:
- raise EnvironmentError(
- f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file."
- )
-
- if resolved_image_processor_file == image_processor_file:
- logger.info(f"loading image processor configuration file {image_processor_file}")
- else:
- logger.info(
- f"loading image processor configuration file {image_processor_file} from cache at"
- f" {resolved_image_processor_file}"
- )
-
- return image_processor_dict, kwargs
-
- @classmethod
- def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
- """
- Instantiates a type of [`~image_processing_utils.ImageProcessorMixin`] from a Python dictionary of parameters.
-
- Args:
- image_processor_dict (`Dict[str, Any]`):
- Dictionary that will be used to instantiate the image processor object. Such a dictionary can be
- retrieved from a pretrained checkpoint by leveraging the
- [`~feature_extraction_utils.FeatureExtractionMixin.to_dict`] method.
- kwargs (`Dict[str, Any]`):
- Additional parameters from which to initialize the image processor object.
-
- Returns:
- [`~feature_extraction_utils.FeatureExtractionMixin`]: The image processor object instantiated from those
- parameters.
- """
- return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
-
- image_processor = cls(**image_processor_dict)
-
- # Update image_processor with kwargs if needed
- to_remove = []
- for key, value in kwargs.items():
- if hasattr(image_processor, key):
- setattr(image_processor, key, value)
- to_remove.append(key)
- for key in to_remove:
- kwargs.pop(key, None)
-
- logger.info(f"image processor {image_processor}")
- if return_unused_kwargs:
- return image_processor, kwargs
- else:
- return image_processor
-
- def to_dict(self) -> Dict[str, Any]:
- """
- Serializes this instance to a Python dictionary.
-
- Returns:
- `Dict[str, Any]`: Dictionary of all the attributes that make up this image processor instance.
- """
- output = copy.deepcopy(self.__dict__)
- output["image_processor_type"] = self.__class__.__name__
-
- return output
-
- @classmethod
- def from_json_file(cls, json_file: Union[str, os.PathLike]):
- """
- Instantiates an image processor of type [`~image_processing_utils.ImageProcessorMixin`] from the path to a JSON
- file of parameters.
-
- Args:
- json_file (`str` or `os.PathLike`):
- Path to the JSON file containing the parameters.
-
- Returns:
- A image processor of type [`~feature_extraction_utils.FeatureExtractionMixin`]: The image_processor object
- instantiated from that JSON file.
- """
- with open(json_file, "r", encoding="utf-8") as reader:
- text = reader.read()
- image_processor_dict = json.loads(text)
- return cls(**image_processor_dict)
-
- def to_json_string(self) -> str:
- """
- Serializes this instance to a JSON string.
-
- Returns:
- `str`: String containing all the attributes that make up this image_processor instance in JSON format.
- """
- dictionary = self.to_dict()
-
- for key, value in dictionary.items():
- if isinstance(value, np.ndarray):
- dictionary[key] = value.tolist()
-
- # make sure private name "_processor_class" is correctly
- # saved as "processor_class"
- _processor_class = dictionary.pop("_processor_class", None)
- if _processor_class is not None:
- dictionary["processor_class"] = _processor_class
-
- return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
-
- def to_json_file(self, json_file_path: Union[str, os.PathLike]):
- """
- Save this instance to a JSON file.
-
- Args:
- json_file_path (`str` or `os.PathLike`):
- Path to the JSON file in which this image_processor instance's parameters will be saved.
- """
- with open(json_file_path, "w", encoding="utf-8") as writer:
- writer.write(self.to_json_string())
-
- def __repr__(self):
- return f"{self.__class__.__name__} {self.to_json_string()}"
-
- @classmethod
- def register_for_auto_class(cls, auto_class="AutoImageProcessor"):
- """
- Register this class with a given auto class. This should only be used for custom image processors as the ones
- in the library are already mapped with `AutoImageProcessor`.
-
-
-
- This API is experimental and may have some slight breaking changes in the next releases.
-
-
-
- Args:
- auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`):
- The auto class to register this new image processor with.
- """
- if not isinstance(auto_class, str):
- auto_class = auto_class.__name__
-
- import transformers.models.auto as auto_module
-
- if not hasattr(auto_module, auto_class):
- raise ValueError(f"{auto_class} is not a valid auto class.")
-
- cls._auto_class = auto_class
-
-
-ImageProcessorMixin.push_to_hub = copy_func(ImageProcessorMixin.push_to_hub)
-ImageProcessorMixin.push_to_hub.__doc__ = ImageProcessorMixin.push_to_hub.__doc__.format(
- object="image processor", object_class="AutoImageProcessor", object_files="image processor file"
-)
+# We use aliasing whilst we phase out the old API. Once feature extractors for vision models
+# are deprecated, ImageProcessor mixin will be implemented. Any shared logic will be abstracted out.
+ImageProcessorMixin = FeatureExtractionMixin
class BaseImageProcessor(ImageProcessorMixin):
diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py
index adbbbd7f14e555..024b46911a750a 100644
--- a/src/transformers/image_transforms.py
+++ b/src/transformers/image_transforms.py
@@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np
-import PIL
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available
@@ -26,7 +25,6 @@
from .image_utils import (
ChannelDimension,
- get_channel_dimension_axis,
get_image_size,
infer_channel_dimension_format,
is_jax_tensor,
@@ -217,6 +215,7 @@ def resize(
size: Tuple[int, int],
resample=PIL.Image.BILINEAR,
data_format: Optional[ChannelDimension] = None,
+ return_numpy: bool = True,
) -> np.ndarray:
"""
Resizes `image` to (h, w) specified by `size` using the PIL library.
@@ -258,123 +257,3 @@ def resize(
resized_image = np.array(resized_image)
resized_image = to_channel_dimension_format(resized_image, data_format)
return resized_image
-
-
-def normalize(
- image,
- mean: Union[float, Iterable[float]],
- std: Union[float, Iterable[float]],
- data_format: Optional[ChannelDimension] = None,
-) -> np.ndarray:
- """
- Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.
-
- image = (image - mean) / std
-
- Args:
- image (`np.ndarray`):
- The image to normalize.
- mean (`float` or `Iterable[float]`):
- The mean to use for normalization.
- std (`float` or `Iterable[float]`):
- The standard deviation to use for normalization.
- data_format (`ChannelDimension`, *optional*, defaults to `None`):
- The channel dimension format of the output image. If `None`, will use the inferred format from the input.
- """
- if not isinstance(image, np.ndarray):
- raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
-
- input_data_format = infer_channel_dimension_format(image)
- channel_axis = get_channel_dimension_axis(image)
- num_channels = image.shape[channel_axis]
-
- if isinstance(mean, Iterable):
- if len(mean) != num_channels:
- raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
- else:
- mean = [mean] * num_channels
-
- if isinstance(std, Iterable):
- if len(std) != num_channels:
- raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
- else:
- std = [std] * num_channels
-
- if input_data_format == ChannelDimension.LAST:
- image = (image - mean) / std
- else:
- image = ((image.T - mean) / std).T
-
- image = to_channel_dimension_format(image, data_format) if data_format is not None else image
- return image
-
-
-def center_crop(
- image: np.ndarray,
- size: Tuple[int, int],
- data_format: Optional[Union[str, ChannelDimension]] = None,
-) -> np.ndarray:
- """
- Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped
- to the size given, it will be padded (so the returned result will always be of size `size`).
-
- Args:
- image (`np.ndarray`):
- The image to crop.
- size (`Tuple[int, int]`):
- The target size for the cropped image.
- data_format (`str` or `ChannelDimension`, *optional*, defaults to `None`):
- The channel dimension format for the output image. Can be one of:
- - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
-
- Returns:
- `np.ndarray`: The cropped image.
- """
- if not isinstance(image, np.ndarray):
- raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
-
- if not isinstance(size, Iterable) or len(size) != 2:
- raise ValueError("size must have 2 elements representing the height and width of the output image")
-
- input_data_format = infer_channel_dimension_format(image)
- output_data_format = data_format if data_format is not None else input_data_format
-
- # We perform the crop in (C, H, W) format and then convert to the output format
- image = to_channel_dimension_format(image, ChannelDimension.FIRST)
-
- orig_height, orig_width = get_image_size(image)
- crop_height, crop_width = size
-
- top = (orig_height - crop_height) // 2
- bottom = top + crop_height # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
- left = (orig_width - crop_width) // 2
- right = left + crop_width # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
-
- # Check if cropped area is within image boundaries
- if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
- image = image[..., top:bottom, left:right]
- image = to_channel_dimension_format(image, output_data_format)
- return image
-
- # Otherwise, we may need to pad if the image is too small. Oh joy...
- new_height = max(crop_height, orig_height)
- new_width = max(crop_width, orig_width)
- new_shape = image.shape[:-2] + (new_height, new_width)
- new_image = np.zeros_like(image, shape=new_shape)
-
- # If the image is too small, pad it with zeros
- top_pad = (new_height - orig_height) // 2
- bottom_pad = top_pad + orig_height
- left_pad = (new_width - orig_width) // 2
- right_pad = left_pad + orig_width
- new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
-
- top += top_pad
- bottom += top_pad
- left += left_pad
- right += left_pad
-
- new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
- new_image = to_channel_dimension_format(new_image, output_data_format)
- return new_image
diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py
index b0c128d86aeaae..0ba86d14b7975d 100644
--- a/src/transformers/image_utils.py
+++ b/src/transformers/image_utils.py
@@ -20,8 +20,14 @@
import requests
-from .utils import is_flax_available, is_tf_available, is_torch_available
-from .utils.generic import _is_jax, _is_tensorflow, _is_torch, to_numpy
+from .utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available
+from .utils.constants import ( # noqa: F401
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+)
+from .utils.generic import ExplicitEnum, _is_jax, _is_tensorflow, _is_torch, to_numpy
if is_vision_available():
@@ -39,9 +45,9 @@
] # noqa
-class ChannelDimension(enum.Enum):
- FIRST = 1
- LAST = 3
+class ChannelDimension(ExplicitEnum):
+ FIRST = "channels_first"
+ LAST = "channels_last"
def is_torch_tensor(obj):
@@ -106,25 +112,6 @@ def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension:
raise ValueError("Unable to infer channel dimension format")
-def get_channel_dimension_axis(image: np.ndarray) -> int:
- """
- Returns the channel dimension axis of the image.
-
- Args:
- image (`np.ndarray`):
- The image to get the channel dimension axis of.
-
- Returns:
- The channel dimension axis of the image.
- """
- channel_dim = infer_channel_dimension_format(image)
- if channel_dim == ChannelDimension.FIRST:
- return image.ndim - 3
- elif channel_dim == ChannelDimension.LAST:
- return image.ndim - 1
- raise ValueError(f"Unsupported data format: {channel_dim}")
-
-
def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
"""
Returns the (height, width) dimensions of the image.
diff --git a/src/transformers/models/glpn/image_processing_glpn.py b/src/transformers/models/glpn/image_processing_glpn.py
index ffc7efe16ea978..98ae1d53f73d60 100644
--- a/src/transformers/models/glpn/image_processing_glpn.py
+++ b/src/transformers/models/glpn/image_processing_glpn.py
@@ -16,8 +16,8 @@
from typing import List, Optional, Union
-import PIL.Image
import numpy as np
+import PIL.Image
from transformers.utils.generic import TensorType
diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py
index f63555bacb485f..69e6de1587b8d6 100644
--- a/tests/test_image_transforms.py
+++ b/tests/test_image_transforms.py
@@ -35,9 +35,7 @@
import PIL.Image
from transformers.image_transforms import (
- center_crop,
get_resize_output_image_size,
- normalize,
resize,
to_channel_dimension_format,
to_pil_image,
@@ -174,53 +172,3 @@ def test_resize(self):
self.assertIsInstance(resized_image, PIL.Image.Image)
# PIL size is in (width, height) order
self.assertEqual(resized_image.size, (40, 30))
-
- def test_normalize(self):
- image = np.random.randint(0, 256, (224, 224, 3)) / 255
-
- # Test that exception is raised if inputs are incorrect
- # Not a numpy array image
- with self.assertRaises(ValueError):
- normalize(5, 5, 5)
-
- # Number of mean values != number of channels
- with self.assertRaises(ValueError):
- normalize(image, mean=(0.5, 0.6), std=1)
-
- # Number of std values != number of channels
- with self.assertRaises(ValueError):
- normalize(image, mean=1, std=(0.5, 0.6))
-
- # Test result is correct - output data format is channels_first and normalization
- # correctly computed
- mean = (0.5, 0.6, 0.7)
- std = (0.1, 0.2, 0.3)
- expected_image = ((image - mean) / std).transpose((2, 0, 1))
-
- normalized_image = normalize(image, mean=mean, std=std, data_format="channels_first")
- self.assertIsInstance(normalized_image, np.ndarray)
- self.assertEqual(normalized_image.shape, (3, 224, 224))
- self.assertTrue(np.allclose(normalized_image, expected_image))
-
- def test_center_crop(self):
- image = np.random.randint(0, 256, (3, 224, 224))
-
- # Test that exception is raised if inputs are incorrect
- with self.assertRaises(ValueError):
- center_crop(image, 10)
-
- # Test result is correct - output data format is channels_first and center crop
- # correctly computed
- expected_image = image[:, 52:172, 82:142].transpose(1, 2, 0)
- cropped_image = center_crop(image, (120, 60), data_format="channels_last")
- self.assertIsInstance(cropped_image, np.ndarray)
- self.assertEqual(cropped_image.shape, (120, 60, 3))
- self.assertTrue(np.allclose(cropped_image, expected_image))
-
- # Test that image is padded with zeros if crop size is larger than image size
- expected_image = np.zeros((300, 260, 3))
- expected_image[38:262, 18:242, :] = image.transpose((1, 2, 0))
- cropped_image = center_crop(image, (300, 260), data_format="channels_last")
- self.assertIsInstance(cropped_image, np.ndarray)
- self.assertEqual(cropped_image.shape, (300, 260, 3))
- self.assertTrue(np.allclose(cropped_image, expected_image))
diff --git a/tests/utils/test_image_utils.py b/tests/utils/test_image_utils.py
index 6868e117c4c386..0ae5d78fb2dc0a 100644
--- a/tests/utils/test_image_utils.py
+++ b/tests/utils/test_image_utils.py
@@ -20,7 +20,7 @@
import pytest
from transformers import is_torch_available, is_vision_available
-from transformers.image_utils import ChannelDimension, get_channel_dimension_axis
+from transformers.image_utils import ChannelDimension
from transformers.testing_utils import require_torch, require_vision
@@ -535,26 +535,3 @@ def test_infer_channel_dimension(self):
image = np.random.randint(0, 256, (1, 3, 4, 5))
inferred_dim = infer_channel_dimension_format(image)
self.assertEqual(inferred_dim, ChannelDimension.FIRST)
-
- def test_get_channel_dimension_axis(self):
- # Test we correctly identify the channel dimension
- image = np.random.randint(0, 256, (3, 4, 5))
- inferred_axis = get_channel_dimension_axis(image)
- self.assertEqual(inferred_axis, 0)
-
- image = np.random.randint(0, 256, (1, 4, 5))
- inferred_axis = get_channel_dimension_axis(image)
- self.assertEqual(inferred_axis, 0)
-
- image = np.random.randint(0, 256, (4, 5, 3))
- inferred_axis = get_channel_dimension_axis(image)
- self.assertEqual(inferred_axis, 2)
-
- image = np.random.randint(0, 256, (4, 5, 1))
- inferred_axis = get_channel_dimension_axis(image)
- self.assertEqual(inferred_axis, 2)
-
- # We can take a batched array of images and find the dimension
- image = np.random.randint(0, 256, (1, 3, 4, 5))
- inferred_axis = get_channel_dimension_axis(image)
- self.assertEqual(inferred_axis, 1)