Skip to content

Commit

Permalink
allow partials to "call" validator (pydantic#546)
Browse files Browse the repository at this point in the history
* allow partials to call validator

* add function_name
  • Loading branch information
samuelcolvin committed Apr 13, 2023
1 parent 493bbb6 commit 2e2bd34
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 1 deletion.
4 changes: 4 additions & 0 deletions pydantic_core/core_schema.py
Expand Up @@ -3065,6 +3065,7 @@ class CallSchema(TypedDict, total=False):
type: Required[Literal['call']]
arguments_schema: Required[CoreSchema]
function: Required[Callable[..., Any]]
function_name: str # default function.__name__
return_schema: CoreSchema
ref: str
metadata: Any
Expand All @@ -3075,6 +3076,7 @@ def call_schema(
arguments: CoreSchema,
function: Callable[..., Any],
*,
function_name: str | None = None,
return_schema: CoreSchema | None = None,
ref: str | None = None,
metadata: Any = None,
Expand Down Expand Up @@ -3106,6 +3108,7 @@ def call_schema(
Args:
arguments: The arguments to use for the arguments schema
function: The function to use for the call schema
function_name: The function name to use for the call schema, if not provided `function.__name__` is used
return_schema: The return schema to use for the call schema
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
Expand All @@ -3115,6 +3118,7 @@ def call_schema(
type='call',
arguments_schema=arguments,
function=function,
function_name=function_name,
return_schema=return_schema,
ref=ref,
metadata=metadata,
Expand Down
17 changes: 16 additions & 1 deletion src/validators/call.rs
Expand Up @@ -37,7 +37,22 @@ impl BuildValidator for CallValidator {
None => None,
};
let function: &PyAny = schema.get_as_req(intern!(py, "function"))?;
let function_name: &str = function.getattr(intern!(py, "__name__"))?.extract()?;
let function_name: &str = match schema.get_as(intern!(py, "function_name"))? {
Some(name) => name,
None => {
match function.getattr(intern!(py, "__name__")) {
Ok(name) => name.extract()?,
Err(_) => {
// partials we use `function.func.__name__`
if let Ok(func) = function.getattr(intern!(py, "func")) {
func.getattr(intern!(py, "__name__"))?.extract()?
} else {
"<unknown>"
}
}
}
}
};
let name = format!("{}[{function_name}]", Self::EXPECTED_TYPE);

Ok(Self {
Expand Down
43 changes: 43 additions & 0 deletions tests/validators/test_call.py
@@ -1,6 +1,7 @@
import dataclasses
import re
from collections import namedtuple
from functools import partial

import pytest

Expand Down Expand Up @@ -178,3 +179,45 @@ def test_named_tuple():
assert isinstance(d, Point)
assert d.x == 1.1
assert d.y == 2.2


def test_function_call_partial():
def my_function(a, b, c):
return a + b + c

v = SchemaValidator(
{
'type': 'call',
'function': partial(my_function, c=3),
'arguments_schema': {
'type': 'arguments',
'arguments_schema': [
{'name': 'a', 'mode': 'positional_or_keyword', 'schema': {'type': 'int'}},
{'name': 'b', 'mode': 'positional_or_keyword', 'schema': {'type': 'int'}},
],
},
}
)
assert 'name:"call[my_function]"' in plain_repr(v)
assert v.validate_python((1, 2)) == 6
assert v.validate_python((1, '2')) == 6


def test_custom_name():
def my_function(a):
return a

v = SchemaValidator(
{
'type': 'call',
'function': my_function,
'function_name': 'foobar',
'arguments_schema': {
'type': 'arguments',
'arguments_schema': [{'name': 'a', 'mode': 'positional_or_keyword', 'schema': {'type': 'int'}}],
},
}
)
assert 'name:"call[foobar]"' in plain_repr(v)
assert v.validate_python((1,)) == 1
assert v.validate_python(('2',)) == 2

0 comments on commit 2e2bd34

Please sign in to comment.