diff --git a/pydantic_core/core_schema.py b/pydantic_core/core_schema.py index fef7194f72..80f8536731 100644 --- a/pydantic_core/core_schema.py +++ b/pydantic_core/core_schema.py @@ -3065,6 +3065,7 @@ class CallSchema(TypedDict, total=False): type: Required[Literal['call']] arguments_schema: Required[CoreSchema] function: Required[Callable[..., Any]] + function_name: str # default function.__name__ return_schema: CoreSchema ref: str metadata: Any @@ -3075,6 +3076,7 @@ def call_schema( arguments: CoreSchema, function: Callable[..., Any], *, + function_name: str | None = None, return_schema: CoreSchema | None = None, ref: str | None = None, metadata: Any = None, @@ -3106,6 +3108,7 @@ def call_schema( Args: arguments: The arguments to use for the arguments schema function: The function to use for the call schema + function_name: The function name to use for the call schema, if not provided `function.__name__` is used return_schema: The return schema to use for the call schema ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core @@ -3115,6 +3118,7 @@ def call_schema( type='call', arguments_schema=arguments, function=function, + function_name=function_name, return_schema=return_schema, ref=ref, metadata=metadata, diff --git a/src/validators/call.rs b/src/validators/call.rs index 0632288a77..8e92f0b376 100644 --- a/src/validators/call.rs +++ b/src/validators/call.rs @@ -37,7 +37,22 @@ impl BuildValidator for CallValidator { None => None, }; let function: &PyAny = schema.get_as_req(intern!(py, "function"))?; - let function_name: &str = function.getattr(intern!(py, "__name__"))?.extract()?; + let function_name: &str = match schema.get_as(intern!(py, "function_name"))? { + Some(name) => name, + None => { + match function.getattr(intern!(py, "__name__")) { + Ok(name) => name.extract()?, + Err(_) => { + // partials we use `function.func.__name__` + if let Ok(func) = function.getattr(intern!(py, "func")) { + func.getattr(intern!(py, "__name__"))?.extract()? + } else { + "" + } + } + } + } + }; let name = format!("{}[{function_name}]", Self::EXPECTED_TYPE); Ok(Self { diff --git a/tests/validators/test_call.py b/tests/validators/test_call.py index dd30aca627..52ee7525d4 100644 --- a/tests/validators/test_call.py +++ b/tests/validators/test_call.py @@ -1,6 +1,7 @@ import dataclasses import re from collections import namedtuple +from functools import partial import pytest @@ -178,3 +179,45 @@ def test_named_tuple(): assert isinstance(d, Point) assert d.x == 1.1 assert d.y == 2.2 + + +def test_function_call_partial(): + def my_function(a, b, c): + return a + b + c + + v = SchemaValidator( + { + 'type': 'call', + 'function': partial(my_function, c=3), + 'arguments_schema': { + 'type': 'arguments', + 'arguments_schema': [ + {'name': 'a', 'mode': 'positional_or_keyword', 'schema': {'type': 'int'}}, + {'name': 'b', 'mode': 'positional_or_keyword', 'schema': {'type': 'int'}}, + ], + }, + } + ) + assert 'name:"call[my_function]"' in plain_repr(v) + assert v.validate_python((1, 2)) == 6 + assert v.validate_python((1, '2')) == 6 + + +def test_custom_name(): + def my_function(a): + return a + + v = SchemaValidator( + { + 'type': 'call', + 'function': my_function, + 'function_name': 'foobar', + 'arguments_schema': { + 'type': 'arguments', + 'arguments_schema': [{'name': 'a', 'mode': 'positional_or_keyword', 'schema': {'type': 'int'}}], + }, + } + ) + assert 'name:"call[foobar]"' in plain_repr(v) + assert v.validate_python((1,)) == 1 + assert v.validate_python(('2',)) == 2