diff --git a/doc/python/imshow.md b/doc/python/imshow.md index 91e8976363..1b92bfd2f6 100644 --- a/doc/python/imshow.md +++ b/doc/python/imshow.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.2' - jupytext_version: 1.4.2 + jupytext_version: 1.3.0 kernelspec: display_name: Python 3 language: python @@ -20,7 +20,7 @@ jupyter: name: python nbconvert_exporter: python pygments_lexer: ipython3 - version: 3.7.7 + version: 3.7.3 plotly: description: How to display image data in Python with Plotly. display_as: scientific @@ -399,6 +399,95 @@ for compression_level in range(0, 9): fig.show() ``` +### Exploring 3-D images, timeseries and sequences of images with `facet_col` + +*Introduced in plotly 4.14* + +For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by representing its different planes as facets. The `facet_col` argument specifies along which axis the image is sliced through to make the facets. With `facet_col_wrap`, one can set the maximum number of columns. For image datasets passed as xarrays, it is also possible to specify the axis by its name (label), thus passing a string to `facet_col`. + +It is recommended to use `binary_string=True` for facetted plots of images in order to keep a small figure size and a short rendering time. + +See the [tutorial on facet plots](/python/facet-plots/) for more information on creating and styling facet plots. + +```python +import plotly.express as px +from skimage import io +from skimage.data import image_fetcher +path = image_fetcher.fetch('data/cells.tif') +data = io.imread(path) +img = data[20:45:2] +fig = px.imshow(img, facet_col=0, binary_string=True, facet_col_wrap=5) +fig.show() +``` + +Facets can also be used to represent several images of equal shape, like in the example below where different values of the blurring parameter of a Gaussian filter are compared. + +```python +import plotly.express as px +import numpy as np +from skimage import data, filters, img_as_float +img = data.camera() +sigmas = [1, 2, 4] +img_sequence = [filters.gaussian(img, sigma=sigma) for sigma in sigmas] +fig = px.imshow(np.array(img_sequence), facet_col=0, binary_string=True, + labels={'facet_col':'sigma'}) +# Set facet titles +for i, sigma in enumerate(sigmas): + fig.layout.annotations[i]['text'] = 'sigma = %d' %sigma +fig.show() +``` + +```python +print(fig) +``` + +### Exploring 3-D images and timeseries with `animation_frame` + +*Introduced in plotly 4.14* + +For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by sliding through its different planes in an animation. The `animation_frame` argument of `px.imshow` sets the axis along which the 3-D image is sliced in the animation. + +```python +import plotly.express as px +from skimage import io +from skimage.data import image_fetcher +path = image_fetcher.fetch('data/cells.tif') +data = io.imread(path) +img = data[25:40] +fig = px.imshow(img, animation_frame=0, binary_string=True) +fig.show() +``` + +### Animations of xarray datasets + +*Introduced in plotly 4.14* + +For xarray datasets, one can pass either an axis number or an axis name to `animation_frame`. Axis names and coordinates are automatically used for the labels, ticks and animation controls of the figure. + +```python +import plotly.express as px +import xarray as xr +# Load xarray from dataset included in the xarray tutorial +ds = xr.tutorial.open_dataset('air_temperature').air[:20] +fig = px.imshow(ds, animation_frame='time', zmin=220, zmax=300, color_continuous_scale='RdBu_r') +fig.show() +``` + +### Combining animations and facets + +It is possible to view 4-dimensional datasets (for example, 3-D images evolving with time) using a combination of `animation_frame` and `facet_col`. + +```python +import plotly.express as px +from skimage import io +from skimage.data import image_fetcher +path = image_fetcher.fetch('data/cells.tif') +data = io.imread(path) +data = data.reshape((15, 4, 256, 256))[5:] +fig = px.imshow(data, animation_frame=0, facet_col=1, binary_string=True) +fig.show() +``` + #### Reference See https://plotly.com/python/reference/image/ for more information and chart attribute options! diff --git a/doc/requirements.txt b/doc/requirements.txt index 71414c34ee..bf6e717f03 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -28,4 +28,5 @@ pyarrow cufflinks==0.17.3 kaleido umap-learn +pooch wget diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 88713e5436..27d1bc7349 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -1,9 +1,10 @@ import plotly.graph_objs as go from _plotly_utils.basevalidators import ColorscaleValidator -from ._core import apply_default_cascade +from ._core import apply_default_cascade, init_figure, configure_animation_controls from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types import pandas as pd import numpy as np +import itertools from plotly.utils import image_array_to_data_uri try: @@ -60,6 +61,11 @@ def imshow( labels={}, x=None, y=None, + animation_frame=None, + facet_col=None, + facet_col_wrap=None, + facet_col_spacing=None, + facet_row_spacing=None, color_continuous_scale=None, color_continuous_midpoint=None, range_color=None, @@ -113,6 +119,26 @@ def imshow( their lengths must match the lengths of the second and first dimensions of the img argument. They are auto-populated if the input is an xarray. + animation_frame: int or str, optional (default None) + axis number along which the image array is sliced to create an animation plot. + If `img` is an xarray, `animation_frame` can be the name of one the dimensions. + + facet_col: int or str, optional (default None) + axis number along which the image array is sliced to create a facetted plot. + If `img` is an xarray, `facet_col` can be the name of one the dimensions. + + facet_col_wrap: int + Maximum number of facet columns. Wraps the column variable at this width, + so that the column facets span multiple rows. + Ignored if `facet_col` is None. + + facet_col_spacing: float between 0 and 1 + Spacing between facet columns, in paper units. Default is 0.02. + + facet_row_spacing: float between 0 and 1 + Spacing between facet rows created when ``facet_col_wrap`` is used, in + paper units. Default is 0.0.7. + color_continuous_scale : str or list of str colormap used to map scalar data to colors (for a 2D image). This parameter is not used for RGB or RGBA images. If a string is provided, it should be the name @@ -204,11 +230,45 @@ def imshow( args = locals() apply_default_cascade(args) labels = labels.copy() + nslices_facet = 1 + if facet_col is not None: + if isinstance(facet_col, str): + facet_col = img.dims.index(facet_col) + nslices_facet = img.shape[facet_col] + facet_slices = range(nslices_facet) + ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices_facet + nrows = ( + nslices_facet // ncols + 1 + if nslices_facet % ncols + else nslices_facet // ncols + ) + else: + nrows = 1 + ncols = 1 + if animation_frame is not None: + if isinstance(animation_frame, str): + animation_frame = img.dims.index(animation_frame) + nslices_animation = img.shape[animation_frame] + animation_slices = range(nslices_animation) + slice_dimensions = (facet_col is not None) + ( + animation_frame is not None + ) # 0, 1, or 2 + facet_label = None + animation_label = None img_is_xarray = False # ----- Define x and y, set labels if img is an xarray ------------------- if xarray_imported and isinstance(img, xarray.DataArray): + dims = list(img.dims) img_is_xarray = True - y_label, x_label = img.dims[0], img.dims[1] + if facet_col is not None: + facet_slices = img.coords[img.dims[facet_col]].values + _ = dims.pop(facet_col) + facet_label = img.dims[facet_col] + if animation_frame is not None: + animation_slices = img.coords[img.dims[animation_frame]].values + _ = dims.pop(animation_frame) + animation_label = img.dims[animation_frame] + y_label, x_label = dims[0], dims[1] # np.datetime64 is not handled correctly by go.Heatmap for ax in [x_label, y_label]: if np.issubdtype(img.coords[ax].dtype, np.datetime64): @@ -223,6 +283,10 @@ def imshow( labels["x"] = x_label if labels.get("y", None) is None: labels["y"] = y_label + if labels.get("animation_frame", None) is None: + labels["animation_frame"] = animation_label + if labels.get("facet_col", None) is None: + labels["facet_col"] = facet_label if labels.get("color", None) is None: labels["color"] = xarray.plot.utils.label_from_attrs(img) labels["color"] = labels["color"].replace("\n", "
") @@ -257,10 +321,29 @@ def imshow( # --------------- Starting from here img is always a numpy array -------- img = np.asanyarray(img) + # Reshape array so that animation dimension comes first, then facets, then images + if facet_col is not None: + img = np.moveaxis(img, facet_col, 0) + if animation_frame is not None and animation_frame < facet_col: + animation_frame += 1 + facet_col = True + if animation_frame is not None: + img = np.moveaxis(img, animation_frame, 0) + animation_frame = True + args["animation_frame"] = ( + "animation_frame" + if labels.get("animation_frame") is None + else labels["animation_frame"] + ) + iterables = () + if animation_frame is not None: + iterables += (range(nslices_animation),) + if facet_col is not None: + iterables += (range(nslices_facet),) # Default behaviour of binary_string: True for RGB images, False for 2D if binary_string is None: - binary_string = img.ndim >= 3 and not is_dataframe + binary_string = img.ndim >= (3 + slice_dimensions) and not is_dataframe # Cast bools to uint8 (also one byte) if img.dtype == np.bool: @@ -272,7 +355,7 @@ def imshow( # -------- Contrast rescaling: either minmax or infer ------------------ if contrast_rescaling is None: - contrast_rescaling = "minmax" if img.ndim == 2 else "infer" + contrast_rescaling = "minmax" if img.ndim == (2 + slice_dimensions) else "infer" # We try to set zmin and zmax only if necessary, because traces have good defaults if contrast_rescaling == "minmax": @@ -288,19 +371,24 @@ def imshow( if zmin is None and zmax is not None: zmin = 0 - # For 2d data, use Heatmap trace, unless binary_string is True - if img.ndim == 2 and not binary_string: - if y is not None and img.shape[0] != len(y): + # For 2d data, use Heatmap trace, unless binary_string is True + if img.ndim == 2 + slice_dimensions and not binary_string: + y_index = slice_dimensions + if y is not None and img.shape[y_index] != len(y): raise ValueError( "The length of the y vector must match the length of the first " + "dimension of the img matrix." ) - if x is not None and img.shape[1] != len(x): + x_index = slice_dimensions + 1 + if x is not None and img.shape[x_index] != len(x): raise ValueError( "The length of the x vector must match the length of the second " + "dimension of the img matrix." ) - trace = go.Heatmap(x=x, y=y, z=img, coloraxis="coloraxis1") + traces = [ + go.Heatmap(x=x, y=y, z=img[index_tup], coloraxis="coloraxis1", name=str(i)) + for i, index_tup in enumerate(itertools.product(*iterables)) + ] autorange = True if origin == "lower" else "reversed" layout = dict(yaxis=dict(autorange=autorange)) if aspect == "equal": @@ -319,7 +407,10 @@ def imshow( layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"]) # For 2D+RGB data, use Image trace - elif img.ndim == 3 and img.shape[-1] in [3, 4] or (img.ndim == 2 and binary_string): + elif ( + img.ndim >= 3 + and (img.shape[-1] in [3, 4] or slice_dimensions and binary_string) + ) or (img.ndim == 2 and binary_string): rescale_image = True # to check whether image has been modified if zmin is not None and zmax is not None: zmin, zmax = ( @@ -366,12 +457,12 @@ def imshow( if zmin is None and zmax is None: # no rescaling, faster img_rescaled = img rescale_image = False - elif img.ndim == 2: + elif img.ndim == 2 + slice_dimensions: # single-channel image img_rescaled = rescale_intensity( img, in_range=(zmin[0], zmax[0]), out_range=np.uint8 ) else: - img_rescaled = np.dstack( + img_rescaled = np.stack( [ rescale_intensity( img[..., ch], @@ -379,27 +470,38 @@ def imshow( out_range=np.uint8, ) for ch in range(img.shape[-1]) - ] + ], + axis=-1, ) - img_str = image_array_to_data_uri( - img_rescaled, - backend=binary_backend, - compression=binary_compression_level, - ext=binary_format, - ) - trace = go.Image(source=img_str, x0=x0, y0=y0, dx=dx, dy=dy) + img_str = [ + image_array_to_data_uri( + img_rescaled[index_tup], + backend=binary_backend, + compression=binary_compression_level, + ext=binary_format, + ) + for index_tup in itertools.product(*iterables) + ] + + traces = [ + go.Image(source=img_str_slice, name=str(i), x0=x0, y0=y0, dx=dx, dy=dy) + for i, img_str_slice in enumerate(img_str) + ] else: colormodel = "rgb" if img.shape[-1] == 3 else "rgba256" - trace = go.Image( - z=img, - zmin=zmin, - zmax=zmax, - colormodel=colormodel, - x0=x0, - y0=y0, - dx=dx, - dy=dy, - ) + traces = [ + go.Image( + z=img[index_tup], + zmin=zmin, + zmax=zmax, + colormodel=colormodel, + x0=x0, + y0=y0, + dx=dx, + dy=dy, + ) + for index_tup in itertools.product(*iterables) + ] layout = {} if origin == "lower" or (dy is not None and dy < 0): layout["yaxis"] = dict(autorange=True) @@ -408,19 +510,44 @@ def imshow( else: raise ValueError( "px.imshow only accepts 2D single-channel, RGB or RGBA images. " - "An image of shape %s was provided" % str(img.shape) + "An image of shape %s was provided." + "Alternatively, 3- or 4-D single or multichannel datasets can be" + "visualized using the `facet_col` or/and `animation_frame` arguments." + % str(img.shape) ) - layout_patch = dict() + # Now build figure + col_labels = [] + if facet_col is not None: + slice_label = ( + "facet_col" if labels.get("facet_col") is None else labels["facet_col"] + ) + col_labels = ["%s = %d" % (slice_label, i) for i in facet_slices] + fig = init_figure(args, "xy", [], nrows, ncols, col_labels, []) for attr_name in ["height", "width"]: if args[attr_name]: - layout_patch[attr_name] = args[attr_name] + layout[attr_name] = args[attr_name] if args["title"]: - layout_patch["title_text"] = args["title"] + layout["title_text"] = args["title"] elif args["template"].layout.margin.t is None: - layout_patch["margin"] = {"t": 60} - fig = go.Figure(data=trace, layout=layout) - fig.update_layout(layout_patch) + layout["margin"] = {"t": 60} + + frame_list = [] + for index, trace in enumerate(traces): + if (facet_col and index < nrows * ncols) or index == 0: + fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1) + if animation_frame is not None: + for i, index in zip(range(nslices_animation), animation_slices): + frame_list.append( + dict( + data=traces[nslices_facet * i : nslices_facet * (i + 1)], + layout=layout, + name=str(index), + ) + ) + if animation_frame: + fig.frames = frame_list + fig.update_layout(layout) # Hover name, z or color if binary_string and rescale_image and not np.all(img == img_rescaled): # we rescaled the image, hence z is not displayed in hover since it does @@ -449,5 +576,6 @@ def imshow( fig.update_xaxes(title_text=labels["x"]) if labels["y"]: fig.update_yaxes(title_text=labels["y"]) + configure_animation_controls(args, go.Image, fig) fig.update_layout(template=args["template"], overwrite=True) return fig diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py index 313267aacb..912b4151ab 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -172,14 +172,26 @@ def test_zmin_zmax_range_color_source(): assert fig1 == fig2 -def test_imshow_xarray(): +@pytest.mark.parametrize("binary_string", [False, True]) +def test_imshow_xarray(binary_string): img = np.random.random((20, 30)) da = xr.DataArray(img, dims=["dim_rows", "dim_cols"]) - fig = px.imshow(da) + fig = px.imshow(da, binary_string=binary_string) # Dimensions are used for axis labels and coordinates assert fig.layout.xaxis.title.text == "dim_cols" assert fig.layout.yaxis.title.text == "dim_rows" - assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_cols"])) + if not binary_string: + assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_cols"])) + + +def test_imshow_xarray_slicethrough(): + img = np.random.random((8, 9, 10)) + da = xr.DataArray(img, dims=["dim_0", "dim_1", "dim_2"]) + fig = px.imshow(da, animation_frame="dim_0") + # Dimensions are used for axis labels and coordinates + assert fig.layout.xaxis.title.text == "dim_2" + assert fig.layout.yaxis.title.text == "dim_1" + assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"])) def test_imshow_labels_and_ranges(): @@ -346,3 +358,50 @@ def test_imshow_hovertemplate(binary_string): fig.data[0].hovertemplate == "x: %{x}
y: %{y}
color: %{z}" ) + + +@pytest.mark.parametrize("facet_col", [0, 1, 2, -1]) +@pytest.mark.parametrize("binary_string", [False, True]) +def test_facet_col(facet_col, binary_string): + img = np.random.randint(255, size=(10, 9, 8)) + facet_col_wrap = 3 + fig = px.imshow( + img, + facet_col=facet_col, + facet_col_wrap=facet_col_wrap, + binary_string=binary_string, + ) + nslices = img.shape[facet_col] + ncols = int(facet_col_wrap) + nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols + nmax = ncols * nrows + assert "yaxis%d" % nmax in fig.layout + assert "yaxis%d" % (nmax + 1) not in fig.layout + assert len(fig.data) == nslices + + +@pytest.mark.parametrize("animation_frame", [0, 1, 2, -1]) +@pytest.mark.parametrize("binary_string", [False, True]) +def test_animation_frame_grayscale(animation_frame, binary_string): + img = np.random.randint(255, size=(10, 9, 8)).astype(np.uint8) + fig = px.imshow(img, animation_frame=animation_frame, binary_string=binary_string,) + nslices = img.shape[animation_frame] + assert len(fig.frames) == nslices + + +@pytest.mark.parametrize("animation_frame", [0, 1, 2]) +@pytest.mark.parametrize("binary_string", [False, True]) +def test_animation_frame_rgb(animation_frame, binary_string): + img = np.random.randint(255, size=(10, 9, 8, 3)).astype(np.uint8) + fig = px.imshow(img, animation_frame=animation_frame, binary_string=binary_string,) + nslices = img.shape[animation_frame] + assert len(fig.frames) == nslices + + +@pytest.mark.parametrize("binary_string", [False, True]) +def test_animation_and_facet(binary_string): + img = np.random.randint(255, size=(10, 9, 8, 7)).astype(np.uint8) + fig = px.imshow(img, animation_frame=0, facet_col=1, binary_string=binary_string) + nslices = img.shape[0] + assert len(fig.frames) == nslices + assert len(fig.data) == img.shape[1]