diff --git a/docs/source/use_with_pytorch.mdx b/docs/source/use_with_pytorch.mdx index 491ba7105f1..cdc6cb644c1 100644 --- a/docs/source/use_with_pytorch.mdx +++ b/docs/source/use_with_pytorch.mdx @@ -78,28 +78,54 @@ To get a single tensor, you must explicitly use the [`Array`] feature type and s ```py >>> from datasets import Dataset, Features, ClassLabel ->>> data = [0, 0, 1] ->>> features = Features({"data": ClassLabel(names=["negative", "positive"])}) ->>> ds = Dataset.from_dict({"data": data}, features=features) +>>> labels = [0, 0, 1] +>>> features = Features({"label": ClassLabel(names=["negative", "positive"])}) +>>> ds = Dataset.from_dict({"label": labels}, features=features) >>> ds = ds.with_format("torch") >>> ds[:3] -{'data': tensor([0, 0, 1])} +{'label': tensor([0, 0, 1])} ``` -However, since it's not possible to convert text data to PyTorch tensors, you can't format a `string` column to PyTorch. -Instead, you can explicitly format certain columns and leave the other columns unformatted: +String and binary objects are unchanged, since PyTorch only supports numbers. + +The [`Image`] and [`Audio`] feature types are also supported: ```py ->>> from datasets import Dataset, Features ->>> text = ["foo", "bar"] ->>> data = [0, 1] ->>> ds = Dataset.from_dict({"text": text, "data": data}) ->>> ds = ds.with_format("torch", columns=["data"], output_all_columns=True) ->>> ds[:2] -{'data': tensor([0, 1]), 'text': ['foo', 'bar']} +>>> from datasets import Dataset, Features, Audio, Image +>>> images = ["path/to/image.png"] * 10 +>>> features = Features({"image": Image()}) +>>> ds = Dataset.from_dict({"image": images}, features=features) +>>> ds = ds.with_format("torch") +>>> ds[0]["image"].shape +torch.Size([512, 512, 4]) +>>> ds[0] +{'image': tensor([[[255, 215, 106, 255], + [255, 215, 106, 255], + ..., + [255, 255, 255, 255], + [255, 255, 255, 255]]], dtype=torch.uint8)} +>>> ds[:2]["image"].shape +torch.Size([2, 512, 512, 4]) +>>> ds[:2] +{'image': tensor([[[[255, 215, 106, 255], + [255, 215, 106, 255], + ..., + [255, 255, 255, 255], + [255, 255, 255, 255]]]], dtype=torch.uint8)} ``` -The [`Image`] and [`Audio`] feature types are not supported yet. +```py +>>> from datasets import Dataset, Features, Audio, Image +>>> audio = ["path/to/audio.wav"] * 10 +>>> features = Features({"audio": Audio()}) +>>> ds = Dataset.from_dict({"audio": audio}, features=features) +>>> ds = ds.with_format("torch") +>>> ds[0]["audio"]["array"] +tensor([ 6.1035e-05, 1.5259e-05, 1.6785e-04, ..., -1.5259e-05, + -1.5259e-05, 1.5259e-05]) +>>> ds[0]["audio"]["sampling_rate"] +tensor(44100) +``` ## Data loading diff --git a/docs/source/use_with_tensorflow.mdx b/docs/source/use_with_tensorflow.mdx index 64c5b07d6f4..0e80dbea5f9 100644 --- a/docs/source/use_with_tensorflow.mdx +++ b/docs/source/use_with_tensorflow.mdx @@ -81,15 +81,15 @@ To get a single tensor, you must explicitly use the Array feature type and speci ```py >>> from datasets import Dataset, Features, ClassLabel ->>> data = [0, 0, 1] ->>> features = Features({"data": ClassLabel(names=["negative", "positive"])}) ->>> ds = Dataset.from_dict({"data": data}, features=features) +>>> labels = [0, 0, 1] +>>> features = Features({"label": ClassLabel(names=["negative", "positive"])}) +>>> ds = Dataset.from_dict({"label": labels}, features=features) >>> ds = ds.with_format("tf") >>> ds[:3] -{'data': +{'label': } ``` -Strings are also supported: +Strings and binary objects are also supported: ```py >>> from datasets import Dataset, Features @@ -111,7 +111,45 @@ You can also explicitly format certain columns and leave the other columns unfor 'text': ['foo', 'bar']} ``` -The [`Image`] and [`Audio`] feature types are not supported yet. +String and binary objects are unchanged, since PyTorch only supports numbers. + +The [`Image`] and [`Audio`] feature types are also supported: + +```py +>>> from datasets import Dataset, Features, Audio, Image +>>> images = ["path/to/image.png"] * 10 +>>> features = Features({"image": Image()}) +>>> ds = Dataset.from_dict({"image": images}, features=features) +>>> ds = ds.with_format("tf") +>>> ds[0] +{'image': } +>>> ds[:2] +{'image': } +``` + +```py +>>> from datasets import Dataset, Features, Audio, Image +>>> audio = ["path/to/audio.wav"] * 10 +>>> features = Features({"audio": Audio()}) +>>> ds = Dataset.from_dict({"audio": audio}, features=features) +>>> ds = ds.with_format("tf") +>>> ds[0]["audio"]["array"] + +>>> ds[0]["audio"]["sampling_rate"] + +``` ## Data loading diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index eda0b318128..2aac44f9a11 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3978,9 +3978,14 @@ def _int64_feature(values): """Returns an int64_list from a list of bool / enum / int / uint.""" return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) - def _feature(values: Union[float, int, str, np.ndarray]) -> "tf.train.Feature": + def _feature(values: Union[float, int, str, np.ndarray, list]) -> "tf.train.Feature": """Typechecks `values` and returns the corresponding tf.train.Feature.""" - if isinstance(values, np.ndarray): + if isinstance(values, list): + if values and isinstance(values[0], str): + return _bytes_feature([v.encode() for v in values]) + else: + raise ValueError(f"values={values} is empty or contains items that cannot be serialized") + elif isinstance(values, np.ndarray): if values.dtype == np.dtype(float): return _float_feature(values) elif values.dtype == np.int64: @@ -3991,9 +3996,9 @@ def _feature(values: Union[float, int, str, np.ndarray]) -> "tf.train.Feature": return _bytes_feature([v.encode() for v in values]) else: raise ValueError( - f"values={values} is an np.ndarray with items of dtype {values[0].dtype}, which cannot be serialized" + f"values={values} is empty or is an np.ndarray with items of dtype {values[0].dtype}, which cannot be serialized" ) - if hasattr(values, "dtype"): + elif hasattr(values, "dtype"): if np.issubdtype(values.dtype, np.floating): return _float_feature([values.item()]) elif np.issubdtype(values.dtype, np.integer): @@ -4003,7 +4008,7 @@ def _feature(values: Union[float, int, str, np.ndarray]) -> "tf.train.Feature": else: raise ValueError(f"values={values} has dtype {values.dtype}, which cannot be serialized") else: - raise ValueError(f"values={values} are not numpy objects, and so cannot be serialized") + raise ValueError(f"values={values} are not numpy objects or strings, and so cannot be serialized") def serialize_example(ex): feature = {key: _feature(value) for key, value in ex.items()} diff --git a/src/datasets/features/image.py b/src/datasets/features/image.py index 2b329b8a427..5ee17f469d9 100644 --- a/src/datasets/features/image.py +++ b/src/datasets/features/image.py @@ -81,6 +81,9 @@ def encode_example(self, value: Union[str, dict, np.ndarray, "PIL.Image.Image"]) else: raise ImportError("To support encoding images, please install 'Pillow'.") + if isinstance(value, list): + value = np.array(value) + if isinstance(value, str): return {"path": value, "bytes": None} elif isinstance(value, np.ndarray): diff --git a/src/datasets/formatting/__init__.py b/src/datasets/formatting/__init__.py index d9b0f44f950..77a20f5c0dc 100644 --- a/src/datasets/formatting/__init__.py +++ b/src/datasets/formatting/__init__.py @@ -23,12 +23,12 @@ ArrowFormatter, CustomFormatter, Formatter, - NumpyFormatter, PandasFormatter, PythonFormatter, format_table, query_table, ) +from .np_formatter import NumpyFormatter logger = logging.get_logger(__name__) diff --git a/src/datasets/formatting/formatting.py b/src/datasets/formatting/formatting.py index ecc5b5e6606..095c6ff6d6d 100644 --- a/src/datasets/formatting/formatting.py +++ b/src/datasets/formatting/formatting.py @@ -21,6 +21,7 @@ import pandas as pd import pyarrow as pa +from ..features import Features from ..features.features import _ArrayXDExtensionType, _is_zero_copy_only, decode_nested_example, pandas_types_mapper from ..table import Table from ..utils.py_utils import no_op_if_value_is_null @@ -198,8 +199,8 @@ def _arrow_array_to_numpy(self, pa_array: pa.Array) -> np.ndarray: or (isinstance(x, float) and np.isnan(x)) for x in array ): - return np.array(array, copy=False, **{**self.np_array_kwargs, "dtype": object}) - return np.array(array, copy=False, **self.np_array_kwargs) + return np.array(array, copy=False, dtype=object) + return np.array(array, copy=False) class PandasArrowExtractor(BaseArrowExtractor[pd.DataFrame, pd.Series, pd.DataFrame]): @@ -214,7 +215,7 @@ def extract_batch(self, pa_table: pa.Table) -> pd.DataFrame: class PythonFeaturesDecoder: - def __init__(self, features): + def __init__(self, features: Features): self.features = features def decode_row(self, row: dict) -> dict: @@ -228,7 +229,7 @@ def decode_batch(self, batch: dict) -> dict: class PandasFeaturesDecoder: - def __init__(self, features): + def __init__(self, features: Features): self.features = features def decode_row(self, row: pd.DataFrame) -> pd.DataFrame: @@ -325,30 +326,6 @@ def format_batch(self, pa_table: pa.Table) -> dict: return batch -class NumpyFormatter(Formatter[dict, np.ndarray, dict]): - def __init__(self, features=None, decoded=True, **np_array_kwargs): - super().__init__(features=features, decoded=decoded) - self.np_array_kwargs = np_array_kwargs - - def format_row(self, pa_table: pa.Table) -> dict: - row = self.numpy_arrow_extractor(**self.np_array_kwargs).extract_row(pa_table) - if self.decoded: - row = self.python_features_decoder.decode_row(row) - return row - - def format_column(self, pa_table: pa.Table) -> np.ndarray: - column = self.numpy_arrow_extractor(**self.np_array_kwargs).extract_column(pa_table) - if self.decoded: - column = self.python_features_decoder.decode_column(column, pa_table.column_names[0]) - return column - - def format_batch(self, pa_table: pa.Table) -> dict: - batch = self.numpy_arrow_extractor(**self.np_array_kwargs).extract_batch(pa_table) - if self.decoded: - batch = self.python_features_decoder.decode_batch(batch) - return batch - - class PandasFormatter(Formatter): def format_row(self, pa_table: pa.Table) -> pd.DataFrame: row = self.pandas_arrow_extractor().extract_row(pa_table) diff --git a/src/datasets/formatting/jax_formatter.py b/src/datasets/formatting/jax_formatter.py index 0a554203be5..19b656d7ab1 100644 --- a/src/datasets/formatting/jax_formatter.py +++ b/src/datasets/formatting/jax_formatter.py @@ -13,11 +13,13 @@ # limitations under the License. # Lint as: python3 +import sys from typing import TYPE_CHECKING import numpy as np import pyarrow as pa +from .. import config from ..utils.py_utils import map_nested from .formatting import Formatter @@ -28,23 +30,46 @@ class JaxFormatter(Formatter[dict, "jnp.ndarray", dict]): def __init__(self, features=None, decoded=True, **jnp_array_kwargs): + super().__init__(features=features, decoded=decoded) self.jnp_array_kwargs = jnp_array_kwargs import jax.numpy as jnp # noqa import jax at initialization + def _consolidate(self, column): + import jax.numpy as jnp + + if isinstance(column, list) and column: + if all( + isinstance(x, jnp.ndarray) and x.shape == column[0].shape and x.dtype == column[0].dtype + for x in column + ): + return jnp.stack(column) + return column + def _tensorize(self, value): import jax import jax.numpy as jnp + if isinstance(value, (str, bytes, type(None))): + return value + elif isinstance(value, (np.character, np.ndarray)) and np.issubdtype(value.dtype, np.character): + return value.tolist() + default_dtype = {} - if np.issubdtype(value.dtype, np.integer): + + if isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.integer): # the default int precision depends on the jax config # see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision if jax.config.jax_enable_x64: default_dtype = {"dtype": jnp.int64} else: default_dtype = {"dtype": jnp.int32} - elif np.issubdtype(value.dtype, np.floating): + elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating): default_dtype = {"dtype": jnp.float32} + elif config.PIL_AVAILABLE and "PIL" in sys.modules: + import PIL.Image + + if isinstance(value, PIL.Image.Image): + value = np.asarray(value) # calling jnp.array on a np.ndarray does copy the data # see https://github.com/google/jax/issues/4486 @@ -52,23 +77,33 @@ def _tensorize(self, value): def _recursive_tensorize(self, data_struct: dict): # support for nested types like struct of list of struct - if isinstance(data_struct, (list, np.ndarray)): - data_struct = np.array(data_struct, copy=False) + if isinstance(data_struct, np.ndarray): if data_struct.dtype == object: # jax arrays cannot be instantied from an array of objects - return [self.recursive_tensorize(substruct) for substruct in data_struct] + return self._consolidate([self.recursive_tensorize(substruct) for substruct in data_struct]) return self._tensorize(data_struct) def recursive_tensorize(self, data_struct: dict): - return map_nested(self._recursive_tensorize, data_struct, map_list=False) + return map_nested(self._recursive_tensorize, data_struct) def format_row(self, pa_table: pa.Table) -> dict: row = self.numpy_arrow_extractor().extract_row(pa_table) + if self.decoded: + row = self.python_features_decoder.decode_row(row) return self.recursive_tensorize(row) def format_column(self, pa_table: pa.Table) -> "jnp.ndarray": - col = self.numpy_arrow_extractor().extract_column(pa_table) - return self.recursive_tensorize(col) + column = self.numpy_arrow_extractor().extract_column(pa_table) + if self.decoded: + column = self.python_features_decoder.decode_column(column, pa_table.column_names[0]) + column = self.recursive_tensorize(column) + column = self._consolidate(column) + return column def format_batch(self, pa_table: pa.Table) -> dict: batch = self.numpy_arrow_extractor().extract_batch(pa_table) - return self.recursive_tensorize(batch) + if self.decoded: + batch = self.python_features_decoder.decode_batch(batch) + batch = self.recursive_tensorize(batch) + for column_name in batch: + batch[column_name] = self._consolidate(batch[column_name]) + return batch diff --git a/src/datasets/formatting/np_formatter.py b/src/datasets/formatting/np_formatter.py new file mode 100644 index 00000000000..f3280a65156 --- /dev/null +++ b/src/datasets/formatting/np_formatter.py @@ -0,0 +1,85 @@ +import sys + +import numpy as np +import pyarrow as pa + +from .. import config +from ..utils.py_utils import map_nested +from .formatting import Formatter + + +class NumpyFormatter(Formatter[dict, np.ndarray, dict]): + def __init__(self, features=None, decoded=True, **np_array_kwargs): + super().__init__(features=features, decoded=decoded) + self.np_array_kwargs = np_array_kwargs + + def _consolidate(self, column): + if isinstance(column, list): + if column and all( + isinstance(x, np.ndarray) and x.shape == column[0].shape and x.dtype == column[0].dtype for x in column + ): + return np.stack(column) + else: + # don't use np.array(column, dtype=object) + # since it fails in certain cases + # see https://stackoverflow.com/q/51005699 + out = np.empty(len(column), dtype=object) + out[:] = column + return out + return column + + def _tensorize(self, value): + + if isinstance(value, (str, bytes, type(None))): + return value + elif isinstance(value, (np.character, np.ndarray)) and np.issubdtype(value.dtype, np.character): + return value + elif isinstance(value, np.number): + return value + + default_dtype = {} + + if isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.integer): + default_dtype = {"dtype": np.int64} + elif isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.floating): + default_dtype = {"dtype": np.float32} + elif config.PIL_AVAILABLE and "PIL" in sys.modules: + import PIL.Image + + if isinstance(value, PIL.Image.Image): + return np.asarray(value, **self.np_array_kwargs) + + return np.array(value, **{**default_dtype, **self.np_array_kwargs}) + + def _recursive_tensorize(self, data_struct: dict): + # support for nested types like struct of list of struct + if isinstance(data_struct, np.ndarray): + if data_struct.dtype == object: # torch tensors cannot be instantied from an array of objects + return self._consolidate([self.recursive_tensorize(substruct) for substruct in data_struct]) + return self._tensorize(data_struct) + + def recursive_tensorize(self, data_struct: dict): + return map_nested(self._recursive_tensorize, data_struct) + + def format_row(self, pa_table: pa.Table) -> dict: + row = self.numpy_arrow_extractor().extract_row(pa_table) + if self.decoded: + row = self.python_features_decoder.decode_row(row) + return self.recursive_tensorize(row) + + def format_column(self, pa_table: pa.Table) -> np.ndarray: + column = self.numpy_arrow_extractor().extract_column(pa_table) + if self.decoded: + column = self.python_features_decoder.decode_column(column, pa_table.column_names[0]) + column = self.recursive_tensorize(column) + column = self._consolidate(column) + return column + + def format_batch(self, pa_table: pa.Table) -> dict: + batch = self.numpy_arrow_extractor().extract_batch(pa_table) + if self.decoded: + batch = self.python_features_decoder.decode_batch(batch) + batch = self.recursive_tensorize(batch) + for column_name in batch: + batch[column_name] = self._consolidate(batch[column_name]) + return batch diff --git a/src/datasets/formatting/tf_formatter.py b/src/datasets/formatting/tf_formatter.py index d07f1f636cc..800efcced67 100644 --- a/src/datasets/formatting/tf_formatter.py +++ b/src/datasets/formatting/tf_formatter.py @@ -13,11 +13,13 @@ # limitations under the License. # Lint as: python3 +import sys from typing import TYPE_CHECKING import numpy as np import pyarrow as pa +from .. import config from ..utils.py_utils import map_nested from .formatting import Formatter @@ -28,61 +30,76 @@ class TFFormatter(Formatter[dict, "tf.Tensor", dict]): def __init__(self, features=None, decoded=True, **tf_tensor_kwargs): + super().__init__(features=features, decoded=decoded) self.tf_tensor_kwargs = tf_tensor_kwargs import tensorflow as tf # noqa: import tf at initialization + def _consolidate(self, column): + import tensorflow as tf + + if isinstance(column, list) and column: + if all( + isinstance(x, tf.Tensor) and x.shape == column[0].shape and x.dtype == column[0].dtype for x in column + ): + return tf.stack(column) + elif all( + isinstance(x, (tf.Tensor, tf.RaggedTensor)) and x.ndim == 1 and x.dtype == column[0].dtype + for x in column + ): + # only rag 1-D tensors, otherwise some dimensions become ragged even though they were consolidated + return tf.ragged.stack(column) + + return column + def _tensorize(self, value): import tensorflow as tf - if "dtype" not in self.tf_tensor_kwargs: - if np.issubdtype(value.dtype, np.integer): - np_dtype = np.int64 - tf_dtype = tf.int64 - default_dtype = {"dtype": tf_dtype} - elif np.issubdtype(value.dtype, np.floating): - np_dtype = np.float32 - tf_dtype = tf.float32 - default_dtype = {"dtype": tf_dtype} - else: - np_dtype = None - tf_dtype = None - default_dtype = {} - else: - tf_dtype = self.tf_tensor_kwargs["dtype"] - np_dtype = tf_dtype.as_numpy_dtype - default_dtype = {} - - # Saving the most expensive methods for last - try: - return tf.convert_to_tensor(value, dtype=tf_dtype) - except ValueError: - try: - return tf.ragged.stack([np.array(subarr, dtype=np_dtype) for subarr in value]) - except ValueError: - # tf.ragged.constant is orders of magnitude slower than tf.ragged.stack - return tf.ragged.constant(value, **{**default_dtype, **self.tf_tensor_kwargs}) + if value is None: + return value + + default_dtype = {} + + if isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.integer): + default_dtype = {"dtype": tf.int64} + elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating): + default_dtype = {"dtype": tf.float32} + elif config.PIL_AVAILABLE and "PIL" in sys.modules: + import PIL.Image + + if isinstance(value, PIL.Image.Image): + value = np.asarray(value) + + return tf.convert_to_tensor(value, **{**default_dtype, **self.tf_tensor_kwargs}) def _recursive_tensorize(self, data_struct: dict): # support for nested types like struct of list of struct - if isinstance(data_struct, (list, np.ndarray)): - if data_struct.dtype == object: # tensorflow tensors can sometimes be instantied from an array of objects - try: - return self._tensorize(data_struct) - except ValueError: - return [self.recursive_tensorize(substruct) for substruct in data_struct] + if isinstance(data_struct, np.ndarray): + if data_struct.dtype == object: # tf tensors cannot be instantied from an array of objects + return self._consolidate([self.recursive_tensorize(substruct) for substruct in data_struct]) return self._tensorize(data_struct) def recursive_tensorize(self, data_struct: dict): - return map_nested(self._recursive_tensorize, data_struct, map_list=False) + return map_nested(self._recursive_tensorize, data_struct) def format_row(self, pa_table: pa.Table) -> dict: row = self.numpy_arrow_extractor().extract_row(pa_table) + if self.decoded: + row = self.python_features_decoder.decode_row(row) return self.recursive_tensorize(row) def format_column(self, pa_table: pa.Table) -> "tf.Tensor": - col = self.numpy_arrow_extractor().extract_column(pa_table) - return self.recursive_tensorize(col) + column = self.numpy_arrow_extractor().extract_column(pa_table) + if self.decoded: + column = self.python_features_decoder.decode_column(column, pa_table.column_names[0]) + column = self.recursive_tensorize(column) + column = self._consolidate(column) + return column def format_batch(self, pa_table: pa.Table) -> dict: batch = self.numpy_arrow_extractor().extract_batch(pa_table) - return self.recursive_tensorize(batch) + if self.decoded: + batch = self.python_features_decoder.decode_batch(batch) + batch = self.recursive_tensorize(batch) + for column_name in batch: + batch[column_name] = self._consolidate(batch[column_name]) + return batch diff --git a/src/datasets/formatting/torch_formatter.py b/src/datasets/formatting/torch_formatter.py index c5a7d3c214f..953aa7b968e 100644 --- a/src/datasets/formatting/torch_formatter.py +++ b/src/datasets/formatting/torch_formatter.py @@ -13,11 +13,13 @@ # limitations under the License. # Lint as: python3 +import sys from typing import TYPE_CHECKING import numpy as np import pyarrow as pa +from .. import config from ..utils.py_utils import map_nested from .formatting import Formatter @@ -28,39 +30,71 @@ class TorchFormatter(Formatter[dict, "torch.Tensor", dict]): def __init__(self, features=None, decoded=True, **torch_tensor_kwargs): + super().__init__(features=features, decoded=decoded) self.torch_tensor_kwargs = torch_tensor_kwargs import torch # noqa import torch at initialization + def _consolidate(self, column): + import torch + + if isinstance(column, list) and column: + if all( + isinstance(x, torch.Tensor) and x.shape == column[0].shape and x.dtype == column[0].dtype + for x in column + ): + return torch.stack(column) + return column + def _tensorize(self, value): import torch + if isinstance(value, (str, bytes, type(None))): + return value + elif isinstance(value, (np.character, np.ndarray)) and np.issubdtype(value.dtype, np.character): + return value.tolist() + default_dtype = {} - if np.issubdtype(value.dtype, np.integer): + + if isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.integer): default_dtype = {"dtype": torch.int64} - elif np.issubdtype(value.dtype, np.floating): + elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating): default_dtype = {"dtype": torch.float32} + elif config.PIL_AVAILABLE and "PIL" in sys.modules: + import PIL.Image + if isinstance(value, PIL.Image.Image): + value = np.asarray(value) return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs}) def _recursive_tensorize(self, data_struct: dict): # support for nested types like struct of list of struct - if isinstance(data_struct, (list, np.ndarray)): - data_struct = np.array(data_struct, copy=False) - if data_struct.dtype == object: # pytorch tensors cannot be instantied from an array of objects - return [self.recursive_tensorize(substruct) for substruct in data_struct] + if isinstance(data_struct, np.ndarray): + if data_struct.dtype == object: # torch tensors cannot be instantied from an array of objects + return self._consolidate([self.recursive_tensorize(substruct) for substruct in data_struct]) return self._tensorize(data_struct) def recursive_tensorize(self, data_struct: dict): - return map_nested(self._recursive_tensorize, data_struct, map_list=False) + return map_nested(self._recursive_tensorize, data_struct) def format_row(self, pa_table: pa.Table) -> dict: row = self.numpy_arrow_extractor().extract_row(pa_table) + if self.decoded: + row = self.python_features_decoder.decode_row(row) return self.recursive_tensorize(row) def format_column(self, pa_table: pa.Table) -> "torch.Tensor": - col = self.numpy_arrow_extractor().extract_column(pa_table) - return self.recursive_tensorize(col) + column = self.numpy_arrow_extractor().extract_column(pa_table) + if self.decoded: + column = self.python_features_decoder.decode_column(column, pa_table.column_names[0]) + column = self.recursive_tensorize(column) + column = self._consolidate(column) + return column def format_batch(self, pa_table: pa.Table) -> dict: batch = self.numpy_arrow_extractor().extract_batch(pa_table) - return self.recursive_tensorize(batch) + if self.decoded: + batch = self.python_features_decoder.decode_batch(batch) + batch = self.recursive_tensorize(batch) + for column_name in batch: + batch[column_name] = self._consolidate(batch[column_name]) + return batch diff --git a/tests/features/data/test_image_rgba.png b/tests/features/data/test_image_rgba.png new file mode 100644 index 00000000000..b689ea49772 Binary files /dev/null and b/tests/features/data/test_image_rgba.png differ diff --git a/tests/features/test_image.py b/tests/features/test_image.py index 469b6b82ae3..2a066f3cbc1 100644 --- a/tests/features/test_image.py +++ b/tests/features/test_image.py @@ -445,26 +445,17 @@ def test_formatted_dataset_with_image_feature(shared_datadir): with dset.formatted_as("numpy"): item = dset[0] assert item.keys() == {"image"} - assert isinstance(item["image"], PIL.Image.Image) - assert os.path.samefile(item["image"].filename, image_path) - assert item["image"].format == "JPEG" - assert item["image"].size == (640, 480) - assert item["image"].mode == "RGB" + assert isinstance(item["image"], np.ndarray) + assert item["image"].shape == (480, 640, 3) batch = dset[:1] assert batch.keys() == {"image"} assert len(batch) == 1 - assert isinstance(batch["image"], list) and all(isinstance(item, PIL.Image.Image) for item in batch["image"]) - assert os.path.samefile(batch["image"][0].filename, image_path) - assert batch["image"][0].format == "JPEG" - assert batch["image"][0].size == (640, 480) - assert batch["image"][0].mode == "RGB" + assert isinstance(batch["image"], np.ndarray) + assert batch["image"].shape == (1, 480, 640, 3) column = dset["image"] assert len(column) == 2 - assert isinstance(column, list) and all(isinstance(item, PIL.Image.Image) for item in column) - assert os.path.samefile(column[0].filename, image_path) - assert column[0].format == "JPEG" - assert column[0].size == (640, 480) - assert column[0].mode == "RGB" + assert isinstance(column, np.ndarray) + assert column.shape == (2, 480, 640, 3) with dset.formatted_as("pandas"): item = dset[0] diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 4fe9ea1ea2b..523483e8b98 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -373,9 +373,17 @@ def test_set_format_torch(self, in_memory): self.assertIsInstance(dset[0]["col_2"], str) self.assertEqual(dset[0]["col_2"], "a") - dset.set_format(type="torch", columns=["col_1", "col_2"]) - with self.assertRaises(TypeError): - dset[0] + dset.set_format(type="torch") + self.assertEqual(len(dset[0]), 3) + self.assertIsInstance(dset[0]["col_1"], torch.Tensor) + self.assertIsInstance(dset["col_1"], torch.Tensor) + self.assertListEqual(list(dset[0]["col_1"].shape), []) + self.assertEqual(dset[0]["col_1"].item(), 3) + self.assertIsInstance(dset[0]["col_2"], str) + self.assertEqual(dset[0]["col_2"], "a") + self.assertIsInstance(dset[0]["col_3"], torch.Tensor) + self.assertIsInstance(dset["col_3"], torch.Tensor) + self.assertListEqual(list(dset[0]["col_3"].shape), []) @require_tf def test_set_format_tf(self, in_memory): @@ -2354,9 +2362,9 @@ def test_format_ragged_vectors(self, in_memory): self.assertIsNotNone(dset[0]) self.assertIsNotNone(dset[:2]) for col in columns: - self.assertIsInstance(dset[0][col], (tf.Tensor, tf.RaggedTensor)) - self.assertIsInstance(dset[:2][col], (tf.Tensor, tf.RaggedTensor)) - self.assertIsInstance(dset[col], (tf.Tensor, tf.RaggedTensor)) + self.assertIsInstance(dset[0][col], tf.Tensor) + self.assertIsInstance(dset[:2][col], tf.RaggedTensor if col == "vec" else tf.Tensor) + self.assertIsInstance(dset[col], tf.RaggedTensor if col == "vec" else tf.Tensor) # dim is None for ragged vectors in tensorflow self.assertListEqual(dset[:2]["vec"].shape.as_list(), [2, None]) self.assertListEqual(dset["vec"][:2].shape.as_list(), [2, None]) @@ -2374,16 +2382,20 @@ def test_format_ragged_vectors(self, in_memory): self.assertTupleEqual(dset[:2]["vec"].shape, (2,)) self.assertTupleEqual(dset["vec"][:2].shape, (2,)) - dset.set_format("torch", columns=["vec"]) + dset.set_format("torch") self.assertIsNotNone(dset[0]) self.assertIsNotNone(dset[:2]) - # torch.Tensor is only for numerical columns + self.assertIsInstance(dset[0]["filename"], str) + self.assertIsInstance(dset[:2]["filename"], list) + self.assertIsInstance(dset["filename"], list) self.assertIsInstance(dset[0]["vec"], torch.Tensor) self.assertIsInstance(dset[:2]["vec"][0], torch.Tensor) self.assertIsInstance(dset["vec"][0], torch.Tensor) # pytorch doesn't support ragged tensors, so we should have lists self.assertIsInstance(dset[:2]["vec"], list) + self.assertIsInstance(dset[:2]["vec"][0], torch.Tensor) self.assertIsInstance(dset["vec"][:2], list) + self.assertIsInstance(dset["vec"][0], torch.Tensor) @require_tf @require_torch diff --git a/tests/test_dataset_dict.py b/tests/test_dataset_dict.py index 2902cd799cb..175b5785e71 100644 --- a/tests/test_dataset_dict.py +++ b/tests/test_dataset_dict.py @@ -106,10 +106,14 @@ def test_set_format_torch(self): self.assertIsInstance(dset_split[0]["col_2"], str) self.assertEqual(dset_split[0]["col_2"], "a") - dset.set_format(type="torch", columns=["col_1", "col_2"]) + dset.set_format(type="torch") for dset_split in dset.values(): - with self.assertRaises(TypeError): - dset_split[0] + self.assertEqual(len(dset_split[0]), 2) + self.assertIsInstance(dset_split[0]["col_1"], torch.Tensor) + self.assertListEqual(list(dset_split[0]["col_1"].shape), []) + self.assertEqual(dset_split[0]["col_1"].item(), 3) + self.assertIsInstance(dset_split[0]["col_2"], str) + self.assertEqual(dset_split[0]["col_2"], "a") del dset @require_tf diff --git a/tests/test_formatting.py b/tests/test_formatting.py index 3c50cbd98cf..1e6bdbe69d1 100644 --- a/tests/test_formatting.py +++ b/tests/test_formatting.py @@ -1,3 +1,4 @@ +from pathlib import Path from unittest import TestCase import numpy as np @@ -5,6 +6,7 @@ import pyarrow as pa import pytest +from datasets import Audio, Features, Image from datasets.formatting import NumpyFormatter, PandasFormatter, PythonFormatter, query_table from datasets.formatting.formatting import NumpyArrowExtractor, PandasArrowExtractor, PythonArrowExtractor from datasets.table import InMemoryTable @@ -14,10 +16,14 @@ _COL_A = [0, 1, 2] _COL_B = ["foo", "bar", "foobar"] -_COL_C = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] +_COL_C = [[[1.0, 0.0, 0.0]] * 2, [[0.0, 1.0, 0.0]] * 2, [[0.0, 0.0, 1.0]] * 2] _INDICES = [1, 0] +IMAGE_PATH_1 = Path(__file__).parent / "features" / "data" / "test_image_rgb.jpg" +IMAGE_PATH_2 = Path(__file__).parent / "features" / "data" / "test_image_rgba.png" +AUDIO_PATH_1 = Path(__file__).parent / "features" / "data" / "test_audio_44100.wav" + class ArrowExtractorTest(TestCase): def _create_dummy_table(self): @@ -34,25 +40,29 @@ def test_python_extractor(self): self.assertEqual(batch, {"a": _COL_A, "b": _COL_B, "c": _COL_C}) def test_numpy_extractor(self): - pa_table = self._create_dummy_table() + pa_table = self._create_dummy_table().drop(["c"]) extractor = NumpyArrowExtractor() row = extractor.extract_row(pa_table) - np.testing.assert_equal(row, {"a": _COL_A[0], "b": _COL_B[0], "c": np.array(_COL_C[0])}) + np.testing.assert_equal(row, {"a": _COL_A[0], "b": _COL_B[0]}) col = extractor.extract_column(pa_table) np.testing.assert_equal(col, np.array(_COL_A)) batch = extractor.extract_batch(pa_table) - np.testing.assert_equal(batch, {"a": np.array(_COL_A), "b": np.array(_COL_B), "c": np.array(_COL_C)}) + np.testing.assert_equal(batch, {"a": np.array(_COL_A), "b": np.array(_COL_B)}) - def test_numpy_extractor_np_array_kwargs(self): - pa_table = self._create_dummy_table().drop(["b"]) - extractor = NumpyArrowExtractor(dtype=np.float16) + def test_numpy_extractor_nested(self): + pa_table = self._create_dummy_table().drop(["a", "b"]) + extractor = NumpyArrowExtractor() row = extractor.extract_row(pa_table) - self.assertEqual(row["c"].dtype, np.dtype(np.float16)) + self.assertEqual(row["c"][0].dtype, np.float64) + self.assertEqual(row["c"].dtype, object) col = extractor.extract_column(pa_table) - self.assertEqual(col.dtype, np.float16) + self.assertEqual(col[0][0].dtype, np.float64) + self.assertEqual(col[0].dtype, object) + self.assertEqual(col.dtype, object) batch = extractor.extract_batch(pa_table) - self.assertEqual(batch["a"].dtype, np.dtype(np.float16)) - self.assertEqual(batch["c"].dtype, np.dtype(np.float16)) + self.assertEqual(batch["c"][0][0].dtype, np.float64) + self.assertEqual(batch["c"][0].dtype, object) + self.assertEqual(batch["c"].dtype, object) def test_pandas_extractor(self): pa_table = self._create_dummy_table() @@ -94,6 +104,7 @@ def test_numpy_formatter(self): np.testing.assert_equal(col, np.array(_COL_A)) batch = formatter.format_batch(pa_table) np.testing.assert_equal(batch, {"a": np.array(_COL_A), "b": np.array(_COL_B), "c": np.array(_COL_C)}) + assert batch["c"].shape == np.array(_COL_C).shape def test_numpy_formatter_np_array_kwargs(self): pa_table = self._create_dummy_table().drop(["b"]) @@ -106,6 +117,49 @@ def test_numpy_formatter_np_array_kwargs(self): self.assertEqual(batch["a"].dtype, np.dtype(np.float16)) self.assertEqual(batch["c"].dtype, np.dtype(np.float16)) + def test_numpy_formatter_image(self): + # same dimensions + pa_table = pa.table({"image": [{"bytes": None, "path": str(IMAGE_PATH_1)}] * 2}) + formatter = NumpyFormatter(features=Features({"image": Image()})) + row = formatter.format_row(pa_table) + self.assertEqual(row["image"].dtype, np.uint8) + self.assertEqual(row["image"].shape, (480, 640, 3)) + col = formatter.format_column(pa_table) + self.assertEqual(col.dtype, np.uint8) + self.assertEqual(col.shape, (2, 480, 640, 3)) + batch = formatter.format_batch(pa_table) + self.assertEqual(batch["image"].dtype, np.uint8) + self.assertEqual(batch["image"].shape, (2, 480, 640, 3)) + + # different dimensions + pa_table = pa.table( + {"image": [{"bytes": None, "path": str(IMAGE_PATH_1)}, {"bytes": None, "path": str(IMAGE_PATH_2)}]} + ) + formatter = NumpyFormatter(features=Features({"image": Image()})) + row = formatter.format_row(pa_table) + self.assertEqual(row["image"].dtype, np.uint8) + self.assertEqual(row["image"].shape, (480, 640, 3)) + col = formatter.format_column(pa_table) + self.assertIsInstance(col, np.ndarray) + self.assertEqual(col.dtype, object) + self.assertEqual(col[0].dtype, np.uint8) + self.assertEqual(col[0].shape, (480, 640, 3)) + batch = formatter.format_batch(pa_table) + self.assertIsInstance(batch["image"], np.ndarray) + self.assertEqual(batch["image"].dtype, object) + self.assertEqual(batch["image"][0].dtype, np.uint8) + self.assertEqual(batch["image"][0].shape, (480, 640, 3)) + + def test_numpy_formatter_audio(self): + pa_table = pa.table({"audio": [{"bytes": None, "path": str(AUDIO_PATH_1)}]}) + formatter = NumpyFormatter(features=Features({"audio": Audio()})) + row = formatter.format_row(pa_table) + self.assertEqual(row["audio"]["array"].dtype, np.dtype(np.float32)) + col = formatter.format_column(pa_table) + self.assertEqual(col[0]["array"].dtype, np.float32) + batch = formatter.format_batch(pa_table) + self.assertEqual(batch["audio"][0]["array"].dtype, np.dtype(np.float32)) + def test_pandas_formatter(self): pa_table = self._create_dummy_table() formatter = PandasFormatter() @@ -128,19 +182,22 @@ def test_torch_formatter(self): from datasets.formatting import TorchFormatter - pa_table = self._create_dummy_table().drop(["b"]) + pa_table = self._create_dummy_table() formatter = TorchFormatter() row = formatter.format_row(pa_table) - torch.testing.assert_allclose(row["a"], torch.tensor(_COL_A, dtype=torch.int64)[0]) - torch.testing.assert_allclose(row["c"], torch.tensor(_COL_C, dtype=torch.float32)[0]) + torch.testing.assert_close(row["a"], torch.tensor(_COL_A, dtype=torch.int64)[0]) + assert row["b"] == _COL_B[0] + torch.testing.assert_close(row["c"], torch.tensor(_COL_C, dtype=torch.float32)[0]) col = formatter.format_column(pa_table) - torch.testing.assert_allclose(col, torch.tensor(_COL_A, dtype=torch.int64)) + torch.testing.assert_close(col, torch.tensor(_COL_A, dtype=torch.int64)) batch = formatter.format_batch(pa_table) - torch.testing.assert_allclose(batch["a"], torch.tensor(_COL_A, dtype=torch.int64)) - torch.testing.assert_allclose(batch["c"], torch.tensor(_COL_C, dtype=torch.float32)) + torch.testing.assert_close(batch["a"], torch.tensor(_COL_A, dtype=torch.int64)) + assert batch["b"] == _COL_B + torch.testing.assert_close(batch["c"], torch.tensor(_COL_C, dtype=torch.float32)) + assert batch["c"].shape == np.array(_COL_C).shape @require_torch - def test_torch_formatter_np_array_kwargs(self): + def test_torch_formatter_torch_tensor_kwargs(self): import torch from datasets.formatting import TorchFormatter @@ -155,6 +212,57 @@ def test_torch_formatter_np_array_kwargs(self): self.assertEqual(batch["a"].dtype, torch.float16) self.assertEqual(batch["c"].dtype, torch.float16) + @require_torch + def test_torch_formatter_image(self): + import torch + + from datasets.formatting import TorchFormatter + + # same dimensions + pa_table = pa.table({"image": [{"bytes": None, "path": str(IMAGE_PATH_1)}] * 2}) + formatter = TorchFormatter(features=Features({"image": Image()})) + row = formatter.format_row(pa_table) + self.assertEqual(row["image"].dtype, torch.uint8) + self.assertEqual(row["image"].shape, (480, 640, 3)) + col = formatter.format_column(pa_table) + self.assertEqual(col.dtype, torch.uint8) + self.assertEqual(col.shape, (2, 480, 640, 3)) + batch = formatter.format_batch(pa_table) + self.assertEqual(batch["image"].dtype, torch.uint8) + self.assertEqual(batch["image"].shape, (2, 480, 640, 3)) + + # different dimensions + pa_table = pa.table( + {"image": [{"bytes": None, "path": str(IMAGE_PATH_1)}, {"bytes": None, "path": str(IMAGE_PATH_2)}]} + ) + formatter = TorchFormatter(features=Features({"image": Image()})) + row = formatter.format_row(pa_table) + self.assertEqual(row["image"].dtype, torch.uint8) + self.assertEqual(row["image"].shape, (480, 640, 3)) + col = formatter.format_column(pa_table) + self.assertIsInstance(col, list) + self.assertEqual(col[0].dtype, torch.uint8) + self.assertEqual(col[0].shape, (480, 640, 3)) + batch = formatter.format_batch(pa_table) + self.assertIsInstance(batch["image"], list) + self.assertEqual(batch["image"][0].dtype, torch.uint8) + self.assertEqual(batch["image"][0].shape, (480, 640, 3)) + + @require_torch + def test_torch_formatter_audio(self): + import torch + + from datasets.formatting import TorchFormatter + + pa_table = pa.table({"audio": [{"bytes": None, "path": str(AUDIO_PATH_1)}]}) + formatter = TorchFormatter(features=Features({"audio": Audio()})) + row = formatter.format_row(pa_table) + self.assertEqual(row["audio"]["array"].dtype, torch.float32) + col = formatter.format_column(pa_table) + self.assertEqual(col[0]["array"].dtype, torch.float32) + batch = formatter.format_batch(pa_table) + self.assertEqual(batch["audio"][0]["array"].dtype, torch.float32) + @require_tf def test_tf_formatter(self): import tensorflow as tf @@ -180,7 +288,7 @@ def test_tf_formatter(self): tf.debugging.assert_equal(tf.convert_to_tensor(batch["c"]), tf.convert_to_tensor(_COL_C, dtype=tf.float32)) @require_tf - def test_tf_formatter_np_array_kwargs(self): + def test_tf_formatter_tf_tensor_kwargs(self): import tensorflow as tf from datasets.formatting import TFFormatter @@ -195,25 +303,80 @@ def test_tf_formatter_np_array_kwargs(self): self.assertEqual(batch["a"].dtype, tf.float16) self.assertEqual(batch["c"].dtype, tf.float16) + @require_tf + def test_tf_formatter_image(self): + import tensorflow as tf + + from datasets.formatting import TFFormatter + + # same dimensions + pa_table = pa.table({"image": [{"bytes": None, "path": str(IMAGE_PATH_1)}] * 2}) + formatter = TFFormatter(features=Features({"image": Image()})) + row = formatter.format_row(pa_table) + self.assertEqual(row["image"].dtype, tf.uint8) + self.assertEqual(row["image"].shape, (480, 640, 3)) + col = formatter.format_column(pa_table) + self.assertEqual(col.dtype, tf.uint8) + self.assertEqual(col.shape, (2, 480, 640, 3)) + batch = formatter.format_batch(pa_table) + self.assertEqual(batch["image"][0].dtype, tf.uint8) + self.assertEqual(batch["image"].shape, (2, 480, 640, 3)) + + # different dimensions + pa_table = pa.table( + {"image": [{"bytes": None, "path": str(IMAGE_PATH_1)}, {"bytes": None, "path": str(IMAGE_PATH_2)}]} + ) + formatter = TFFormatter(features=Features({"image": Image()})) + row = formatter.format_row(pa_table) + self.assertEqual(row["image"].dtype, tf.uint8) + self.assertEqual(row["image"].shape, (480, 640, 3)) + col = formatter.format_column(pa_table) + self.assertIsInstance(col, list) + self.assertEqual(col[0].dtype, tf.uint8) + self.assertEqual(col[0].shape, (480, 640, 3)) + batch = formatter.format_batch(pa_table) + self.assertIsInstance(batch["image"], list) + self.assertEqual(batch["image"][0].dtype, tf.uint8) + self.assertEqual(batch["image"][0].shape, (480, 640, 3)) + + @require_tf + def test_tf_formatter_audio(self): + import tensorflow as tf + + from datasets.formatting import TFFormatter + + pa_table = pa.table({"audio": [{"bytes": None, "path": str(AUDIO_PATH_1)}]}) + formatter = TFFormatter(features=Features({"audio": Audio()})) + row = formatter.format_row(pa_table) + self.assertEqual(row["audio"]["array"].dtype, tf.float32) + col = formatter.format_column(pa_table) + self.assertEqual(col[0]["array"].dtype, tf.float32) + batch = formatter.format_batch(pa_table) + self.assertEqual(batch["audio"][0]["array"].dtype, tf.float32) + @require_jax def test_jax_formatter(self): + import jax import jax.numpy as jnp from datasets.formatting import JaxFormatter - pa_table = self._create_dummy_table().drop(["b"]) + pa_table = self._create_dummy_table() formatter = JaxFormatter() row = formatter.format_row(pa_table) - jnp.allclose(row["a"], jnp.array(_COL_A, dtype=jnp.int64)[0]) + jnp.allclose(row["a"], jnp.array(_COL_A, dtype=jnp.int64 if jax.config.jax_enable_x64 else jnp.int32)[0]) + assert row["b"] == _COL_B[0] jnp.allclose(row["c"], jnp.array(_COL_C, dtype=jnp.float32)[0]) col = formatter.format_column(pa_table) - jnp.allclose(col, jnp.array(_COL_A, dtype=jnp.int64)) + jnp.allclose(col, jnp.array(_COL_A, dtype=jnp.int64 if jax.config.jax_enable_x64 else jnp.int32)) batch = formatter.format_batch(pa_table) - jnp.allclose(batch["a"], jnp.array(_COL_A, dtype=jnp.int64)) + jnp.allclose(batch["a"], jnp.array(_COL_A, dtype=jnp.int64 if jax.config.jax_enable_x64 else jnp.int32)) + assert batch["b"] == _COL_B jnp.allclose(batch["c"], jnp.array(_COL_C, dtype=jnp.float32)) + assert batch["c"].shape == np.array(_COL_C).shape @require_jax - def test_jax_formatter_np_array_kwargs(self): + def test_jax_formatter_jnp_array_kwargs(self): import jax.numpy as jnp from datasets.formatting import JaxFormatter @@ -228,6 +391,57 @@ def test_jax_formatter_np_array_kwargs(self): self.assertEqual(batch["a"].dtype, jnp.float16) self.assertEqual(batch["c"].dtype, jnp.float16) + @require_jax + def test_jax_formatter_image(self): + import jax.numpy as jnp + + from datasets.formatting import JaxFormatter + + # same dimensions + pa_table = pa.table({"image": [{"bytes": None, "path": str(IMAGE_PATH_1)}] * 2}) + formatter = JaxFormatter(features=Features({"image": Image()})) + row = formatter.format_row(pa_table) + self.assertEqual(row["image"].dtype, jnp.uint8) + self.assertEqual(row["image"].shape, (480, 640, 3)) + col = formatter.format_column(pa_table) + self.assertEqual(col.dtype, jnp.uint8) + self.assertEqual(col.shape, (2, 480, 640, 3)) + batch = formatter.format_batch(pa_table) + self.assertEqual(batch["image"].dtype, jnp.uint8) + self.assertEqual(batch["image"].shape, (2, 480, 640, 3)) + + # different dimensions + pa_table = pa.table( + {"image": [{"bytes": None, "path": str(IMAGE_PATH_1)}, {"bytes": None, "path": str(IMAGE_PATH_2)}]} + ) + formatter = JaxFormatter(features=Features({"image": Image()})) + row = formatter.format_row(pa_table) + self.assertEqual(row["image"].dtype, jnp.uint8) + self.assertEqual(row["image"].shape, (480, 640, 3)) + col = formatter.format_column(pa_table) + self.assertIsInstance(col, list) + self.assertEqual(col[0].dtype, jnp.uint8) + self.assertEqual(col[0].shape, (480, 640, 3)) + batch = formatter.format_batch(pa_table) + self.assertIsInstance(batch["image"], list) + self.assertEqual(batch["image"][0].dtype, jnp.uint8) + self.assertEqual(batch["image"][0].shape, (480, 640, 3)) + + @require_jax + def test_jax_formatter_audio(self): + import jax.numpy as jnp + + from datasets.formatting import JaxFormatter + + pa_table = pa.table({"audio": [{"bytes": None, "path": str(AUDIO_PATH_1)}]}) + formatter = JaxFormatter(features=Features({"audio": Audio()})) + row = formatter.format_row(pa_table) + self.assertEqual(row["audio"]["array"].dtype, jnp.float32) + col = formatter.format_column(pa_table) + self.assertEqual(col[0]["array"].dtype, jnp.float32) + batch = formatter.format_batch(pa_table) + self.assertEqual(batch["audio"][0]["array"].dtype, jnp.float32) + class QueryTest(TestCase): def _create_dummy_table(self): @@ -553,12 +767,12 @@ def test_torch_formatter_sets_default_dtypes(cast_schema, arrow_table): formatter = TorchFormatter() row = formatter.format_row(arrow_table) - torch.testing.assert_allclose(row["col_int"], torch.tensor(list_int, dtype=torch.int64)[0]) - torch.testing.assert_allclose(row["col_float"], torch.tensor(list_float, dtype=torch.float32)[0]) + torch.testing.assert_close(row["col_int"], torch.tensor(list_int, dtype=torch.int64)[0]) + torch.testing.assert_close(row["col_float"], torch.tensor(list_float, dtype=torch.float32)[0]) col = formatter.format_column(arrow_table) - torch.testing.assert_allclose(col, torch.tensor(list_int, dtype=torch.int64)) + torch.testing.assert_close(col, torch.tensor(list_int, dtype=torch.int64)) batch = formatter.format_batch(arrow_table) - torch.testing.assert_allclose(batch["col_int"], torch.tensor(list_int, dtype=torch.int64)) - torch.testing.assert_allclose(batch["col_float"], torch.tensor(list_float, dtype=torch.float32)) + torch.testing.assert_close(batch["col_int"], torch.tensor(list_int, dtype=torch.int64)) + torch.testing.assert_close(batch["col_float"], torch.tensor(list_float, dtype=torch.float32))