Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract attribute docstrings for FieldInfo.description #6563

Merged
merged 38 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
29bbf3e
WIP: Implement docs extraction
Viicos Jul 9, 2023
792da19
Implement attribute docstring extractiob
Viicos Jul 10, 2023
d3e8402
Apply feedback
Viicos Jul 12, 2023
c9325d7
Implement `ast.NodeVisitor`
Viicos Jul 12, 2023
1f81724
Add `use_attributes_docstring` config value
Viicos Jul 12, 2023
cab7a07
WIP: First tests
Viicos Jul 12, 2023
e46890c
Fix tests to cope with `inspect.getsource`
Viicos Jul 13, 2023
af4f340
Add config docstring
Viicos Sep 20, 2023
d3a4761
Support `TypedDict`
Viicos Sep 20, 2023
3da7402
Walk back frames to get source code
Viicos Oct 22, 2023
4f73618
Handle the case where `f_lineno` is not an `int`
Viicos Oct 22, 2023
2102756
Fix usage of `getblock`
Viicos Oct 23, 2023
8ae16c0
Improve class detection when walking up frames
Viicos Oct 23, 2023
436e95a
Apply feedback and last fixes
Viicos Oct 23, 2023
85e4c0f
Add test with `create_model`
Viicos Oct 23, 2023
d8cbc3b
WIP: Feedback and additional tests
Viicos Oct 24, 2023
f8c56a9
More tests and small code refactors
Viicos Oct 24, 2023
44a4a73
Use same model names in tests
Viicos Oct 24, 2023
43c1c82
Fix rebase
Viicos Oct 24, 2023
626a19a
Use source lines from `frame` object
Viicos Oct 25, 2023
359846b
Improve config docstring, apply additional feedback, fix infinite loop
Viicos Oct 25, 2023
c12f020
`pyright`
Viicos Oct 25, 2023
b9ca191
lint again
Viicos Oct 25, 2023
05598d3
Remove unrelated file
Viicos Oct 25, 2023
a044bce
Fix handling of dataclasses with Python>=3.11
Viicos Oct 25, 2023
23c8695
Fix return value
Viicos Oct 25, 2023
70b5ec2
Parse the frame block using `ast`
Viicos Oct 26, 2023
7d9d77a
Dedent fix
Viicos Oct 26, 2023
0f1743b
Some optimizations
Viicos Oct 26, 2023
ab3dfc0
Comments reordering
Viicos Oct 26, 2023
d451205
Remove commented `breakpoint()`
Viicos Oct 27, 2023
146f915
Add additional generic assertion
Viicos Oct 27, 2023
23396b8
Fix failing test
Viicos Oct 27, 2023
73f1a94
Apply minor changes and fixes
Viicos Oct 28, 2023
f3e84a6
Update to latest `pyright`
Viicos Nov 10, 2023
d9150c6
fix rebase
Viicos Nov 15, 2023
abba7d2
Use latest pyright
Viicos Nov 15, 2023
caa5cbe
Merge branch 'main' into docstrings-description
Viicos Jan 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions pydantic/_internal/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class ConfigWrapper:
coerce_numbers_to_str: bool
regex_engine: Literal['rust-regex', 'python-re']
validation_error_cause: bool
use_attribute_docstrings: bool

def __init__(self, config: ConfigDict | dict[str, Any] | type[Any] | None, *, check: bool = True):
if check:
Expand Down Expand Up @@ -253,6 +254,7 @@ def push(self, config_wrapper: ConfigWrapper | ConfigDict | None):
coerce_numbers_to_str=False,
regex_engine='rust-regex',
validation_error_cause=False,
use_attribute_docstrings=False,
)


Expand Down
11 changes: 8 additions & 3 deletions pydantic/_internal/_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,20 @@ class PydanticDataclass(StandardDataclass, typing.Protocol):
DeprecationWarning = PydanticDeprecatedSince20


def set_dataclass_fields(cls: type[StandardDataclass], types_namespace: dict[str, Any] | None = None) -> None:
def set_dataclass_fields(
cls: type[StandardDataclass],
types_namespace: dict[str, Any] | None = None,
config_wrapper: _config.ConfigWrapper | None = None,
) -> None:
"""Collect and set `cls.__pydantic_fields__`.

Args:
cls: The class.
types_namespace: The types namespace, defaults to `None`.
config_wrapper: The config wrapper instance, defaults to `None`.
"""
typevars_map = get_standard_typevars_map(cls)
fields = collect_dataclass_fields(cls, types_namespace, typevars_map=typevars_map)
fields = collect_dataclass_fields(cls, types_namespace, typevars_map=typevars_map, config_wrapper=config_wrapper)

cls.__pydantic_fields__ = fields # type: ignore

Expand Down Expand Up @@ -115,7 +120,7 @@ def complete_dataclass(
if types_namespace is None:
types_namespace = _typing_extra.get_cls_types_namespace(cls)

set_dataclass_fields(cls, types_namespace)
set_dataclass_fields(cls, types_namespace, config_wrapper=config_wrapper)

typevars_map = get_standard_typevars_map(cls)
gen_schema = GenerateSchema(
Expand Down
103 changes: 103 additions & 0 deletions pydantic/_internal/_docs_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Utilities related to attribute docstring extraction."""
from __future__ import annotations

import ast
import inspect
import textwrap
from typing import Any


class DocstringVisitor(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()

self.target: str | None = None
self.attrs: dict[str, str] = {}
self.previous_node_type: type[ast.AST] | None = None

def visit(self, node: ast.AST) -> Any:
node_result = super().visit(node)
self.previous_node_type = type(node)
return node_result

def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
if isinstance(node.target, ast.Name):
self.target = node.target.id

def visit_Expr(self, node: ast.Expr) -> Any:
if isinstance(node.value, ast.Str) and self.previous_node_type is ast.AnnAssign:
docstring = inspect.cleandoc(node.value.s)
if self.target:
self.attrs[self.target] = docstring
self.target = None
Comment on lines +10 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class DocstringVisitor(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()
self.target: str | None = None
self.attrs: dict[str, str] = {}
self.previous_node_type: type[ast.AST] | None = None
def visit(self, node: ast.AST) -> Any:
node_result = super().visit(node)
self.previous_node_type = type(node)
return node_result
def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
if isinstance(node.target, ast.Name):
self.target = node.target.id
def visit_Expr(self, node: ast.Expr) -> Any:
if isinstance(node.value, ast.Str) and self.previous_node_type is ast.AnnAssign:
docstring = inspect.cleandoc(node.value.s)
if self.target:
self.attrs[self.target] = docstring
self.target = None
class DocstringVisitor(ast.NodeVisitor):
def __init__(self) -> None:
self.attrs: dict[str, str] = {}
self.previous_node: ast.AST | None = None
def visit(self, node: ast.AST) -> Any:
node_result = super().visit(node)
self.previous_node = node
return node_result
def visit_Expr(self, node: ast.Expr) -> Any:
if (
isinstance(node.value, ast.Str)
and isinstance(self.previous_node, ast.AnnAssign)
and isinstance(self.previous_node.target, ast.Name)
):
docstring = inspect.cleandoc(node.value.s)
self.attrs[self.previous_node.target.id] = docstring

Does this work?

Personally I'm not convinced that using a visitor is clearer or better than your first implementation. Visiting every node in the tree is a waste, and I don't like having to think about the order in which the visit methods get called.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before applying the suggestion/reverting to the old solution, I'll let maintainers decide what they prefer.



def _dedent_source_lines(source: list[str]) -> str:
# Required for nested class definitions, e.g. in a function block
dedent_source = textwrap.dedent(''.join(source))
if dedent_source.startswith((' ', '\t')):
# We are in the case where there's a dedented (usually multiline) string
# at a lower indentation level than the class itself. We wrap our class
# in a function as a workaround.
dedent_source = f'def dedent_workaround():\n{dedent_source}'
return dedent_source


def _extract_source_from_frame(cls: type[Any]) -> list[str] | None:
frame = inspect.currentframe()

while frame:
if inspect.getmodule(frame) is inspect.getmodule(cls):
lnum = frame.f_lineno
try:
lines, _ = inspect.findsource(frame)
except OSError:
# Source can't be retrieved (maybe because running in an interactive terminal),
# we don't want to error here.
pass
else:
block_lines = inspect.getblock(lines[lnum - 1 :])
dedent_source = _dedent_source_lines(block_lines)
try:
block_tree = ast.parse(dedent_source)
except SyntaxError:
pass
else:
stmt = block_tree.body[0]
if isinstance(stmt, ast.FunctionDef) and stmt.name == 'dedent_workaround':
# `_dedent_source_lines` wrapped the class around the workaround function
stmt = stmt.body[0]
if isinstance(stmt, ast.ClassDef) and stmt.name == cls.__name__:
return block_lines

frame = frame.f_back


def extract_docstrings_from_cls(cls: type[Any], use_inspect: bool = False) -> dict[str, str]:
"""Map model attributes and their corresponding docstring.

Args:
cls: The class of the Pydantic model to inspect.
use_inspect: Whether to skip usage of frames to find the object and use
the `inspect` module instead.

Returns:
A mapping containing attribute names and their corresponding docstring.
Viicos marked this conversation as resolved.
Show resolved Hide resolved
"""
if use_inspect:
# Might not work as expected if two classes have the same name in the same source file.
Viicos marked this conversation as resolved.
Show resolved Hide resolved
try:
source, _ = inspect.getsourcelines(cls)
except OSError:
return {}
else:
source = _extract_source_from_frame(cls)

if not source:
return {}

dedent_source = _dedent_source_lines(source)

visitor = DocstringVisitor()
visitor.visit(ast.parse(dedent_source))
return visitor.attrs
21 changes: 20 additions & 1 deletion pydantic/_internal/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from . import _typing_extra
from ._config import ConfigWrapper
from ._docs_extraction import extract_docstrings_from_cls
from ._repr import Representation
from ._typing_extra import get_cls_type_hints_lenient, get_type_hints, is_classvar, is_finalvar

Expand Down Expand Up @@ -85,6 +86,14 @@ def __init__(self, metadata: Any):
return _PydanticGeneralMetadata # type: ignore


def _update_fields_from_docstrings(cls: type[Any], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper) -> None:
if config_wrapper.use_attribute_docstrings:
fields_docs = extract_docstrings_from_cls(cls)
for ann_name, field_info in fields.items():
if field_info.description is None and ann_name in fields_docs:
field_info.description = fields_docs[ann_name]


def collect_model_fields( # noqa: C901
cls: type[BaseModel],
bases: tuple[type[Any], ...],
Expand Down Expand Up @@ -229,6 +238,8 @@ def collect_model_fields( # noqa: C901
for field in fields.values():
field.apply_typevars_map(typevars_map, types_namespace)

_update_fields_from_docstrings(cls, fields, config_wrapper)

return fields, class_vars


Expand All @@ -246,14 +257,19 @@ def _is_finalvar_with_default_val(type_: type[Any], val: Any) -> bool:


def collect_dataclass_fields(
cls: type[StandardDataclass], types_namespace: dict[str, Any] | None, *, typevars_map: dict[Any, Any] | None = None
cls: type[StandardDataclass],
types_namespace: dict[str, Any] | None,
*,
typevars_map: dict[Any, Any] | None = None,
config_wrapper: ConfigWrapper | None = None,
) -> dict[str, FieldInfo]:
"""Collect the fields of a dataclass.

Args:
cls: dataclass.
types_namespace: Optional extra namespace to look for types in.
typevars_map: A dictionary mapping type variables to their concrete types.
config_wrapper: The config wrapper instance.

Returns:
The dataclass fields.
Expand Down Expand Up @@ -299,6 +315,9 @@ def collect_dataclass_fields(
for field in fields.values():
field.apply_typevars_map(typevars_map, types_namespace)

if config_wrapper is not None:
_update_fields_from_docstrings(cls, fields, config_wrapper)

return fields


Expand Down
12 changes: 12 additions & 0 deletions pydantic/_internal/_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
inspect_model_serializer,
inspect_validator,
)
from ._docs_extraction import extract_docstrings_from_cls
from ._fields import collect_dataclass_fields, get_type_hints_infer_globalns
from ._forward_ref import PydanticRecursiveRef
from ._generics import get_standard_typevars_map, has_instance_in_type, recursively_defined_type_refs, replace_types
Expand Down Expand Up @@ -1206,6 +1207,11 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co

decorators = DecoratorInfos.build(typed_dict_cls)

if self._config_wrapper.use_attribute_docstrings:
field_docstrings = extract_docstrings_from_cls(typed_dict_cls, use_inspect=True)
else:
field_docstrings = None

for field_name, annotation in get_type_hints_infer_globalns(
typed_dict_cls, localns=self._types_namespace, include_extras=True
).items():
Expand All @@ -1226,6 +1232,12 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co
)[0]

field_info = FieldInfo.from_annotation(annotation)
if (
field_docstrings is not None
and field_info.description is None
and field_name in field_docstrings
):
field_info.description = field_docstrings[field_name]
Viicos marked this conversation as resolved.
Show resolved Hide resolved
fields[field_name] = self._generate_td_field_schema(
field_name, field_info, decorators, required=required
)
Expand Down
42 changes: 39 additions & 3 deletions pydantic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ class Model(BaseModel):

regex_engine: Literal['rust-regex', 'python-re']
"""
The regex engine to used for pattern validation
The regex engine to be used for pattern validation.
Defaults to `'rust-regex'`.

- `rust-regex` uses the [`regex`](https://docs.rs/regex) Rust crate,
Expand Down Expand Up @@ -899,14 +899,50 @@ class Model(BaseModel):

validation_error_cause: bool
"""
If `True`, python exceptions that were part of a validation failure will be shown as an exception group as a cause. Can be useful for debugging. Defaults to `False`.
If `True`, Python exceptions that were part of a validation failure will be shown as an exception group as a cause. Can be useful for debugging. Defaults to `False`.

Note:
Python 3.10 and older don't support exception groups natively. <=3.10, backport must be installed: `pip install exceptiongroup`.

Note:
The structure of validation errors are likely to change in future pydantic versions. Pydantic offers no guarantees about the structure of validation errors. Should be used for visual traceback debugging only.
The structure of validation errors are likely to change in future Pydantic versions. Pydantic offers no guarantees about their structure. Should be used for visual traceback debugging only.
"""

use_attribute_docstrings: bool
'''
Whether docstrings of attributes (bare string literals immediately following the attribute declaration)
should be used for field descriptions.

```py
from pydantic import BaseModel, ConfigDict, Field


class Model(BaseModel):
model_config = ConfigDict(use_attribute_docstrings=True)

x: str
"""
Example of an attribute docstring
"""

y: int = Field(description="Description in Field")
"""
Description in Field overrides attribute docstring
"""


print(Model.model_fields["x"].description)
# > Example of an attribute docstring
print(Model.model_fields["y"].description)
# > Description in Field
```
This requires the source code of the class to be available at runtime.

!!! warning "Usage with `TypedDict`"
Due to current limitations, attribute docstrings detection may not work as expected when using `TypedDict`
(in particular when multiple `TypedDict` classes have the same name in the same source file). The behavior
can be different depending on the Python version used.
'''


__getattr__ = getattr_migration(__name__)