Skip to content

Commit

Permalink
🎬 : Improved file uploader and camera input to call its on_change han…
Browse files Browse the repository at this point in the history
…dler only when necessary (#4270)

* Add __eq__ method to UploadedFile, to be able to compare different UploadedFile instances with same FileRec.
We need that change to prevent `on_change` call for file_uploader on rerun.
  • Loading branch information
kajarenc committed Jan 25, 2022
1 parent 2c153aa commit ef14d2f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 17 deletions.
5 changes: 5 additions & 0 deletions lib/streamlit/uploaded_file_manager.py
Expand Up @@ -50,6 +50,11 @@ def __init__(self, record: UploadedFileRec):
self.type = record.type
self.size = len(record.data)

def __eq__(self, other: object) -> bool:
if not isinstance(other, UploadedFile):
return NotImplemented
return self.id == other.id

def __repr__(self) -> str:
return util.repr_(self)

Expand Down
2 changes: 1 addition & 1 deletion lib/tests/streamlit/file_uploader_test.py
Expand Up @@ -123,7 +123,7 @@ def test_unique_uploaded_file_instance(self, get_file_recs_patch):
file1: UploadedFile = st.file_uploader("a", accept_multiple_files=False)
file2: UploadedFile = st.file_uploader("b", accept_multiple_files=False)

self.assertNotEqual(file1, file2)
self.assertNotEqual(id(file1), id(file2))

# Seeking in one instance should not impact the position in the other.
file1.seek(2)
Expand Down
17 changes: 1 addition & 16 deletions lib/tests/streamlit/state/session_state_test.py
Expand Up @@ -305,22 +305,7 @@ def test_file_uploader_serde(self, get_file_recs_patch):
get_file_recs_patch.return_value = file_recs

uploaded_file = st.file_uploader("file_uploader", key="file_uploader")

# We can't use check_roundtrip here as the return_value of a
# file_uploader widget isn't a primitive value, so comparing them
# using == checks for reference equality.
session_state = get_session_state()
metadata = session_state.get_metadata_by_key("file_uploader")
serializer = metadata.serializer
deserializer = metadata.deserializer

file_after_serde = deserializer(serializer(uploaded_file), "")

assert uploaded_file.id == file_after_serde.id
assert uploaded_file.name == file_after_serde.name
assert uploaded_file.type == file_after_serde.type
assert uploaded_file.size == file_after_serde.size
assert uploaded_file.read() == file_after_serde.read()
check_roundtrip("file_uploader", uploaded_file)

def test_multiselect_serde(self):
multiselect = st.multiselect(
Expand Down

0 comments on commit ef14d2f

Please sign in to comment.