Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support discriminated union #2336

Merged
merged 57 commits into from Dec 18, 2021

Conversation

PrettyWood
Copy link
Member

@PrettyWood PrettyWood commented Feb 9, 2021

Change Summary

Add discriminated union support and support open api spec about it

Related issue number

closes #619
closes #3113

Checklist

  • Unit tests for the changes exist
  • Tests pass on CI and coverage remains at 100%
  • Documentation reflects the changes where applicable
  • changes/<pull request or issue id>-<github username>.md file added describing change
    (see changes/README.md for details)

@codecov
Copy link

codecov bot commented Feb 9, 2021

Codecov Report

Merging #2336 (c015f0f) into master (7b7e705) will not change coverage.
The diff coverage is 100.00%.

❗ Current head c015f0f differs from pull request most recent head b974edb. Consider uploading reports for the commit b974edb to get more accurate results

@@            Coverage Diff            @@
##            master     #2336   +/-   ##
=========================================
  Coverage   100.00%   100.00%           
=========================================
  Files           25        25           
  Lines         5109      5179   +70     
  Branches      1050      1065   +15     
=========================================
+ Hits          5109      5179   +70     
Impacted Files Coverage Δ
pydantic/__init__.py 100.00% <ø> (ø)
pydantic/fields.py 100.00% <100.00%> (ø)
pydantic/schema.py 100.00% <100.00%> (ø)
pydantic/tools.py 100.00% <100.00%> (ø)
pydantic/typing.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update afcd155...b974edb. Read the comment docs.

@PrettyWood PrettyWood force-pushed the f/discriminated-union branch 2 times, most recently from 78248a1 to 876b2ab Compare February 9, 2021 23:26
@PrettyWood PrettyWood changed the title [wip] Support discriminated union Support discriminated union Feb 9, 2021
@PrettyWood PrettyWood force-pushed the f/discriminated-union branch 2 times, most recently from fc200f6 to 74d4355 Compare February 11, 2021 00:11
@levrik
Copy link

levrik commented Feb 12, 2021

OpenAPI 3.x actually doesn't support the const option. An enum with a single option would be correct here.
JSON schemas support the const option.
I was playing a bit around with this but I was hitting format errors on the const keyword.
But I'm also seeing that this doesn't come from this PR specifically.
Even the JSON Schema spec describes const as a shortcut for enum with single option (https://json-schema.org/draft/2019-09/json-schema-validation.html#rfc.section.6.1.3)

@PrettyWood
Copy link
Member Author

@levrik It's another issue that was already opened regarding proper support of Literal.
I made a separate PR #2348 for this

@levrik
Copy link

levrik commented Feb 12, 2021

@PrettyWood That's great news!

tests/test_schema.py Outdated Show resolved Hide resolved
@amagee
Copy link

amagee commented Feb 16, 2021

I noticed this does not work if the types I am unioning over are forward references.

Ie. this works:

from pydantic import BaseModel, Field
from typing import Literal, Union

class Group(BaseModel):
    type: Literal["Group"]

class Layer(BaseModel):
    type: Literal["Layer"]

class GroupOrLayer(BaseModel):
    __root__: Union["Group", "Layer"] = Field(..., discriminator="type")

but if I move the Group and Layer classes beneath GroupOrLayer, I get this:

Traceback (most recent call last):
  File "mcve.py", line 4, in <module>
    class GroupOrLayer(BaseModel):
  File "/path/pydantic/main.py", line 291, in __new__
    fields[ann_name] = ModelField.infer(
  File "/path/pydantic/fields.py", line 390, in infer
    return cls(
  File "/path/pydantic/fields.py", line 325, in __init__
    self.prepare()
  File "/path/pydantic/fields.py", line 430, in prepare
    self._type_analysis()
  File "/path/pydantic/fields.py", line 539, in _type_analysis
    t_discriminator_type = t.__fields__[discriminator_key].outer_type_
AttributeError: 'ForwardRef' object has no attribute '__fields__'

(Just on initialisation of the class; no need to actually do anything with it.)

@PrettyWood
Copy link
Member Author

Thanks @amagee. I had to rewrite a bit the logic to support properly ForwardRef.
Feedback welcome

@vdwees
Copy link
Contributor

vdwees commented Feb 22, 2021

I gave this branch a try. Definitely a step in the right direction. In the script below I demonstrate a common situation I encounter where there is more than one discriminating field. I am wondering if you think this is something worth supporting in a discriminated union- would it even be compatible with OpenAPI?

from typing import Literal, Union
from pydantic import BaseModel, Field


class DomainA(BaseModel):
    domain: Literal["A"]


class DomainB(BaseModel):
    domain: Literal["B"]


class FooDomainA(DomainA):
    identifier: Literal["foo"]


class BarDomainA(DomainA):
    identifier: Literal["bar"]


class FooDomainB(DomainB):
    identifier: Literal["foo"]


class BarDomainB(DomainB):
    identifier: Literal["bar"]


class MyClass(BaseModel):
    __root__: Union[FooDomainA, BarDomainA, FooDomainB, BarDomainB] = Field(
        ..., discriminator="domain"
    )


for domain in ("A", "B"):
    for identifier in ("foo", "bar"):
        print(domain, identifier)
        try:
            print(MyClass.parse_obj({"domain": domain, "identifier": identifier}))
        except ValueError as e:
            print(e)
        print()

Output:

A foo
1 validation error for MyClass
__root__ -> identifier
  unexpected value; permitted: 'bar' (type=value_error.const; given=foo; permitted=('bar',))

A bar
__root__=BarDomainA(domain='A', identifier='bar')

B foo
1 validation error for MyClass
__root__ -> identifier
  unexpected value; permitted: 'bar' (type=value_error.const; given=foo; permitted=('bar',))

B bar
__root__=BarDomainB(domain='B', identifier='bar')

@PrettyWood
Copy link
Member Author

PrettyWood commented Feb 22, 2021

@vdwees Thanks a lot for trying it out!
What would be the configuration? an array of discriminators?
And what about the generated json schema?

@vdwees
Copy link
Contributor

vdwees commented Feb 23, 2021

@PrettyWood I looked into it a bit more, and based on my reading of the openapi schema, I think the discriminator is only a single field. So as much as I would like the discriminator kwarg to support a sequence in addition to a single value, I guess it wouldn't make sense to support that here 😢

But I was thinking- even if it is only a single value, it might still be reasonable/possible to support similar functionality if the discriminator was nestable, e.g something like below:

class FooDomainA(BaseModel):
    domain: Literal["A"]
    identifier: Literal["foo"]

class BarDomainA(BaseModel):
    domain: Literal["A"]
    identifier: Literal["bar"]

class FooDomainB(BaseModel):
    domain: Literal["B"]
    identifier: Literal["foo"]

class BarDomainB(BaseModel):
    domain: Literal["B"]
    identifier: Literal["bar"]

class DomainA(BaseModel):
    __root__: Union[FooDomainA, BarDomainA] = Field(..., discriminator="identifier")

class DomainB(BaseModel):
    __root__: Union[FooDomainB, BarDomainB] = Field(..., discriminator="identifier")

class PolymorphClass(BaseModel):
    __root__: Union[DomainA, DomainB] = Field(..., discriminator="domain")

@vdwees
Copy link
Contributor

vdwees commented Feb 23, 2021

Actually, it seems valid from the OpenAPI perspective to do it this way, based on the specs I found here

Example OpenAPI schemas

components:
  schemas:
    PolymorphClass:
      title: PolymorphClass
      discriminator:
        propertyName: domain
        mapping:
          A: '#/components/schemas/DomainA'
          B: '#/components/schemas/DomainB'
      anyOf:
        - $ref: '#/components/schemas/DomainA'
        - $ref: '#/components/schemas/DomainB'
    DomainA:
      title: DomainA
      discriminator:
        propertyName: identifier
        mapping:
          foo: '#/components/schemas/FooDomainA'
          bar: '#/components/schemas/BarDomainA'
      anyOf:
        - $ref: '#/components/schemas/FooDomainA'
        - $ref: '#/components/schemas/BarDomainA'
    DomainB:
      title: DomainB
      discriminator:
        propertyName: identifier
        mapping:
          foo: '#/components/schemas/FooDomainB'
          bar: '#/components/schemas/BarDomainB'
      anyOf:
        - $ref: '#/components/schemas/FooDomainB'
        - $ref: '#/components/schemas/BarDomainB'
    FooDomainA:
      title: FooDomainA
      type: object
      properties:
        domain:
          title: Domain
          enum:
            - A
          type: string
        identifier:
          title: Identifier
          enum:
            - foo
          type: string
      required:
        - domain
        - identifier
    BarDomainA:
      title: BarDomainA
      type: object
      properties:
        domain:
          title: Domain
          enum:
            - A
          type: string
        identifier:
          title: Identifier
          enum:
            - bar
          type: string
      required:
        - domain
        - identifier
    FooDomainB:
      title: FooDomainB
      type: object
      properties:
        domain:
          title: Domain
          enum:
            - B
          type: string
        identifier:
          title: Identifier
          enum:
            - foo
          type: string
      required:
        - domain
        - identifier
    BarDomainB:
      title: BarDomainB
      type: object
      properties:
        domain:
          title: Domain
          enum:
            - B
          type: string
        identifier:
          title: Identifier
          enum:
            - bar
          type: string
      required:
        - domain
        - identifier

image

Example PolymorphClass.schema()

{
  "title": "PolymorphClass",
  "discriminator": {
    "propertyName": "domain",
    "mapping": {
      "A": "#/definitions/DomainA",
      "B": "#/definitions/DomainB"
    }
  },
  "anyOf": [
    {
      "$ref": "#/definitions/DomainA"
    },
    {
      "$ref": "#/definitions/DomainB"
    }
  ],
  "definitions": {
    "DomainA": {
      "title": "DomainA",
      "discriminator": {
        "propertyName": "identifier",
        "mapping": {
          "foo": "#/definitions/FooDomainA",
          "bar": "#/definitions/BarDomainA"
        }
      },
      "anyOf": [
        {
          "$ref": "#/definitions/FooDomainA"
        },
        {
          "$ref": "#/definitions/BarDomainA"
        }
      ]
    },
    "DomainB": {
      "title": "DomainB",
      "discriminator": {
        "propertyName": "identifier",
        "mapping": {
          "foo": "#/definitions/FooDomainB",
          "bar": "#/definitions/BarDomainB"
        }
      },
      "anyOf": [
        {
          "$ref": "#/definitions/FooDomainB"
        },
        {
          "$ref": "#/definitions/BarDomainB"
        }
      ]
    },
    "FooDomainA": {
      "title": "FooDomainA",
      "type": "object",
      "properties": {
        "domain": {
          "title": "Domain",
          "const": "A",
          "type": "string"
        },
        "identifier": {
          "title": "Identifier",
          "const": "foo",
          "type": "string"
        }
      },
      "required": ["domain", "identifier"]
    },
    "BarDomainA": {
      "title": "BarDomainA",
      "type": "object",
      "properties": {
        "domain": {
          "title": "Domain",
          "const": "A",
          "type": "string"
        },
        "identifier": {
          "title": "Identifier",
          "const": "bar",
          "type": "string"
        }
      },
      "required": ["domain", "identifier"]
    },
    "FooDomainB": {
      "title": "FooDomainB",
      "type": "object",
      "properties": {
        "domain": {
          "title": "Domain",
          "const": "B",
          "type": "string"
        },
        "identifier": {
          "title": "Identifier",
          "const": "foo",
          "type": "string"
        }
      },
      "required": ["domain", "identifier"]
    },
    "BarDomainB": {
      "title": "BarDomainB",
      "type": "object",
      "properties": {
        "domain": {
          "title": "Domain",
          "const": "B",
          "type": "string"
        },
        "identifier": {
          "title": "Identifier",
          "const": "bar",
          "type": "string"
        }
      },
      "required": ["domain", "identifier"]
    }
  }
}

The pydantic API should also support this part of the spec. Maybe using a tuple like this:

class PolymorphClass(BaseModel):
    __root__: Union[DomainA, DomainB] = Field(
        ..., discriminator=("domain", {"A": DomainA, "B": DomainB})
    )

@PrettyWood PrettyWood force-pushed the f/discriminated-union branch 3 times, most recently from 98b268a to 3398459 Compare February 23, 2021 23:55
@samuelcolvin
Copy link
Member

Just to say thank you so much for working on this.

Giving this a proper review is weighing heavy on my conscience, but I can't devote the time before v1.8.

@PrettyWood
Copy link
Member Author

PrettyWood commented Feb 25, 2021

@samuelcolvin No worries! I want to improve it further to support nested discriminated unions. As I said in the 1.8 discussion, better wait for v1.9 for the PRs on union!

Copy link
Contributor

@antdking antdking left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good.
I've put in some cleanup suggestions to reduce some duplication, and simplify calling update_forward_refs (tested locally).

I think an assumption has been made that will cause some headaches though.
It's possible to have duplicated discriminator values.

class Spaniel(BaseModel):
  pet_type: Literal['dog']
  breed: Literal['spaniel']

class JackRussel(BaseModel):
  pet_type: Literal['dog']
  breed: Literal['Jack Russel']

class Cat(BaseModel):
  pet_type: Literal['cat']

class Model(BaseModel):
  pet: Union[Spaniel, JackRussel, Cat] = Field(discriminator='pet_type')

I don't think this is an edge case either, we have something to this effect running in production.

In this case, I think the code should be defensive:

  • Don't generate the mapping in the schema (or exclude mappings that are duplicates)
  • have sub_field_mappings be defined as Dict[str, List[ModelField]], so it can be used to narrow down validation

pydantic/fields.py Outdated Show resolved Hide resolved
pydantic/fields.py Outdated Show resolved Hide resolved
pydantic/fields.py Outdated Show resolved Hide resolved
pydantic/fields.py Outdated Show resolved Hide resolved
pydantic/typing.py Outdated Show resolved Hide resolved
@antdking
Copy link
Contributor

While looking into nested discriminators (as a true fix to the above), You end up with a schema akin to this:

definitions:
  Pet:
    anyOf:
      - $ref: '#/definitions/Cat'
      - $ref: '#/definitions/Dog'
   discriminator:
      propertyName: pet_type
      mapping:
        cat: '#/definitions/Cat'
        dog: '#/definitions/Dog'
  Cat:
    properties:
      pet_type:
        type: string
        const: cat
  Dog:
    anyOf:
      - $ref: '#/definitions/Spaniel'
      - $ref: '#/definitions/JackRussel'
    discriminator:
      propertyName: breed
      mapping:
        spaniel: '#/definitions/Spaniel'
        jack-russel: '#/definitions/JackRussel'
  Spaniel:
    properties:
      pet_type:
        type: string
        const: dog
      breed:
        type: string
        const: spaniel
      s:
        type: string
  JackRussel:
    properties:
      pet_type:
        type: string
        const: dog
      breed:
        type: string
        const: jack-russel
      j:
        type: string

Note how we have Dog as an intermediary object, wrapping our breeds.

To express this, we need to define Dog as:

class Dog(BaseModel):
  __root__ = Union[Spaniel, JackRussel] = Field(..., discriminator='breed')

class Model(BaseModel):
        pet: Union[Cat, Dog] = Field(..., discriminator='pet_type')
        number: int

However this isn't supported by the current code, raising a KeyError:
KeyError: "Model 'Dog' needs a discriminator field for key 'pet_type'"


An example of using in Redoc: https://redocly.github.io/redoc/?url=https://gist.githubusercontent.com/cybojenix/5258ebe39cd100e99be76e5175f9b41f/raw/6b1708c71cad50ef00fda6bbf1d35ea439d6e2fa/nested-discriminator.yml

This actually breaks the interface when selecting dog as the pet_type, though it looks like others are looking for similar functionality.

@PrettyWood
Copy link
Member Author

Thanks a lot @cybojenix for the review 🙏 It was still a work in progress but glad to have some constructive feedback on the changes we can do. I'll have a look at it probably next week!

@PrettyWood
Copy link
Member Author

PrettyWood commented Mar 14, 2021

Thanks @vdwees and @cybojenix for the examples.
I did the changes and started to work on the nested unions.
Currently I have this

from typing import Literal, Union

from pydantic import BaseModel, Field


class FooDomainAA(BaseModel):
    domain: Literal["A"]
    identifier: Literal["foo"]
    type: Literal["a"]

class FooDomainAB(BaseModel):
    domain: Literal["A"]
    identifier: Literal["foo"]
    type: Literal["b"]

class BarDomainA(BaseModel):
    domain: Literal["A"]
    identifier: Literal["bar"]

class FooDomainB(BaseModel):
    domain: Literal["B"]
    identifier: Literal["foo"]

class BarDomainB(BaseModel):
    domain: Literal["B"]
    identifier: Literal["bar"]

class FooDomainA(BaseModel):
    __root__: Union[FooDomainAA, FooDomainAB] = Field(..., discriminator="type")

class DomainA(BaseModel):
    __root__: Union[FooDomainA, BarDomainA] = Field(..., discriminator="identifier")

class DomainB(BaseModel):
    __root__: Union[FooDomainB, BarDomainB] = Field(..., discriminator="identifier")

class PolymorphClass(BaseModel):
    __root__: Union[DomainA, DomainB] = Field(..., discriminator="domain")

assert PolymorphClass.schema() == {
  "title": "PolymorphClass",
  "discriminator": {
    "propertyName": "domain",
    "mapping": {
      "A": "#/definitions/DomainA",
      "B": "#/definitions/DomainB"
    }
  },
  "anyOf": [
    {
      "$ref": "#/definitions/DomainA"
    },
    {
      "$ref": "#/definitions/DomainB"
    }
  ],
  "definitions": {
    "FooDomainAA": {
      "title": "FooDomainAA",
      "type": "object",
      "properties": {
        "domain": {
          "title": "Domain",
          "enum": [
            "A"
          ],
          "type": "string"
        },
        "identifier": {
          "title": "Identifier",
          "enum": [
            "foo"
          ],
          "type": "string"
        },
        "type": {
          "title": "Type",
          "enum": [
            "a"
          ],
          "type": "string"
        }
      },
      "required": [
        "domain",
        "identifier",
        "type"
      ]
    },
    "FooDomainAB": {
      "title": "FooDomainAB",
      "type": "object",
      "properties": {
        "domain": {
          "title": "Domain",
          "enum": [
            "A"
          ],
          "type": "string"
        },
        "identifier": {
          "title": "Identifier",
          "enum": [
            "foo"
          ],
          "type": "string"
        },
        "type": {
          "title": "Type",
          "enum": [
            "b"
          ],
          "type": "string"
        }
      },
      "required": [
        "domain",
        "identifier",
        "type"
      ]
    },
    "FooDomainA": {
      "title": "FooDomainA",
      "discriminator": {
        "propertyName": "type",
        "mapping": {
          "a": "#/definitions/FooDomainAA",
          "b": "#/definitions/FooDomainAB"
        }
      },
      "anyOf": [
        {
          "$ref": "#/definitions/FooDomainAA"
        },
        {
          "$ref": "#/definitions/FooDomainAB"
        }
      ]
    },
    "BarDomainA": {
      "title": "BarDomainA",
      "type": "object",
      "properties": {
        "domain": {
          "title": "Domain",
          "enum": [
            "A"
          ],
          "type": "string"
        },
        "identifier": {
          "title": "Identifier",
          "enum": [
            "bar"
          ],
          "type": "string"
        }
      },
      "required": [
        "domain",
        "identifier"
      ]
    },
    "DomainA": {
      "title": "DomainA",
      "discriminator": {
        "propertyName": "identifier",
        "mapping": {
          "foo": "#/definitions/FooDomainA",
          "bar": "#/definitions/BarDomainA"
        }
      },
      "anyOf": [
        {
          "$ref": "#/definitions/FooDomainA"
        },
        {
          "$ref": "#/definitions/BarDomainA"
        }
      ]
    },
    "FooDomainB": {
      "title": "FooDomainB",
      "type": "object",
      "properties": {
        "domain": {
          "title": "Domain",
          "enum": [
            "B"
          ],
          "type": "string"
        },
        "identifier": {
          "title": "Identifier",
          "enum": [
            "foo"
          ],
          "type": "string"
        }
      },
      "required": [
        "domain",
        "identifier"
      ]
    },
    "BarDomainB": {
      "title": "BarDomainB",
      "type": "object",
      "properties": {
        "domain": {
          "title": "Domain",
          "enum": [
            "B"
          ],
          "type": "string"
        },
        "identifier": {
          "title": "Identifier",
          "enum": [
            "bar"
          ],
          "type": "string"
        }
      },
      "required": [
        "domain",
        "identifier"
      ]
    },
    "DomainB": {
      "title": "DomainB",
      "discriminator": {
        "propertyName": "identifier",
        "mapping": {
          "foo": "#/definitions/FooDomainB",
          "bar": "#/definitions/BarDomainB"
        }
      },
      "anyOf": [
        {
          "$ref": "#/definitions/FooDomainB"
        },
        {
          "$ref": "#/definitions/BarDomainB"
        }
      ]
    }
  }
}

It looks good to me but maybe I'm mistaken.
Feedback more than welcome again :)
I'll try to take some time tonight to come back on it and add tests and extra documentation

@ryukinix
Copy link

ryukinix commented Dec 7, 2021

Any estimation date for the next release of pydantic with this feature included?

@PrettyWood
Copy link
Member Author

@ryukinix Probably before the end of the year. I'll come back on this PR asap

@PrettyWood
Copy link
Member Author

PrettyWood commented Dec 11, 2021

@samuelcolvin I took most of your remarks into consideration. I also added alias support (see #2336 (comment)).
TBH I'm not super happy with the implementation, which doesn't feel very robust... But handling all the different cases (root models, annotated unions, same or different discriminator...) is quite annoying with the current codebase.
Feel free to modify the PR directly I'll try to take some time on the two other big PRs (dataclass and computed fields) tonight or tomorrow

@PrettyWood
Copy link
Member Author

Please review

@samuelcolvin samuelcolvin merged commit c834f34 into pydantic:master Dec 18, 2021
@samuelcolvin
Copy link
Member

🎉

🙏

@foarsitter
Copy link

Thanks for all your hard work folks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet