Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Union Discriminator functioanlity with Pydantic 1.10+ #19

Merged
merged 2 commits into from Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 22 additions & 12 deletions examples/union_field.py
Expand Up @@ -22,19 +22,29 @@ class ContactMethod(BaseModel):
text: str


st.header("Form inputs from model")
input_data = sp.pydantic_input(key="union_input", model=ContactMethod)
if input_data:
st.json(input_data)
from_model_tab, from_instance_tab = st.tabs(
["Form inputs from model", "Form inputs from instance"]
)

with from_model_tab:
input_data = sp.pydantic_input(key="union_input", model=ContactMethod)
if input_data:
st.json(input_data)


with from_instance_tab:
instance = ContactMethod(
contact=EmailAddress(email="instance@example.com", send_news=True),
text="instance text",
)

instance_input_data = sp.pydantic_input(key="union_input_instance", model=instance)

if instance_input_data:
st.json(instance_input_data)

st.header("Form inputs from instance")
instance = ContactMethod(
contact=EmailAddress(email="instance@example.com", send_news=True),
text="instance text",
)

instance_input_data = sp.pydantic_input(key="union_input_instance", model=instance)
st.markdown("---")

if instance_input_data:
st.json(instance_input_data)
with st.expander("Session State", expanded=False):
st.write(st.session_state)
39 changes: 24 additions & 15 deletions examples/union_field_discriminator.py
Expand Up @@ -26,21 +26,30 @@ class ContactMethod(BaseModel):
text: str


st.header("Form inputs from model")
input_data = sp.pydantic_input(key="union_input", model=ContactMethod)
if input_data:
st.json(input_data)


st.header("Form inputs from instance")
instance = ContactMethod(
contact=EmailAddress(
contact_type="email", email="instance@example.com", send_news=True
),
text="instance text",
from_model_tab, from_instance_tab = st.tabs(
["Form inputs from model", "Form inputs from instance"]
)

instance_input_data = sp.pydantic_input(key="union_input_instance", model=instance)
with from_model_tab:
input_data = sp.pydantic_input(key="union_input", model=ContactMethod)
if input_data:
st.json(input_data)

if instance_input_data:
st.json(instance_input_data)

with from_instance_tab:
instance = ContactMethod(
contact=EmailAddress(
contact_type="email", email="instance@example.com", send_news=True
),
text="instance text",
)

instance_input_data = sp.pydantic_input(key="union_input_instance", model=instance)

if instance_input_data:
st.json(instance_input_data)

st.markdown("---")

with st.expander("Session State", expanded=False):
st.write(st.session_state)
12 changes: 8 additions & 4 deletions src/streamlit_pydantic/schema_utils.py
Expand Up @@ -20,7 +20,8 @@ def get_single_reference_item(property: Dict, references: Dict) -> Dict:

def get_union_references(property: Dict, references: Dict) -> List[Dict]:
# Ref can either be directly in the properties or the first element of allOf
union_references = property.get("anyOf")
# anyOf is used for union property prior to pydantic < 1.10
union_references = property.get("oneOf", property.get("anyOf"))
resolved_references: List[Dict] = []
for reference in union_references: # type: ignore
resolved_references.append(resolve_reference(reference["$ref"], references))
Expand Down Expand Up @@ -124,13 +125,16 @@ def is_single_object(property: Dict, references: Dict) -> bool:


def is_union_property(property: Dict) -> bool:
if property.get("anyOf") is None:
# anyOf is used for union property prior to pydantic < 1.10
union_prop = property.get("oneOf", property.get("anyOf"))

if union_prop is None:
return False

if len(property.get("anyOf")) == 0: # type: ignore
if len(union_prop) == 0: # type: ignore
return False

for reference in property.get("anyOf"): # type: ignore
for reference in union_prop: # type: ignore
if not is_single_reference(reference):
return False

Expand Down