Skip to content

Commit

Permalink
Require keyword args for data iterator. (#8327)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Oct 10, 2022
1 parent e1f9f80 commit 5545c49
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 38 deletions.
89 changes: 51 additions & 38 deletions python-package/xgboost/core.py
Expand Up @@ -502,8 +502,8 @@ def _next_wrapper(self, this: None) -> int: # pylint: disable=unused-argument
pointer.
"""
@_deprecate_positional_args
def data_handle(
@require_pos_args(True)
def input_data(
data: Any,
*,
feature_names: Optional[FeatureNames] = None,
Expand All @@ -528,7 +528,7 @@ def data_handle(
**kwargs,
)
# pylint: disable=not-callable
return self._handle_exception(lambda: self.next(data_handle), 0)
return self._handle_exception(lambda: self.next(input_data), 0)

@abstractmethod
def reset(self) -> None:
Expand All @@ -554,7 +554,7 @@ def next(self, input_data: Callable) -> int:
raise NotImplementedError()


# Notice for `_deprecate_positional_args`
# Notice for `require_pos_args`
# Authors: Olivier Grisel
# Gael Varoquaux
# Andreas Mueller
Expand All @@ -563,50 +563,63 @@ def next(self, input_data: Callable) -> int:
# Nicolas Tresegnie
# Sylvain Marie
# License: BSD 3 clause
def _deprecate_positional_args(f: Callable[..., _T]) -> Callable[..., _T]:
def require_pos_args(error: bool) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
"""Decorator for methods that issues warnings for positional arguments
Using the keyword-only argument syntax in pep 3102, arguments after the
* will issue a warning when passed as a positional argument.
* will issue a warning or error when passed as a positional argument.
Modified from sklearn utils.validation.
Parameters
----------
f : function
function to check arguments on
error :
Whether to throw an error or raise a warning.
"""
sig = signature(f)
kwonly_args = []
all_args = []

for name, param in sig.parameters.items():
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
all_args.append(name)
elif param.kind == Parameter.KEYWORD_ONLY:
kwonly_args.append(name)

@wraps(f)
def inner_f(*args: Any, **kwargs: Any) -> _T:
extra_args = len(args) - len(all_args)
if extra_args > 0:
# ignore first 'self' argument for instance methods
args_msg = [
f"{name}" for name, _ in zip(
kwonly_args[:extra_args], args[-extra_args:]
)
]
# pylint: disable=consider-using-f-string
warnings.warn(
"Pass `{}` as keyword args. Passing these as positional "
"arguments will be considered as error in future releases.".
format(", ".join(args_msg)), FutureWarning
)
for k, arg in zip(sig.parameters, args):
kwargs[k] = arg
return f(**kwargs)

return inner_f
def throw_if(func: Callable[..., _T]) -> Callable[..., _T]:
"""Throw error/warning if there are positional arguments after the asterisk.
Parameters
----------
f :
function to check arguments on.
"""
sig = signature(func)
kwonly_args = []
all_args = []

for name, param in sig.parameters.items():
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
all_args.append(name)
elif param.kind == Parameter.KEYWORD_ONLY:
kwonly_args.append(name)

@wraps(func)
def inner_f(*args: Any, **kwargs: Any) -> _T:
extra_args = len(args) - len(all_args)
if extra_args > 0:
# ignore first 'self' argument for instance methods
args_msg = [
f"{name}"
for name, _ in zip(kwonly_args[:extra_args], args[-extra_args:])
]
# pylint: disable=consider-using-f-string
msg = "Pass `{}` as keyword args.".format(", ".join(args_msg))
if error:
raise TypeError(msg)
warnings.warn(msg, FutureWarning)
for k, arg in zip(sig.parameters, args):
kwargs[k] = arg
return func(**kwargs)

return inner_f

return throw_if


_deprecate_positional_args = require_pos_args(False)


class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods
Expand Down
4 changes: 4 additions & 0 deletions tests/python/testing.py
Expand Up @@ -198,6 +198,10 @@ def __init__(
def next(self, input_data: Callable) -> int:
if self.it == len(self.X):
return 0

with pytest.raises(TypeError, match="keyword args"):
input_data(self.X[self.it], self.y[self.it], None)

# Use copy to make sure the iterator doesn't hold a reference to the data.
input_data(
data=self.X[self.it].copy(),
Expand Down

0 comments on commit 5545c49

Please sign in to comment.