Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New feature: enable builds to bind "meta" fields #102

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/hydra_zen/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

_T = _typing.TypeVar("_T")

__all__ = ["partial", "get_obj"]
__all__ = ["partial", "get_obj", "pass_it"]


def partial(
Expand All @@ -38,3 +38,12 @@ def get_obj(*, path: str) -> _typing.Union[type, _typing.Callable[..., _typing.A
except Exception as e: # pragma: no cover
_log.error(f"Error getting callable at {path} : {e}")
raise e


def pass_it(
*args, _true_target_: str, _excluded_: _typing.Sequence[str], **kwargs
) -> _typing.Any:
excluded_set = set(_excluded_)
kwargs = {k: v for k, v in kwargs.items() if k not in excluded_set}
target = get_obj(path=_true_target_)
return target(*args, **kwargs)
37 changes: 36 additions & 1 deletion src/hydra_zen/structured_configs/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from typing_extensions import Final, Literal

from hydra_zen.funcs import get_obj, partial
from hydra_zen.funcs import get_obj, partial, pass_it
from hydra_zen.structured_configs import _utils
from hydra_zen.typing import Builds, Importable, Just, PartialBuilds
from hydra_zen.typing._implementations import HasPartialTarget, HasTarget
Expand Down Expand Up @@ -352,6 +352,7 @@ def builds(
hydra_partial: Literal[False] = False,
hydra_recursive: Optional[bool] = None,
hydra_convert: Optional[Literal["none", "partial", "all"]] = None,
hydra_meta: Optional[Dict[str, Any]] = None,
dataclass_name: Optional[str] = None,
builds_bases: Tuple[Any, ...] = (),
frozen: bool = False,
Expand All @@ -369,6 +370,7 @@ def builds(
hydra_partial: Literal[True],
hydra_recursive: Optional[bool] = None,
hydra_convert: Optional[Literal["none", "partial", "all"]] = None,
hydra_meta: Optional[Dict[str, Any]] = None,
dataclass_name: Optional[str] = None,
builds_bases: Tuple[Any, ...] = (),
frozen: bool = False,
Expand All @@ -386,6 +388,7 @@ def builds(
hydra_partial: bool,
hydra_recursive: Optional[bool] = None,
hydra_convert: Optional[Literal["none", "partial", "all"]] = None,
hydra_meta: Optional[Dict[str, Any]] = None,
dataclass_name: Optional[str] = None,
builds_bases: Tuple[Any, ...] = (),
frozen: bool = False,
Expand All @@ -403,6 +406,7 @@ def builds(
hydra_partial: bool = False,
hydra_recursive: Optional[bool] = None,
hydra_convert: Optional[Literal["none", "partial", "all"]] = None,
hydra_meta: Optional[Dict[str, Any]] = None,
frozen: bool = False,
dataclass_name: Optional[str] = None,
builds_bases: Tuple[Any, ...] = (),
Expand Down Expand Up @@ -607,6 +611,24 @@ def builds(
_utils.field(default=just(target), init=False),
),
]
elif hydra_meta:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to handle the case where both hydra_partial is True and hydra_meta is not None?

Copy link
Contributor Author

@rsokl rsokl Sep 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! I definitely do need to handle this case. This is just a crude WIP right now.

I plan to refactor how we handle hydra_partial=True so that it also looks for _true_target_ instead of _partial_target_. I want to converge on a common interface for any/all post-processing functions that hydra-zen ships.

target_field = [
(
_TARGET_FIELD_NAME,
str,
_utils.field(default=_utils.get_obj_path(pass_it), init=False),
),
(
"_true_target_", # TODO: replace with global constant
str,
_utils.field(default=_utils.get_obj_path(target), init=False),
),
(
"_excluded_",
Tuple[str],
_utils.field(default=tuple(hydra_meta), init=False),
),
]
else:
target_field = [
(
Expand Down Expand Up @@ -807,6 +829,16 @@ def builds(
for name, value in kwargs_for_target.items()
}

# TODO: validate to ensure no collision with named arguments or
# with named arguments in signature. Also check inherited case
if hydra_meta:
user_specified_named_params.update(
{
name: (name, Any, sanitized_default_value(value))
for name, value in hydra_meta.items()
}
)

if populate_full_signature is True:
# Populate dataclass fields based on the target's signature.
#
Expand Down Expand Up @@ -986,6 +1018,9 @@ def get_target(obj: Union[HasTarget, HasPartialTarget]) -> Any: # pragma: no co
...


# TODO: add case for new namespace build flavor


def get_target(obj: Union[HasTarget, HasPartialTarget]) -> Any:
"""
Returns the target-object from a targeted structured config.
Expand Down