Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require keyword args for data iterator. #8327

Merged
merged 4 commits into from Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
89 changes: 51 additions & 38 deletions python-package/xgboost/core.py
Expand Up @@ -502,8 +502,8 @@ def _next_wrapper(self, this: None) -> int: # pylint: disable=unused-argument
pointer.

"""
@_deprecate_positional_args
def data_handle(
@require_pos_args(True)
def input_data(
data: Any,
*,
feature_names: Optional[FeatureNames] = None,
Expand All @@ -528,7 +528,7 @@ def data_handle(
**kwargs,
)
# pylint: disable=not-callable
return self._handle_exception(lambda: self.next(data_handle), 0)
return self._handle_exception(lambda: self.next(input_data), 0)

@abstractmethod
def reset(self) -> None:
Expand All @@ -554,7 +554,7 @@ def next(self, input_data: Callable) -> int:
raise NotImplementedError()


# Notice for `_deprecate_positional_args`
# Notice for `require_pos_args`
# Authors: Olivier Grisel
# Gael Varoquaux
# Andreas Mueller
Expand All @@ -563,50 +563,63 @@ def next(self, input_data: Callable) -> int:
# Nicolas Tresegnie
# Sylvain Marie
# License: BSD 3 clause
def _deprecate_positional_args(f: Callable[..., _T]) -> Callable[..., _T]:
def require_pos_args(error: bool) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
"""Decorator for methods that issues warnings for positional arguments

Using the keyword-only argument syntax in pep 3102, arguments after the
* will issue a warning when passed as a positional argument.
* will issue a warning or error when passed as a positional argument.

Modified from sklearn utils.validation.

Parameters
----------
f : function
function to check arguments on
error :
Whether to throw an error or raise a warning.
"""
sig = signature(f)
kwonly_args = []
all_args = []

for name, param in sig.parameters.items():
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
all_args.append(name)
elif param.kind == Parameter.KEYWORD_ONLY:
kwonly_args.append(name)

@wraps(f)
def inner_f(*args: Any, **kwargs: Any) -> _T:
extra_args = len(args) - len(all_args)
if extra_args > 0:
# ignore first 'self' argument for instance methods
args_msg = [
f"{name}" for name, _ in zip(
kwonly_args[:extra_args], args[-extra_args:]
)
]
# pylint: disable=consider-using-f-string
warnings.warn(
"Pass `{}` as keyword args. Passing these as positional "
"arguments will be considered as error in future releases.".
format(", ".join(args_msg)), FutureWarning
)
for k, arg in zip(sig.parameters, args):
kwargs[k] = arg
return f(**kwargs)

return inner_f
def throw_if(func: Callable[..., _T]) -> Callable[..., _T]:
"""Throw error/warning if there are positional arguments after the asterisk.

Parameters
----------
f :
function to check arguments on.

"""
sig = signature(func)
kwonly_args = []
all_args = []

for name, param in sig.parameters.items():
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
all_args.append(name)
elif param.kind == Parameter.KEYWORD_ONLY:
kwonly_args.append(name)

@wraps(func)
def inner_f(*args: Any, **kwargs: Any) -> _T:
extra_args = len(args) - len(all_args)
if extra_args > 0:
# ignore first 'self' argument for instance methods
args_msg = [
f"{name}"
for name, _ in zip(kwonly_args[:extra_args], args[-extra_args:])
]
# pylint: disable=consider-using-f-string
msg = "Pass `{}` as keyword args.".format(", ".join(args_msg))
if error:
raise TypeError(msg)
warnings.warn(msg, FutureWarning)
for k, arg in zip(sig.parameters, args):
kwargs[k] = arg
return func(**kwargs)

return inner_f

return throw_if


_deprecate_positional_args = require_pos_args(False)


class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods
Expand Down
4 changes: 4 additions & 0 deletions tests/python/testing.py
Expand Up @@ -198,6 +198,10 @@ def __init__(
def next(self, input_data: Callable) -> int:
if self.it == len(self.X):
return 0

with pytest.raises(TypeError, match="keyword args"):
input_data(self.X[self.it], self.y[self.it], None)

# Use copy to make sure the iterator doesn't hold a reference to the data.
input_data(
data=self.X[self.it].copy(),
Expand Down