/
image.py
425 lines (358 loc) · 14.7 KB
/
image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Some casts in this file are only occasionally necessary depending on the
# user's Python version, and mypy doesn't have a good way of toggling this
# specific config option at a per-line level.
# mypy: no-warn-unused-ignores
"""Image marshalling."""
import imghdr
import io
import mimetypes
from typing import cast, List, Optional, Sequence, TYPE_CHECKING, Tuple, Union
from typing_extensions import Final, Literal, TypeAlias
from urllib.parse import urlparse
import re
import numpy as np
from PIL import Image, ImageFile
from streamlit.errors import StreamlitAPIException
from streamlit.logger import get_logger
from streamlit.in_memory_file_manager import in_memory_file_manager
from streamlit.proto.Image_pb2 import ImageList as ImageListProto
if TYPE_CHECKING:
import numpy.typing as npt
from typing import Any
from streamlit.delta_generator import DeltaGenerator
LOGGER: Final = get_logger(__name__)
# This constant is related to the frontend maximum content width specified
# in App.jsx main container
# 730 is the max width of element-container in the frontend, and 2x is for high
# DPI.
MAXIMUM_CONTENT_WIDTH: Final[int] = 2 * 730
PILImage: TypeAlias = Union[ImageFile.ImageFile, Image.Image]
AtomicImage: TypeAlias = Union[PILImage, "npt.NDArray[Any]", io.BytesIO, str]
ImageOrImageList: TypeAlias = Union[AtomicImage, List[AtomicImage]]
UseColumnWith: TypeAlias = Optional[Union[Literal["auto", "always", "never"], bool]]
Channels: TypeAlias = Literal["RGB", "BGR"]
OutputFormat: TypeAlias = Literal["JPEG", "PNG", "auto"]
class ImageMixin:
def image(
self,
image: ImageOrImageList,
# TODO: Narrow type of caption, dependent on type of image,
# by way of overload
caption: Optional[Union[str, List[str]]] = None,
width: Optional[int] = None,
use_column_width: UseColumnWith = None,
clamp: bool = False,
channels: Channels = "RGB",
output_format: OutputFormat = "auto",
) -> "DeltaGenerator":
"""Display an image or list of images.
Parameters
----------
image : numpy.ndarray, [numpy.ndarray], BytesIO, str, or [str]
Monochrome image of shape (w,h) or (w,h,1)
OR a color image of shape (w,h,3)
OR an RGBA image of shape (w,h,4)
OR a URL to fetch the image from
OR a path of a local image file
OR an SVG XML string like `<svg xmlns=...</svg>`
OR a list of one of the above, to display multiple images.
caption : str or list of str
Image caption. If displaying multiple images, caption should be a
list of captions (one for each image).
width : int or None
Image width. None means use the image width,
but do not exceed the width of the column.
Should be set for SVG images, as they have no default image width.
use_column_width : 'auto' or 'always' or 'never' or bool
If 'auto', set the image's width to its natural size,
but do not exceed the width of the column.
If 'always' or True, set the image's width to the column width.
If 'never' or False, set the image's width to its natural size.
Note: if set, `use_column_width` takes precedence over the `width` parameter.
clamp : bool
Clamp image pixel values to a valid range ([0-255] per channel).
This is only meaningful for byte array images; the parameter is
ignored for image URLs. If this is not set, and an image has an
out-of-range value, an error will be thrown.
channels : 'RGB' or 'BGR'
If image is an nd.array, this parameter denotes the format used to
represent color information. Defaults to 'RGB', meaning
`image[:, :, 0]` is the red channel, `image[:, :, 1]` is green, and
`image[:, :, 2]` is blue. For images coming from libraries like
OpenCV you should set this to 'BGR', instead.
output_format : 'JPEG', 'PNG', or 'auto'
This parameter specifies the format to use when transferring the
image data. Photos should use the JPEG format for lossy compression
while diagrams should use the PNG format for lossless compression.
Defaults to 'auto' which identifies the compression type based
on the type and format of the image argument.
Example
-------
>>> from PIL import Image
>>> image = Image.open('sunrise.jpg')
>>>
>>> st.image(image, caption='Sunrise by the mountains')
.. output::
https://doc-image.streamlitapp.com/
height: 710px
"""
if use_column_width == "auto" or (use_column_width is None and width is None):
width = -3
elif use_column_width == "always" or use_column_width == True:
width = -2
elif width is None:
width = -1
elif width <= 0:
raise StreamlitAPIException("Image width must be positive.")
image_list_proto = ImageListProto()
marshall_images(
self.dg._get_delta_path_str(),
image,
caption,
width,
image_list_proto,
clamp,
channels,
output_format,
)
return self.dg._enqueue("imgs", image_list_proto)
@property
def dg(self) -> "DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("DeltaGenerator", self)
def _image_may_have_alpha_channel(image: PILImage) -> bool:
if image.mode in ("RGBA", "LA", "P"):
return True
else:
return False
def _format_from_image_type(
image: PILImage,
output_format: str,
) -> Literal["JPEG", "PNG"]:
output_format = output_format.upper()
if output_format == "JPEG" or output_format == "PNG":
return cast(
Literal["JPEG", "PNG"],
output_format,
)
# We are forgiving on the spelling of JPEG
if output_format == "JPG":
return "JPEG"
if _image_may_have_alpha_channel(image):
return "PNG"
return "JPEG"
def _PIL_to_bytes(
image: PILImage,
format: Literal["JPEG", "PNG"] = "JPEG",
quality: int = 100,
) -> bytes:
tmp = io.BytesIO()
# User must have specified JPEG, so we must convert it
if format == "JPEG" and _image_may_have_alpha_channel(image):
image = image.convert("RGB")
image.save(tmp, format=format, quality=quality)
return tmp.getvalue()
def _BytesIO_to_bytes(data: io.BytesIO) -> bytes:
data.seek(0)
return data.getvalue()
def _np_array_to_bytes(
array: "npt.NDArray[Any]",
output_format="JPEG",
) -> bytes:
img = Image.fromarray(array.astype(np.uint8))
format = _format_from_image_type(img, output_format)
return _PIL_to_bytes(img, format)
def _4d_to_list_3d(array: "npt.NDArray[Any]") -> List["npt.NDArray[Any]"]:
return [array[i, :, :, :] for i in range(0, array.shape[0])]
def _verify_np_shape(array: "npt.NDArray[Any]") -> "npt.NDArray[Any]":
if len(array.shape) not in (2, 3):
raise StreamlitAPIException("Numpy shape has to be of length 2 or 3.")
if len(array.shape) == 3 and array.shape[-1] not in (1, 3, 4):
raise StreamlitAPIException(
"Channel can only be 1, 3, or 4 got %d. Shape is %s"
% (array.shape[-1], str(array.shape))
)
# If there's only one channel, convert is to x, y
if len(array.shape) == 3 and array.shape[-1] == 1:
array = array[:, :, 0]
return array
def _normalize_to_bytes(
data,
width: int,
output_format: OutputFormat,
) -> Tuple[bytes, str]:
image = Image.open(io.BytesIO(data))
actual_width, actual_height = image.size
format = _format_from_image_type(image, output_format)
if output_format.lower() == "auto":
ext = imghdr.what(None, data)
mimetype = mimetypes.guess_type("image.%s" % ext)[0]
# if no other options, attempt to convert
if mimetype is None:
mimetype = "image/" + format.lower()
else:
mimetype = "image/" + format.lower()
if width < 0 and actual_width > MAXIMUM_CONTENT_WIDTH:
width = MAXIMUM_CONTENT_WIDTH
if width > 0 and actual_width > width:
new_height = int(1.0 * actual_height * width / actual_width)
image = image.resize((width, new_height), resample=Image.BILINEAR)
data = _PIL_to_bytes(image, format=format, quality=90)
mimetype = "image/" + format.lower()
return data, mimetype
def _clip_image(image: "npt.NDArray[Any]", clamp: bool) -> "npt.NDArray[Any]":
data = image
if issubclass(image.dtype.type, np.floating):
if clamp:
data = np.clip(image, 0, 1.0)
else:
if np.amin(image) < 0.0 or np.amax(image) > 1.0:
raise RuntimeError("Data is outside [0.0, 1.0] and clamp is not set.")
data = data * 255
else:
if clamp:
data = np.clip(image, 0, 255)
else:
if np.amin(image) < 0 or np.amax(image) > 255:
raise RuntimeError("Data is outside [0, 255] and clamp is not set.")
return data
def image_to_url(
image: AtomicImage,
width: int,
clamp: bool,
channels: Channels,
output_format: OutputFormat,
image_id: str,
allow_emoji: bool = False,
):
# PIL Images
if isinstance(image, ImageFile.ImageFile) or isinstance(image, Image.Image):
format = _format_from_image_type(image, output_format)
data = _PIL_to_bytes(image, format)
# BytesIO
# Note: This doesn't support SVG. We could convert to png (cairosvg.svg2png)
# or just decode BytesIO to string and handle that way.
elif isinstance(image, io.BytesIO):
data = _BytesIO_to_bytes(image)
# Numpy Arrays (ie opencv)
elif isinstance(image, np.ndarray):
image = _clip_image(
_verify_np_shape(image),
clamp,
)
if channels == "BGR":
if len(image.shape) == 3:
image = image[:, :, [2, 1, 0]]
else:
raise StreamlitAPIException(
'When using `channels="BGR"`, the input image should '
"have exactly 3 color channels"
)
# Depending on the version of numpy that the user has installed, the
# typechecker may not be able to deduce that indexing into a
# `npt.NDArray[Any]` returns a `npt.NDArray[Any]`, so we need to
# ignore redundant casts below.
data = _np_array_to_bytes(
array=cast("npt.NDArray[Any]", image), # type: ignore[redundant-cast]
output_format=output_format,
)
# Strings
elif isinstance(image, str):
# If it's a url, then set the protobuf and continue
try:
p = urlparse(image)
if p.scheme:
return image
except UnicodeDecodeError:
pass
# Finally, see if it's a file.
try:
with open(image, "rb") as f:
data = f.read()
except:
if allow_emoji:
# This might be an emoji string, so just pass it to the frontend
return image
else:
# Allow OS filesystem errors to raise
raise
# Assume input in bytes.
else:
data = image
(data, mimetype) = _normalize_to_bytes(data, width, output_format)
this_file = in_memory_file_manager.add(data, mimetype, image_id)
return this_file.url
def marshall_images(
coordinates: str,
image: ImageOrImageList,
caption: Optional[Union[str, "npt.NDArray[Any]", List[str]]],
width: int,
proto_imgs: ImageListProto,
clamp: bool,
channels: Channels = "RGB",
output_format: OutputFormat = "auto",
) -> None:
channels = cast(Channels, channels.upper())
# Turn single image and caption into one element list.
images: Sequence[AtomicImage]
if isinstance(image, list):
images = image
elif isinstance(image, np.ndarray) and len(image.shape) == 4:
images = _4d_to_list_3d(image)
else:
images = [image]
if type(caption) is list:
captions: Sequence[Optional[str]] = caption
else:
if isinstance(caption, str):
captions = [caption]
# You can pass in a 1-D Numpy array as captions.
elif isinstance(caption, np.ndarray) and len(caption.shape) == 1:
captions = caption.tolist()
# If there are no captions then make the captions list the same size
# as the images list.
elif caption is None:
captions = [None] * len(images)
else:
captions = [str(caption)]
assert type(captions) == list, "If image is a list then caption should be as well"
assert len(captions) == len(images), "Cannot pair %d captions with %d images." % (
len(captions),
len(images),
)
proto_imgs.width = width
# Each image in an image list needs to be kept track of at its own coordinates.
for coord_suffix, (image, caption) in enumerate(zip(images, captions)):
proto_img = proto_imgs.imgs.add()
if caption is not None:
proto_img.caption = str(caption)
# We use the index of the image in the input image list to identify this image inside
# InMemoryFileManager. For this, we just add the index to the image's "coordinates".
image_id = "%s-%i" % (coordinates, coord_suffix)
is_svg = False
if isinstance(image, str):
# Unpack local SVG image file to an SVG string
if image.endswith(".svg") and not image.startswith(("http://", "https://")):
with open(image) as textfile:
image = textfile.read()
# Following regex allows svg image files to start either via a "<?xml...>" tag eventually followed by a "<svg...>" tag or directly starting with a "<svg>" tag
if re.search(r"(^\s?(<\?xml[\s\S]*<svg )|^\s?<svg )", image):
proto_img.markup = f"data:image/svg+xml,{image}"
is_svg = True
if not is_svg:
proto_img.url = image_to_url(
image, width, clamp, channels, output_format, image_id
)