Skip to content

Commit

Permalink
Allow non-scalar values as tagged union keys (pydantic#655)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Jun 8, 2023
1 parent d2f4226 commit 3b1a6a7
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 303 deletions.
13 changes: 3 additions & 10 deletions generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def tagged_union(std_union_schema: Dict[str, Any], discriminator_key: str, ref:
first, *rest = literal
tagged_choices[first] = choice
for arg in rest:
tagged_choices[arg] = first
tagged_choices[arg] = choice
s = {'type': 'tagged-union', 'discriminator': discriminator_key, 'choices': tagged_choices}
if ref is not None:
s['ref'] = ref
Expand Down Expand Up @@ -129,15 +129,8 @@ def type_dict_schema(typed_dict) -> dict[str, Any]: # noqa: C901
schema = {'type': 'list', 'items_schema': schema_ref_validator}
elif fr_arg == 'Dict[str, CoreSchema]':
schema = {'type': 'dict', 'keys_schema': {'type': 'str'}, 'values_schema': schema_ref_validator}
elif fr_arg == 'Dict[Union[str, int], Union[str, int, CoreSchema]]':
schema = {
'type': 'dict',
'keys_schema': {'type': 'union', 'choices': [{'type': 'str'}, {'type': 'int'}]},
'values_schema': {
'type': 'union',
'choices': [{'type': 'str'}, {'type': 'int'}, schema_ref_validator],
},
}
elif fr_arg == 'Dict[Hashable, CoreSchema]':
schema = {'type': 'dict', 'keys_schema': {'type': 'any'}, 'values_schema': schema_ref_validator}
else:
raise ValueError(f'Unknown Schema forward ref: {fr_arg}')
else:
Expand Down
6 changes: 3 additions & 3 deletions pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from collections.abc import Mapping
from datetime import date, datetime, time, timedelta
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
from typing import Any, Callable, Dict, Hashable, List, Optional, Set, Type, Union

if sys.version_info < (3, 11):
from typing_extensions import Protocol, Required, TypeAlias
Expand Down Expand Up @@ -2361,7 +2361,7 @@ def union_schema(

class TaggedUnionSchema(TypedDict, total=False):
type: Required[Literal['tagged-union']]
choices: Required[Dict[Union[str, int], Union[str, int, CoreSchema]]]
choices: Required[Dict[Hashable, CoreSchema]]
discriminator: Required[
Union[str, List[Union[str, int]], List[List[Union[str, int]]], Callable[[Any], Optional[Union[str, int]]]]
]
Expand All @@ -2376,7 +2376,7 @@ class TaggedUnionSchema(TypedDict, total=False):


def tagged_union_schema(
choices: Dict[Union[int, str], int | str | CoreSchema],
choices: Dict[Hashable, CoreSchema],
discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], str | int | None],
*,
custom_error_type: str | None = None,
Expand Down
12 changes: 12 additions & 0 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
self.strict_int()
}

/// Extract an EitherInt from the input, only allowing exact
/// matches for an Int (no subclasses)
fn exact_int(&'a self) -> ValResult<EitherInt<'a>> {
self.strict_int()
}

/// Extract a String from the input, only allowing exact
/// matches for a String (no subclasses)
fn exact_str(&'a self) -> ValResult<EitherString<'a>> {
self.strict_str()
}

fn validate_float(&self, strict: bool, ultra_strict: bool) -> ValResult<f64> {
if ultra_strict {
self.ultra_strict_float()
Expand Down
8 changes: 8 additions & 0 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,14 @@ impl<'a> Input<'a> for PyAny {
}
}

fn exact_int(&'a self) -> ValResult<EitherInt<'a>> {
if PyInt::is_exact_type_of(self) {
Ok(EitherInt::Py(self))
} else {
Err(ValError::new(ErrorType::IntType, self))
}
}

fn lax_str(&'a self) -> ValResult<EitherString<'a>> {
if let Ok(py_str) = <PyString as PyTryFrom>::try_from_exact(self) {
Ok(py_str.into())
Expand Down
4 changes: 2 additions & 2 deletions src/lookup_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::errors::{ErrorType, ValLineError};
use crate::input::{Input, JsonInput, JsonObject};
use crate::tools::{extract_i64, py_err};

/// Used got getting items from python dicts, python objects, or JSON objects, in different ways
/// Used for getting items from python dicts, python objects, or JSON objects, in different ways
#[derive(Debug, Clone)]
pub(crate) enum LookupKey {
/// simply look up a key in a dict, equivalent to `d.get(key)`
Expand All @@ -29,7 +29,7 @@ pub(crate) enum LookupKey {
py_key2: Py<PyString>,
path2: LookupPath,
},
/// look up keys buy one or more "paths" a path might be `['foo', 'bar']` to get `d.?foo.?bar`
/// look up keys by one or more "paths" a path might be `['foo', 'bar']` to get `d.?foo.?bar`
/// ints are also supported to index arrays/lists/tuples and dicts with int keys
/// we reuse Location as the enum is the same, and the meaning is the same
PathChoices(Vec<LookupPath>),
Expand Down
147 changes: 96 additions & 51 deletions src/validators/literal.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// Validator for things inside of a typing.Literal[]
// which can be an int, a string, bytes or an Enum value (including `class Foo(str, Enum)` type enums)
use core::fmt::Debug;

use ahash::AHashSet;
use ahash::AHashMap;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
Expand All @@ -15,15 +16,96 @@ use crate::tools::SchemaDict;
use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};

#[derive(Debug, Clone)]
pub struct LiteralValidator {
pub struct LiteralLookup<T: Clone + Debug> {
// Specialized lookups for ints and strings because they
// (1) are easy to convert between Rust and Python
// (2) hashing them in Rust is very fast
// (3) are the most commonly used things in Literal[...]
expected_int: Option<AHashSet<i64>>,
expected_str: Option<AHashSet<String>>,
expected_int: Option<AHashMap<i64, usize>>,
expected_str: Option<AHashMap<String, usize>>,
// Catch all for Enum and bytes (the latter only because it is seldom used)
expected_py: Option<Py<PyDict>>,
pub values: Vec<T>,
}

impl<T: Clone + Debug> LiteralLookup<T> {
pub fn new<'py>(py: Python<'py>, expected: impl Iterator<Item = (&'py PyAny, T)>) -> PyResult<Self> {
let mut expected_int = AHashMap::new();
let mut expected_str = AHashMap::new();
let expected_py = PyDict::new(py);
let mut values = Vec::new();
for (k, v) in expected {
let id = values.len();
values.push(v);
if let Ok(either_int) = k.exact_int() {
let int = either_int
.into_i64(py)
.map_err(|_| py_schema_error_type!("error extracting int {:?}", k))?;
expected_int.insert(int, id);
} else if let Ok(either_str) = k.exact_str() {
let str = either_str
.as_cow()
.map_err(|_| py_schema_error_type!("error extracting str {:?}", k))?;
expected_str.insert(str.to_string(), id);
} else {
expected_py.set_item(k, id)?;
}
}

Ok(Self {
expected_int: match expected_int.is_empty() {
true => None,
false => Some(expected_int),
},
expected_str: match expected_str.is_empty() {
true => None,
false => Some(expected_str),
},
expected_py: match expected_py.is_empty() {
true => None,
false => Some(expected_py.into()),
},
values,
})
}

pub fn validate<'data, I: Input<'data>>(
&self,
py: Python<'data>,
input: &'data I,
) -> ValResult<'data, Option<(&'data I, &T)>> {
// dbg!(input.to_object(py).as_ref(py).repr().unwrap());
if let Some(expected_ints) = &self.expected_int {
if let Ok(either_int) = input.exact_int() {
let int = either_int.into_i64(py)?;
if let Some(id) = expected_ints.get(&int) {
return Ok(Some((input, &self.values[*id])));
}
}
}
if let Some(expected_strings) = &self.expected_str {
// dbg!(expected_strings);
if let Ok(either_str) = input.exact_str() {
let cow = either_str.as_cow()?;
if let Some(id) = expected_strings.get(cow.as_ref()) {
return Ok(Some((input, &self.values[*id])));
}
}
}
// must be an enum or bytes
if let Some(expected_py) = &self.expected_py {
if let Some(v) = expected_py.as_ref(py).get_item(input) {
let id: usize = v.extract().unwrap();
return Ok(Some((input, &self.values[id])));
}
};
Ok(None)
}
}

#[derive(Debug, Clone)]
pub struct LiteralValidator {
lookup: LiteralLookup<PyObject>,
expected_repr: String,
name: String,
}
Expand All @@ -41,32 +123,14 @@ impl BuildValidator for LiteralValidator {
return py_schema_err!("`expected` should have length > 0");
}
let py = expected.py();
// Literal[...] only supports int, str, bytes or enums, all of which can be hashed
let mut expected_int = AHashSet::new();
let mut expected_str = AHashSet::new();
let expected_py = PyDict::new(py);
let mut repr_args: Vec<String> = Vec::new();
for item in expected.iter() {
repr_args.push(item.repr()?.extract()?);
if let Ok(either_int) = item.strict_int() {
let int = either_int
.into_i64(py)
.map_err(|_| py_schema_error_type!("error extracting int {:?}", item))?;
expected_int.insert(int);
} else if let Ok(either_str) = item.strict_str() {
let str = either_str
.as_cow()
.map_err(|_| py_schema_error_type!("error extracting str {:?}", item))?;
expected_str.insert(str.to_string());
} else {
expected_py.set_item(item, item)?;
}
}
let (expected_repr, name) = expected_repr_name(repr_args, "literal");
let lookup = LiteralLookup::new(py, expected.iter().map(|v| (v, v.to_object(py))))?;
Ok(CombinedValidator::Literal(Self {
expected_int: (!expected_int.is_empty()).then_some(expected_int),
expected_str: (!expected_str.is_empty()).then_some(expected_str),
expected_py: (!expected_py.is_empty()).then_some(expected_py.into()),
lookup,
expected_repr,
name,
}))
Expand All @@ -82,34 +146,15 @@ impl Validator for LiteralValidator {
_definitions: &'data Definitions<CombinedValidator>,
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
if let Some(expected_ints) = &self.expected_int {
if let Ok(either_int) = input.strict_int() {
let int = either_int.into_i64(py)?;
if expected_ints.contains(&int) {
return Ok(input.to_object(py));
}
}
match self.lookup.validate(py, input)? {
Some((_, v)) => Ok(v.clone()),
None => Err(ValError::new(
ErrorType::LiteralError {
expected: self.expected_repr.clone(),
},
input,
)),
}
if let Some(expected_strings) = &self.expected_str {
if let Ok(either_str) = input.strict_str() {
let cow = either_str.as_cow()?;
if expected_strings.contains(cow.as_ref()) {
return Ok(input.to_object(py));
}
}
}
// must be an enum or bytes
if let Some(expected_py) = &self.expected_py {
if let Some(v) = expected_py.as_ref(py).get_item(input) {
return Ok(v.into());
}
};
Err(ValError::new(
ErrorType::LiteralError {
expected: self.expected_repr.clone(),
},
input,
))
}

fn different_strict_behavior(
Expand Down

0 comments on commit 3b1a6a7

Please sign in to comment.