Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Support for pydantic options #803

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Empty file.
17 changes: 17 additions & 0 deletions docs_src/parameter_types/pydantic/tutorial001.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import typer

import pydantic


class User(pydantic.BaseModel):
id: int
name: str = "Jane Doe"


def main(num: int, user: User):
print(num, type(num))
print(user, type(user))


if __name__ == "__main__":
typer.run(main)
24 changes: 24 additions & 0 deletions docs_src/parameter_types/pydantic/tutorial002.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Optional

import typer

import pydantic


class Pet(pydantic.BaseModel):
name: str
species: str


class Person(pydantic.BaseModel):
name: str
age: Optional[float] = None
pet: Pet


def main(person: Person):
print(person, type(person))


if __name__ == "__main__":
typer.run(main)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ homepage = "https://github.com/tiangolo/typer"
standard = [
"shellingham >=1.3.0",
"rich >=10.11.0",
"pydantic >= 2.0.0",
]

[tool.pdm]
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import subprocess
import sys

import pytest
import typer
from typer.testing import CliRunner

from docs_src.parameter_types.pydantic import tutorial001 as mod

runner = CliRunner()

app = typer.Typer()
app.command()(mod.main)


def test_help():
result = runner.invoke(app, ["--help"])
assert result.exit_code == 0


def test_parse_pydantic_model():
result = runner.invoke(app, ["1", "--user.id", "2", "--user.name", "John Doe"])
assert "1 <class 'int'>" in result.output
assert (
"id=2 name='John Doe' <class 'docs_src.parameter_types.pydantic.tutorial001.User'>"
in result.output
)


def test_script():
result = subprocess.run(
[sys.executable, "-m", "coverage", "run", mod.__file__, "--help"],
capture_output=True,
encoding="utf-8",
)
assert "Usage" in result.stdout


def test_error_without_pydantic():
pydantic = typer.pydantic_extension.pydantic
typer.pydantic_extension.pydantic = None
with pytest.raises(
RuntimeError,
match="Type not yet supported: <class 'docs_src.parameter_types.pydantic.tutorial001.User'>",
):
runner.invoke(
app,
["1", "--user.id", "2", "--user.name", "John Doe"],
catch_exceptions=False,
)
typer.pydantic_extension.pydantic = pydantic
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import subprocess
import sys

import typer
from typer.testing import CliRunner

from docs_src.parameter_types.pydantic import tutorial002 as mod

runner = CliRunner()

app = typer.Typer()
app.command()(mod.main)


def test_help():
result = runner.invoke(app, ["--help"])
assert result.exit_code == 0


def test_parse_pydantic_model():
result = runner.invoke(
app,
[
"--person.name",
"Jeff",
"--person.pet.name",
"Lassie",
"--person.pet.species",
"dog",
],
)
assert (
"name='Jeff' age=None pet=Pet(name='Lassie', species='dog') <class 'docs_src.parameter_types.pydantic.tutorial002.Person'>"
in result.output
)


def test_script():
result = subprocess.run(
[sys.executable, "-m", "coverage", "run", mod.__file__, "--help"],
capture_output=True,
encoding="utf-8",
)
assert "Usage" in result.stdout
15 changes: 5 additions & 10 deletions typer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from .completion import get_completion_inspect_parameters
from .core import MarkupMode, TyperArgument, TyperCommand, TyperGroup, TyperOption
from .models import (
AnyType,
ArgumentInfo,
CommandFunctionType,
CommandInfo,
Expand All @@ -34,7 +33,8 @@
Required,
TyperInfo,
)
from .utils import get_params_from_function
from .pydantic_extension import wrap_pydantic_callback
from .utils import get_params_from_function, lenient_issubclass

try:
import rich
Expand Down Expand Up @@ -572,17 +572,18 @@ def get_command_from_info(
use_help = inspect.getdoc(command_info.callback)
else:
use_help = inspect.cleandoc(use_help)
callback = wrap_pydantic_callback(command_info.callback)
(
params,
convertors,
context_param_name,
) = get_params_convertors_ctx_param_name_from_function(command_info.callback)
) = get_params_convertors_ctx_param_name_from_function(callback)
cls = command_info.cls or TyperCommand
command = cls(
name=name,
context_settings=command_info.context_settings,
callback=get_callback(
callback=command_info.callback,
callback=callback,
params=params,
convertors=convertors,
context_param_name=context_param_name,
Expand Down Expand Up @@ -788,12 +789,6 @@ def get_click_type(
raise RuntimeError(f"Type not yet supported: {annotation}") # pragma: no cover


def lenient_issubclass(
cls: Any, class_or_tuple: Union[AnyType, Tuple[AnyType, ...]]
) -> bool:
return isinstance(cls, type) and issubclass(cls, class_or_tuple)


def get_click_param(
param: ParamMeta,
) -> Tuple[Union[click.Argument, click.Option], Any]:
Expand Down
85 changes: 85 additions & 0 deletions typer/pydantic_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import inspect
from typing import Any, Callable, Dict, List

from typing_extensions import Annotated

from .params import Option
from .utils import deep_update, inspect_signature, lenient_issubclass

try:
import pydantic
except ImportError: # pragma: no cover
pydantic = None # type: ignore

PYDANTIC_FIELD_SEPARATOR = "."


def _flatten_pydantic_model(
model: "pydantic.BaseModel", ancestors: List[str]
) -> Dict[str, inspect.Parameter]:
# This function should only be called if pydantic is available
assert pydantic is not None
pydantic_parameters = {}
for field_name, field in model.model_fields.items():
qualifier = [*ancestors, field_name]
sub_name = f"_pydantic_{'_'.join(qualifier)}"
if lenient_issubclass(field.annotation, pydantic.BaseModel):
params = _flatten_pydantic_model(field.annotation, qualifier) # type: ignore
pydantic_parameters.update(params)
else:
default = (
field.default if field.default is not pydantic.fields._Unset else ...
)
typer_option = Option(f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}")
pydantic_parameters[sub_name] = inspect.Parameter(
sub_name,
inspect.Parameter.KEYWORD_ONLY,
annotation=Annotated[field.annotation, typer_option, qualifier],
default=default,
)
return pydantic_parameters


def wrap_pydantic_callback(callback: Callable[..., Any]) -> Callable[..., Any]:
if pydantic is None:
return callback

original_signature = inspect_signature(callback)

pydantic_parameters = {}
pydantic_roots = {}
other_parameters = {}
for name, parameter in original_signature.parameters.items():
if lenient_issubclass(parameter.annotation, pydantic.BaseModel):
params = _flatten_pydantic_model(parameter.annotation, [name])
pydantic_parameters.update(params)
pydantic_roots[name] = parameter.annotation
else:
other_parameters[name] = parameter

extended_signature = inspect.Signature(
[*other_parameters.values(), *pydantic_parameters.values()],
return_annotation=original_signature.return_annotation,
)

def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
converted_kwargs = kwargs.copy()
raw_pydantic_objects: Dict[str, Any] = {}
for kwarg_name, kwarg_value in kwargs.items():
if kwarg_name in pydantic_parameters:
converted_kwargs.pop(kwarg_name)
annotation = pydantic_parameters[kwarg_name].annotation
_, qualifier = annotation.__metadata__
for part in reversed(qualifier):
kwarg_value = {part: kwarg_value}
raw_pydantic_objects = deep_update(raw_pydantic_objects, kwarg_value)
for root_name, value in raw_pydantic_objects.items():
converted_kwargs[root_name] = pydantic_roots[root_name](**value)
return callback(*args, **converted_kwargs)

wrapper.__signature__ = extended_signature # type: ignore
# Copy annotations to make forward references work in Python <= 3.9
wrapper.__annotations__ = {
k: v.annotation for k, v in extended_signature.parameters.items()
}
return wrapper
50 changes: 47 additions & 3 deletions typer/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
import inspect
import sys
from copy import copy
from typing import Any, Callable, Dict, List, Tuple, Type, cast, get_type_hints
from typing import (
Any,
Callable,
Dict,
List,
Tuple,
Type,
TypeVar,
Union,
cast,
get_type_hints,
)

from typing_extensions import Annotated

from ._typing import get_args, get_origin
from .models import ArgumentInfo, OptionInfo, ParameterInfo, ParamMeta
from .models import AnyType, ArgumentInfo, OptionInfo, ParameterInfo, ParamMeta


def _param_type_to_user_string(param_type: Type[ParameterInfo]) -> str:
Expand Down Expand Up @@ -106,11 +117,44 @@ def _split_annotation_from_typer_annotations(
]


def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]:
def lenient_issubclass(
cls: Any, class_or_tuple: Union[AnyType, Tuple[AnyType, ...]]
) -> bool:
return isinstance(cls, type) and issubclass(cls, class_or_tuple)


KeyType = TypeVar("KeyType")


def deep_update(
mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any]
) -> Dict[KeyType, Any]:
# Copied from pydantic because they don't expose it publicly:
# https://github.com/pydantic/pydantic/blob/26129479a06960af9d02d3a948e51985fe59ed4b/pydantic/_internal/_utils.py#L103
updated_mapping = mapping.copy()
for updating_mapping in updating_mappings:
for k, v in updating_mapping.items():
if (
k in updated_mapping
and isinstance(updated_mapping[k], dict)
and isinstance(v, dict)
):
updated_mapping[k] = deep_update(updated_mapping[k], v)
else:
updated_mapping[k] = v
return updated_mapping


def inspect_signature(func: Callable[..., Any]) -> inspect.Signature:
if sys.version_info >= (3, 10):
signature = inspect.signature(func, eval_str=True)
else:
signature = inspect.signature(func)
return signature


def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]:
signature = inspect_signature(func)

type_hints = get_type_hints(func)
params = {}
Expand Down