diff --git a/examples/union_field.py b/examples/union_field.py index bf14d3e..8c8c41b 100644 --- a/examples/union_field.py +++ b/examples/union_field.py @@ -22,19 +22,29 @@ class ContactMethod(BaseModel): text: str -st.header("Form inputs from model") -input_data = sp.pydantic_input(key="union_input", model=ContactMethod) -if input_data: - st.json(input_data) +from_model_tab, from_instance_tab = st.tabs( + ["Form inputs from model", "Form inputs from instance"] +) +with from_model_tab: + input_data = sp.pydantic_input(key="union_input", model=ContactMethod) + if input_data: + st.json(input_data) + + +with from_instance_tab: + instance = ContactMethod( + contact=EmailAddress(email="instance@example.com", send_news=True), + text="instance text", + ) + + instance_input_data = sp.pydantic_input(key="union_input_instance", model=instance) + + if instance_input_data: + st.json(instance_input_data) -st.header("Form inputs from instance") -instance = ContactMethod( - contact=EmailAddress(email="instance@example.com", send_news=True), - text="instance text", -) -instance_input_data = sp.pydantic_input(key="union_input_instance", model=instance) +st.markdown("---") -if instance_input_data: - st.json(instance_input_data) +with st.expander("Session State", expanded=False): + st.write(st.session_state) diff --git a/examples/union_field_discriminator.py b/examples/union_field_discriminator.py index 4297b52..340ce11 100644 --- a/examples/union_field_discriminator.py +++ b/examples/union_field_discriminator.py @@ -26,21 +26,30 @@ class ContactMethod(BaseModel): text: str -st.header("Form inputs from model") -input_data = sp.pydantic_input(key="union_input", model=ContactMethod) -if input_data: - st.json(input_data) - - -st.header("Form inputs from instance") -instance = ContactMethod( - contact=EmailAddress( - contact_type="email", email="instance@example.com", send_news=True - ), - text="instance text", +from_model_tab, from_instance_tab = st.tabs( + ["Form inputs from model", "Form inputs from instance"] ) -instance_input_data = sp.pydantic_input(key="union_input_instance", model=instance) +with from_model_tab: + input_data = sp.pydantic_input(key="union_input", model=ContactMethod) + if input_data: + st.json(input_data) -if instance_input_data: - st.json(instance_input_data) + +with from_instance_tab: + instance = ContactMethod( + contact=EmailAddress( + contact_type="email", email="instance@example.com", send_news=True + ), + text="instance text", + ) + + instance_input_data = sp.pydantic_input(key="union_input_instance", model=instance) + + if instance_input_data: + st.json(instance_input_data) + +st.markdown("---") + +with st.expander("Session State", expanded=False): + st.write(st.session_state) diff --git a/src/streamlit_pydantic/schema_utils.py b/src/streamlit_pydantic/schema_utils.py index 3deeee4..baf280f 100644 --- a/src/streamlit_pydantic/schema_utils.py +++ b/src/streamlit_pydantic/schema_utils.py @@ -20,7 +20,8 @@ def get_single_reference_item(property: Dict, references: Dict) -> Dict: def get_union_references(property: Dict, references: Dict) -> List[Dict]: # Ref can either be directly in the properties or the first element of allOf - union_references = property.get("anyOf") + # anyOf is used for union property prior to pydantic < 1.10 + union_references = property.get("oneOf", property.get("anyOf")) resolved_references: List[Dict] = [] for reference in union_references: # type: ignore resolved_references.append(resolve_reference(reference["$ref"], references)) @@ -124,13 +125,16 @@ def is_single_object(property: Dict, references: Dict) -> bool: def is_union_property(property: Dict) -> bool: - if property.get("anyOf") is None: + # anyOf is used for union property prior to pydantic < 1.10 + union_prop = property.get("oneOf", property.get("anyOf")) + + if union_prop is None: return False - if len(property.get("anyOf")) == 0: # type: ignore + if len(union_prop) == 0: # type: ignore return False - for reference in property.get("anyOf"): # type: ignore + for reference in union_prop: # type: ignore if not is_single_reference(reference): return False