-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
__init__.py
127 lines (105 loc) 路 4.99 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# Lint as: python3
from typing import Dict, List, Optional
from .. import config
from ..utils import logging
from .formatting import (
ArrowFormatter,
CustomFormatter,
Formatter,
PandasFormatter,
PythonFormatter,
format_table,
query_table,
)
from .np_formatter import NumpyFormatter
logger = logging.get_logger(__name__)
_FORMAT_TYPES: Dict[Optional[str], type] = {}
_FORMAT_TYPES_ALIASES: Dict[Optional[str], str] = {}
_FORMAT_TYPES_ALIASES_UNAVAILABLE: Dict[Optional[str], Exception] = {}
def _register_formatter(formatter_cls: type, format_type: Optional[str], aliases: Optional[List[str]] = None):
"""
Register a Formatter object using a name and optional aliases.
This function must be used on a Formatter class.
"""
aliases = aliases if aliases is not None else []
if format_type in _FORMAT_TYPES:
logger.warning(
f"Overwriting format type '{format_type}' ({_FORMAT_TYPES[format_type].__name__} -> {formatter_cls.__name__})"
)
_FORMAT_TYPES[format_type] = formatter_cls
for alias in set(aliases + [format_type]):
if alias in _FORMAT_TYPES_ALIASES:
logger.warning(
f"Overwriting format type alias '{alias}' ({_FORMAT_TYPES_ALIASES[alias]} -> {format_type})"
)
_FORMAT_TYPES_ALIASES[alias] = format_type
def _register_unavailable_formatter(
unavailable_error: Exception, format_type: Optional[str], aliases: Optional[List[str]] = None
):
"""
Register an unavailable Formatter object using a name and optional aliases.
This function must be used on an Exception object that is raised when trying to get the unavailable formatter.
"""
aliases = aliases if aliases is not None else []
for alias in set(aliases + [format_type]):
_FORMAT_TYPES_ALIASES_UNAVAILABLE[alias] = unavailable_error
# Here we define all the available formatting functions that can be used by `Dataset.set_format`
_register_formatter(PythonFormatter, None, aliases=["python"])
_register_formatter(ArrowFormatter, "arrow", aliases=["pa", "pyarrow"])
_register_formatter(NumpyFormatter, "numpy", aliases=["np"])
_register_formatter(PandasFormatter, "pandas", aliases=["pd"])
_register_formatter(CustomFormatter, "custom")
if config.TORCH_AVAILABLE:
from .torch_formatter import TorchFormatter
_register_formatter(TorchFormatter, "torch", aliases=["pt", "pytorch"])
else:
_torch_error = ValueError("PyTorch needs to be installed to be able to return PyTorch tensors.")
_register_unavailable_formatter(_torch_error, "torch", aliases=["pt", "pytorch"])
if config.TF_AVAILABLE:
from .tf_formatter import TFFormatter
_register_formatter(TFFormatter, "tensorflow", aliases=["tf"])
else:
_tf_error = ValueError("Tensorflow needs to be installed to be able to return Tensorflow tensors.")
_register_unavailable_formatter(_tf_error, "tensorflow", aliases=["tf"])
if config.JAX_AVAILABLE:
from .jax_formatter import JaxFormatter
_register_formatter(JaxFormatter, "jax", aliases=[])
else:
_jax_error = ValueError("JAX needs to be installed to be able to return JAX arrays.")
_register_unavailable_formatter(_jax_error, "jax", aliases=[])
def get_format_type_from_alias(format_type: Optional[str]) -> Optional[str]:
"""If the given format type is a known alias, then return its main type name. Otherwise return the type with no change."""
if format_type in _FORMAT_TYPES_ALIASES:
return _FORMAT_TYPES_ALIASES[format_type]
else:
return format_type
def get_formatter(format_type: Optional[str], **format_kwargs) -> Formatter:
"""
Factory function to get a Formatter given its type name and keyword arguments.
A formatter is an object that extracts and formats data from pyarrow table.
It defines the formatting for rows, colums and batches.
If the formatter for a given type name doesn't exist or is not available, an error is raised.
"""
format_type = get_format_type_from_alias(format_type)
if format_type in _FORMAT_TYPES:
return _FORMAT_TYPES[format_type](**format_kwargs)
if format_type in _FORMAT_TYPES_ALIASES_UNAVAILABLE:
raise _FORMAT_TYPES_ALIASES_UNAVAILABLE[format_type]
else:
raise ValueError(
f"Return type should be None or selected in {list(type for type in _FORMAT_TYPES.keys() if type != None)}, but got '{format_type}'"
)