Skip to content

Commit

Permalink
Extract attribute docstrings for FieldInfo.description (#6563)
Browse files Browse the repository at this point in the history
Co-authored-by: Alex Hall <alex.mojaki@gmail.com>
  • Loading branch information
Viicos and alexmojaki committed Jan 25, 2024
1 parent d67eff2 commit 6e59619
Show file tree
Hide file tree
Showing 7 changed files with 521 additions and 7 deletions.
2 changes: 2 additions & 0 deletions pydantic/_internal/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class ConfigWrapper:
coerce_numbers_to_str: bool
regex_engine: Literal['rust-regex', 'python-re']
validation_error_cause: bool
use_attribute_docstrings: bool

def __init__(self, config: ConfigDict | dict[str, Any] | type[Any] | None, *, check: bool = True):
if check:
Expand Down Expand Up @@ -253,6 +254,7 @@ def push(self, config_wrapper: ConfigWrapper | ConfigDict | None):
coerce_numbers_to_str=False,
regex_engine='rust-regex',
validation_error_cause=False,
use_attribute_docstrings=False,
)


Expand Down
11 changes: 8 additions & 3 deletions pydantic/_internal/_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,20 @@ class PydanticDataclass(StandardDataclass, typing.Protocol):
DeprecationWarning = PydanticDeprecatedSince20


def set_dataclass_fields(cls: type[StandardDataclass], types_namespace: dict[str, Any] | None = None) -> None:
def set_dataclass_fields(
cls: type[StandardDataclass],
types_namespace: dict[str, Any] | None = None,
config_wrapper: _config.ConfigWrapper | None = None,
) -> None:
"""Collect and set `cls.__pydantic_fields__`.
Args:
cls: The class.
types_namespace: The types namespace, defaults to `None`.
config_wrapper: The config wrapper instance, defaults to `None`.
"""
typevars_map = get_standard_typevars_map(cls)
fields = collect_dataclass_fields(cls, types_namespace, typevars_map=typevars_map)
fields = collect_dataclass_fields(cls, types_namespace, typevars_map=typevars_map, config_wrapper=config_wrapper)

cls.__pydantic_fields__ = fields # type: ignore

Expand Down Expand Up @@ -111,7 +116,7 @@ def complete_dataclass(
if types_namespace is None:
types_namespace = _typing_extra.get_cls_types_namespace(cls)

set_dataclass_fields(cls, types_namespace)
set_dataclass_fields(cls, types_namespace, config_wrapper=config_wrapper)

typevars_map = get_standard_typevars_map(cls)
gen_schema = GenerateSchema(
Expand Down
103 changes: 103 additions & 0 deletions pydantic/_internal/_docs_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Utilities related to attribute docstring extraction."""
from __future__ import annotations

import ast
import inspect
import textwrap
from typing import Any


class DocstringVisitor(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()

self.target: str | None = None
self.attrs: dict[str, str] = {}
self.previous_node_type: type[ast.AST] | None = None

def visit(self, node: ast.AST) -> Any:
node_result = super().visit(node)
self.previous_node_type = type(node)
return node_result

def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
if isinstance(node.target, ast.Name):
self.target = node.target.id

def visit_Expr(self, node: ast.Expr) -> Any:
if isinstance(node.value, ast.Str) and self.previous_node_type is ast.AnnAssign:
docstring = inspect.cleandoc(node.value.s)
if self.target:
self.attrs[self.target] = docstring
self.target = None


def _dedent_source_lines(source: list[str]) -> str:
# Required for nested class definitions, e.g. in a function block
dedent_source = textwrap.dedent(''.join(source))
if dedent_source.startswith((' ', '\t')):
# We are in the case where there's a dedented (usually multiline) string
# at a lower indentation level than the class itself. We wrap our class
# in a function as a workaround.
dedent_source = f'def dedent_workaround():\n{dedent_source}'
return dedent_source


def _extract_source_from_frame(cls: type[Any]) -> list[str] | None:
frame = inspect.currentframe()

while frame:
if inspect.getmodule(frame) is inspect.getmodule(cls):
lnum = frame.f_lineno
try:
lines, _ = inspect.findsource(frame)
except OSError:
# Source can't be retrieved (maybe because running in an interactive terminal),
# we don't want to error here.
pass
else:
block_lines = inspect.getblock(lines[lnum - 1 :])
dedent_source = _dedent_source_lines(block_lines)
try:
block_tree = ast.parse(dedent_source)
except SyntaxError:
pass
else:
stmt = block_tree.body[0]
if isinstance(stmt, ast.FunctionDef) and stmt.name == 'dedent_workaround':
# `_dedent_source_lines` wrapped the class around the workaround function
stmt = stmt.body[0]
if isinstance(stmt, ast.ClassDef) and stmt.name == cls.__name__:
return block_lines

frame = frame.f_back


def extract_docstrings_from_cls(cls: type[Any], use_inspect: bool = False) -> dict[str, str]:
"""Map model attributes and their corresponding docstring.
Args:
cls: The class of the Pydantic model to inspect.
use_inspect: Whether to skip usage of frames to find the object and use
the `inspect` module instead.
Returns:
A mapping containing attribute names and their corresponding docstring.
"""
if use_inspect:
# Might not work as expected if two classes have the same name in the same source file.
try:
source, _ = inspect.getsourcelines(cls)
except OSError:
return {}
else:
source = _extract_source_from_frame(cls)

if not source:
return {}

dedent_source = _dedent_source_lines(source)

visitor = DocstringVisitor()
visitor.visit(ast.parse(dedent_source))
return visitor.attrs
21 changes: 20 additions & 1 deletion pydantic/_internal/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from . import _typing_extra
from ._config import ConfigWrapper
from ._docs_extraction import extract_docstrings_from_cls
from ._repr import Representation
from ._typing_extra import get_cls_type_hints_lenient, get_type_hints, is_classvar, is_finalvar

Expand Down Expand Up @@ -87,6 +88,14 @@ def __init__(self, metadata: Any):
return _PydanticGeneralMetadata # type: ignore


def _update_fields_from_docstrings(cls: type[Any], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper) -> None:
if config_wrapper.use_attribute_docstrings:
fields_docs = extract_docstrings_from_cls(cls)
for ann_name, field_info in fields.items():
if field_info.description is None and ann_name in fields_docs:
field_info.description = fields_docs[ann_name]


def collect_model_fields( # noqa: C901
cls: type[BaseModel],
bases: tuple[type[Any], ...],
Expand Down Expand Up @@ -231,6 +240,8 @@ def collect_model_fields( # noqa: C901
for field in fields.values():
field.apply_typevars_map(typevars_map, types_namespace)

_update_fields_from_docstrings(cls, fields, config_wrapper)

return fields, class_vars


Expand All @@ -248,14 +259,19 @@ def _is_finalvar_with_default_val(type_: type[Any], val: Any) -> bool:


def collect_dataclass_fields(
cls: type[StandardDataclass], types_namespace: dict[str, Any] | None, *, typevars_map: dict[Any, Any] | None = None
cls: type[StandardDataclass],
types_namespace: dict[str, Any] | None,
*,
typevars_map: dict[Any, Any] | None = None,
config_wrapper: ConfigWrapper | None = None,
) -> dict[str, FieldInfo]:
"""Collect the fields of a dataclass.
Args:
cls: dataclass.
types_namespace: Optional extra namespace to look for types in.
typevars_map: A dictionary mapping type variables to their concrete types.
config_wrapper: The config wrapper instance.
Returns:
The dataclass fields.
Expand Down Expand Up @@ -308,6 +324,9 @@ def collect_dataclass_fields(
for field in fields.values():
field.apply_typevars_map(typevars_map, types_namespace)

if config_wrapper is not None:
_update_fields_from_docstrings(cls, fields, config_wrapper)

return fields


Expand Down
12 changes: 12 additions & 0 deletions pydantic/_internal/_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
inspect_model_serializer,
inspect_validator,
)
from ._docs_extraction import extract_docstrings_from_cls
from ._fields import collect_dataclass_fields, get_type_hints_infer_globalns
from ._forward_ref import PydanticRecursiveRef
from ._generics import get_standard_typevars_map, has_instance_in_type, recursively_defined_type_refs, replace_types
Expand Down Expand Up @@ -1210,6 +1211,11 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co

decorators = DecoratorInfos.build(typed_dict_cls)

if self._config_wrapper.use_attribute_docstrings:
field_docstrings = extract_docstrings_from_cls(typed_dict_cls, use_inspect=True)
else:
field_docstrings = None

for field_name, annotation in get_type_hints_infer_globalns(
typed_dict_cls, localns=self._types_namespace, include_extras=True
).items():
Expand All @@ -1230,6 +1236,12 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co
)[0]

field_info = FieldInfo.from_annotation(annotation)
if (
field_docstrings is not None
and field_info.description is None
and field_name in field_docstrings
):
field_info.description = field_docstrings[field_name]
fields[field_name] = self._generate_td_field_schema(
field_name, field_info, decorators, required=required
)
Expand Down
42 changes: 39 additions & 3 deletions pydantic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ class Model(BaseModel):

regex_engine: Literal['rust-regex', 'python-re']
"""
The regex engine to used for pattern validation
The regex engine to be used for pattern validation.
Defaults to `'rust-regex'`.
- `rust-regex` uses the [`regex`](https://docs.rs/regex) Rust crate,
Expand Down Expand Up @@ -899,14 +899,50 @@ class Model(BaseModel):

validation_error_cause: bool
"""
If `True`, python exceptions that were part of a validation failure will be shown as an exception group as a cause. Can be useful for debugging. Defaults to `False`.
If `True`, Python exceptions that were part of a validation failure will be shown as an exception group as a cause. Can be useful for debugging. Defaults to `False`.
Note:
Python 3.10 and older don't support exception groups natively. <=3.10, backport must be installed: `pip install exceptiongroup`.
Note:
The structure of validation errors are likely to change in future pydantic versions. Pydantic offers no guarantees about the structure of validation errors. Should be used for visual traceback debugging only.
The structure of validation errors are likely to change in future Pydantic versions. Pydantic offers no guarantees about their structure. Should be used for visual traceback debugging only.
"""

use_attribute_docstrings: bool
'''
Whether docstrings of attributes (bare string literals immediately following the attribute declaration)
should be used for field descriptions.
```py
from pydantic import BaseModel, ConfigDict, Field
class Model(BaseModel):
model_config = ConfigDict(use_attribute_docstrings=True)
x: str
"""
Example of an attribute docstring
"""
y: int = Field(description="Description in Field")
"""
Description in Field overrides attribute docstring
"""
print(Model.model_fields["x"].description)
# > Example of an attribute docstring
print(Model.model_fields["y"].description)
# > Description in Field
```
This requires the source code of the class to be available at runtime.
!!! warning "Usage with `TypedDict`"
Due to current limitations, attribute docstrings detection may not work as expected when using `TypedDict`
(in particular when multiple `TypedDict` classes have the same name in the same source file). The behavior
can be different depending on the Python version used.
'''


__getattr__ = getattr_migration(__name__)

0 comments on commit 6e59619

Please sign in to comment.