From afb5c4dfbc0f329dff2256f6b961f884c5fff3fa Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 3 Sep 2020 11:12:37 +0200 Subject: [PATCH 01/28] use init_figure from main px core --- packages/python/plotly/plotly/express/_imshow.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 218a824e15..d55ee54dfa 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -1,6 +1,6 @@ import plotly.graph_objs as go from _plotly_utils.basevalidators import ColorscaleValidator -from ._core import apply_default_cascade +from ._core import apply_default_cascade, init_figure from io import BytesIO import base64 from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types @@ -133,6 +133,9 @@ def imshow( labels={}, x=None, y=None, + animation_frame=False, + facet_col=False, + facet_col_wrap=None, color_continuous_scale=None, color_continuous_midpoint=None, range_color=None, @@ -277,6 +280,14 @@ def imshow( args = locals() apply_default_cascade(args) labels = labels.copy() + if facet_col: + nslices = img.shape[-1] + ncols = facet_col_wrap + nrows = nslices / ncols + else: + nrows = 1 + ncols = 1 + fig = init_figure(args, 'xy', [], nrows, ncols, [], []) # ----- Define x and y, set labels if img is an xarray ------------------- if xarray_imported and isinstance(img, xarray.DataArray): if binary_string: @@ -449,7 +460,8 @@ def imshow( layout_patch["title_text"] = args["title"] elif args["template"].layout.margin.t is None: layout_patch["margin"] = {"t": 60} - fig = go.Figure(data=trace, layout=layout) + fig.add_trace(trace) + fig.update_layout(layout) fig.update_layout(layout_patch) # Hover name, z or color if binary_string and rescale_image and not np.all(img == img_rescaled): From 8be8ca04978214608ad75bfa7c0b339f30710128 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 3 Sep 2020 14:33:18 +0200 Subject: [PATCH 02/28] WIP: add facet_col arg to imshow --- .../python/plotly/plotly/express/_imshow.py | 79 ++++++++++++++----- .../tests/test_core/test_px/test_imshow.py | 21 +++++ 2 files changed, 80 insertions(+), 20 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index d55ee54dfa..edc2e372fe 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -134,7 +134,7 @@ def imshow( x=None, y=None, animation_frame=False, - facet_col=False, + facet_col=None, facet_col_wrap=None, color_continuous_scale=None, color_continuous_midpoint=None, @@ -189,6 +189,14 @@ def imshow( their lengths must match the lengths of the second and first dimensions of the img argument. They are auto-populated if the input is an xarray. + facet_col: int, optional (default None) + axis number along which the image array is slices to create a facetted plot. + + facet_col_wrap: int + Maximum number of facet columns. Wraps the column variable at this width, + so that the column facets span multiple rows. + Ignored if `facet_col` is None. + color_continuous_scale : str or list of str colormap used to map scalar data to colors (for a 2D image). This parameter is not used for RGB or RGBA images. If a string is provided, it should be the name @@ -280,14 +288,14 @@ def imshow( args = locals() apply_default_cascade(args) labels = labels.copy() - if facet_col: - nslices = img.shape[-1] - ncols = facet_col_wrap - nrows = nslices / ncols + if facet_col is not None: + nslices = img.shape[facet_col] + ncols = int(facet_col_wrap) + nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols else: nrows = 1 ncols = 1 - fig = init_figure(args, 'xy', [], nrows, ncols, [], []) + fig = init_figure(args, "xy", [], nrows, ncols, [], []) # ----- Define x and y, set labels if img is an xarray ------------------- if xarray_imported and isinstance(img, xarray.DataArray): if binary_string: @@ -345,10 +353,16 @@ def imshow( # --------------- Starting from here img is always a numpy array -------- img = np.asanyarray(img) + if facet_col is not None: + img = np.moveaxis(img, facet_col, 0) + facet_col = True # Default behaviour of binary_string: True for RGB images, False for 2D if binary_string is None: - binary_string = img.ndim >= 3 and not is_dataframe + if facet_col: + binary_string = img.ndim >= 4 and not is_dataframe + else: + binary_string = img.ndim >= 3 and not is_dataframe # Cast bools to uint8 (also one byte) if img.dtype == np.bool: @@ -377,7 +391,7 @@ def imshow( zmin = 0 # For 2d data, use Heatmap trace, unless binary_string is True - if img.ndim == 2 and not binary_string: + if (img.ndim == 2 or (img.ndim == 3 and facet_col)) and not binary_string: if y is not None and img.shape[0] != len(y): raise ValueError( "The length of the y vector must match the length of the first " @@ -388,7 +402,13 @@ def imshow( "The length of the x vector must match the length of the second " + "dimension of the img matrix." ) - trace = go.Heatmap(x=x, y=y, z=img, coloraxis="coloraxis1") + if facet_col: + traces = [ + go.Heatmap(x=x, y=y, z=img_slice, coloraxis="coloraxis1") + for img_slice in img + ] + else: + traces = [go.Heatmap(x=x, y=y, z=img, coloraxis="coloraxis1")] autorange = True if origin == "lower" else "reversed" layout = dict(yaxis=dict(autorange=autorange)) if aspect == "equal": @@ -407,7 +427,11 @@ def imshow( layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"]) # For 2D+RGB data, use Image trace - elif img.ndim == 3 and img.shape[-1] in [3, 4] or (img.ndim == 2 and binary_string): + elif ( + img.ndim == 3 + and (img.shape[-1] in [3, 4] or (facet_col and binary_string)) + or (img.ndim == 2 and binary_string) + ): rescale_image = True # to check whether image has been modified if zmin is not None and zmax is not None: zmin, zmax = ( @@ -418,7 +442,7 @@ def imshow( if zmin is None and zmax is None: # no rescaling, faster img_rescaled = img rescale_image = False - elif img.ndim == 2: + elif img.ndim == 2 or (img.ndim == 3 and facet_col): img_rescaled = rescale_intensity( img, in_range=(zmin[0], zmax[0]), out_range=np.uint8 ) @@ -433,16 +457,30 @@ def imshow( for ch in range(img.shape[-1]) ] ) - img_str = _array_to_b64str( - img_rescaled, - backend=binary_backend, - compression=binary_compression_level, - ext=binary_format, - ) - trace = go.Image(source=img_str) + if facet_col: + img_str = [ + _array_to_b64str( + img_rescaled_slice, + backend=binary_backend, + compression=binary_compression_level, + ext=binary_format, + ) + for img_rescaled_slice in img_rescaled + ] + + else: + img_str = [ + _array_to_b64str( + img_rescaled, + backend=binary_backend, + compression=binary_compression_level, + ext=binary_format, + ) + ] + traces = [go.Image(source=img_str_slice) for img_str_slice in img_str] else: colormodel = "rgb" if img.shape[-1] == 3 else "rgba256" - trace = go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel) + traces = [go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)] layout = {} if origin == "lower": layout["yaxis"] = dict(autorange=True) @@ -460,7 +498,8 @@ def imshow( layout_patch["title_text"] = args["title"] elif args["template"].layout.margin.t is None: layout_patch["margin"] = {"t": 60} - fig.add_trace(trace) + for index, trace in enumerate(traces): + fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1) fig.update_layout(layout) fig.update_layout(layout_patch) # Hover name, z or color diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py index 84e39c7833..1de6b32076 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -314,3 +314,24 @@ def test_imshow_hovertemplate(binary_string): fig.data[0].hovertemplate == "x: %{x}
y: %{y}
color: %{z}" ) + + +@pytest.mark.parametrize("facet_col", [0, 1, 2, -1]) +@pytest.mark.parametrize("binary_string", [False, True]) +def test_facet_col(facet_col, binary_string): + img = np.random.randint(255, size=(10, 9, 8)) + facet_col_wrap = 3 + fig = px.imshow( + img, + facet_col=facet_col, + facet_col_wrap=facet_col_wrap, + binary_string=binary_string, + ) + if facet_col is not None: + nslices = img.shape[facet_col] + ncols = int(facet_col_wrap) + nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols + nmax = ncols * nrows + assert "yaxis%d" % nmax in fig.layout + assert "yaxis%d" % (nmax + 1) not in fig.layout + assert len(fig.data) == nslices From d236bc227c1ef058fcd646b6df0d872396b8f44f Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Fri, 4 Sep 2020 23:38:13 +0200 Subject: [PATCH 03/28] animations work for grayscale images, with or without binary string --- .../python/plotly/plotly/express/_imshow.py | 44 +++++++++++++------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index edc2e372fe..db0ee33aed 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -1,6 +1,6 @@ import plotly.graph_objs as go from _plotly_utils.basevalidators import ColorscaleValidator -from ._core import apply_default_cascade, init_figure +from ._core import apply_default_cascade, init_figure, configure_animation_controls from io import BytesIO import base64 from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types @@ -133,7 +133,7 @@ def imshow( labels={}, x=None, y=None, - animation_frame=False, + animation_frame=None, facet_col=None, facet_col_wrap=None, color_continuous_scale=None, @@ -353,13 +353,21 @@ def imshow( # --------------- Starting from here img is always a numpy array -------- img = np.asanyarray(img) + slice_through = False if facet_col is not None: img = np.moveaxis(img, facet_col, 0) facet_col = True - + slice_through = True + if animation_frame is not None: + img = np.moveaxis(img, animation_frame, 0) + animation_frame = True + args["animation_frame"] = "plane" + slice_through = True + + print("slice_through", slice_through) # Default behaviour of binary_string: True for RGB images, False for 2D if binary_string is None: - if facet_col: + if slice_through: binary_string = img.ndim >= 4 and not is_dataframe else: binary_string = img.ndim >= 3 and not is_dataframe @@ -391,7 +399,7 @@ def imshow( zmin = 0 # For 2d data, use Heatmap trace, unless binary_string is True - if (img.ndim == 2 or (img.ndim == 3 and facet_col)) and not binary_string: + if (img.ndim == 2 or (img.ndim == 3 and slice_through)) and not binary_string: if y is not None and img.shape[0] != len(y): raise ValueError( "The length of the y vector must match the length of the first " @@ -402,10 +410,10 @@ def imshow( "The length of the x vector must match the length of the second " + "dimension of the img matrix." ) - if facet_col: + if slice_through: traces = [ - go.Heatmap(x=x, y=y, z=img_slice, coloraxis="coloraxis1") - for img_slice in img + go.Heatmap(x=x, y=y, z=img_slice, coloraxis="coloraxis1", name=str(i)) + for i, img_slice in enumerate(img) ] else: traces = [go.Heatmap(x=x, y=y, z=img, coloraxis="coloraxis1")] @@ -429,7 +437,7 @@ def imshow( # For 2D+RGB data, use Image trace elif ( img.ndim == 3 - and (img.shape[-1] in [3, 4] or (facet_col and binary_string)) + and (img.shape[-1] in [3, 4] or (slice_through and binary_string)) or (img.ndim == 2 and binary_string) ): rescale_image = True # to check whether image has been modified @@ -442,7 +450,7 @@ def imshow( if zmin is None and zmax is None: # no rescaling, faster img_rescaled = img rescale_image = False - elif img.ndim == 2 or (img.ndim == 3 and facet_col): + elif img.ndim == 2 or (img.ndim == 3 and slice_through): img_rescaled = rescale_intensity( img, in_range=(zmin[0], zmax[0]), out_range=np.uint8 ) @@ -457,7 +465,7 @@ def imshow( for ch in range(img.shape[-1]) ] ) - if facet_col: + if slice_through: img_str = [ _array_to_b64str( img_rescaled_slice, @@ -477,7 +485,7 @@ def imshow( ext=binary_format, ) ] - traces = [go.Image(source=img_str_slice) for img_str_slice in img_str] + traces = [go.Image(source=img_str_slice, name=str(i)) for i, img_str_slice in enumerate(img_str)] else: colormodel = "rgb" if img.shape[-1] == 3 else "rgba256" traces = [go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)] @@ -498,8 +506,15 @@ def imshow( layout_patch["title_text"] = args["title"] elif args["template"].layout.margin.t is None: layout_patch["margin"] = {"t": 60} + + frame_list = [] for index, trace in enumerate(traces): - fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1) + if facet_col or index == 0: + fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1) + if animation_frame: + frame_list.append(dict(data=trace, layout=layout, name=str(index))) + if animation_frame: + fig.frames = frame_list fig.update_layout(layout) fig.update_layout(layout_patch) # Hover name, z or color @@ -530,5 +545,6 @@ def imshow( fig.update_xaxes(title_text=labels["x"]) if labels["y"]: fig.update_yaxes(title_text=labels["y"]) - fig.update_layout(template=args["template"], overwrite=True) + configure_animation_controls(args, go.Image, fig) + #fig.update_layout(template=args["template"], overwrite=True) return fig From c8e852e3b58ee0be5f90b98105bbe3af0617f910 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Sat, 5 Sep 2020 22:17:18 +0200 Subject: [PATCH 04/28] animations now work + tests --- .../python/plotly/plotly/express/_imshow.py | 33 ++++++++++++------ .../tests/test_core/test_px/test_imshow.py | 34 ++++++++++++++----- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index db0ee33aed..812190c7bd 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -364,7 +364,6 @@ def imshow( args["animation_frame"] = "plane" slice_through = True - print("slice_through", slice_through) # Default behaviour of binary_string: True for RGB images, False for 2D if binary_string is None: if slice_through: @@ -382,7 +381,11 @@ def imshow( # -------- Contrast rescaling: either minmax or infer ------------------ if contrast_rescaling is None: - contrast_rescaling = "minmax" if img.ndim == 2 else "infer" + contrast_rescaling = ( + "minmax" + if (img.ndim == 2 or (img.ndim == 3 and slice_through)) + else "infer" + ) # We try to set zmin and zmax only if necessary, because traces have good defaults if contrast_rescaling == "minmax": @@ -436,10 +439,8 @@ def imshow( # For 2D+RGB data, use Image trace elif ( - img.ndim == 3 - and (img.shape[-1] in [3, 4] or (slice_through and binary_string)) - or (img.ndim == 2 and binary_string) - ): + img.ndim >= 3 and (img.shape[-1] in [3, 4] or slice_through and binary_string) + ) or (img.ndim == 2 and binary_string): rescale_image = True # to check whether image has been modified if zmin is not None and zmax is not None: zmin, zmax = ( @@ -455,7 +456,7 @@ def imshow( img, in_range=(zmin[0], zmax[0]), out_range=np.uint8 ) else: - img_rescaled = np.dstack( + img_rescaled = np.stack( [ rescale_intensity( img[..., ch], @@ -463,7 +464,8 @@ def imshow( out_range=np.uint8, ) for ch in range(img.shape[-1]) - ] + ], + axis=-1, ) if slice_through: img_str = [ @@ -485,10 +487,19 @@ def imshow( ext=binary_format, ) ] - traces = [go.Image(source=img_str_slice, name=str(i)) for i, img_str_slice in enumerate(img_str)] + traces = [ + go.Image(source=img_str_slice, name=str(i)) + for i, img_str_slice in enumerate(img_str) + ] else: colormodel = "rgb" if img.shape[-1] == 3 else "rgba256" - traces = [go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)] + if slice_through: + traces = [ + go.Image(z=img_slice, zmin=zmin, zmax=zmax, colormodel=colormodel) + for img_slice in img + ] + else: + traces = [go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)] layout = {} if origin == "lower": layout["yaxis"] = dict(autorange=True) @@ -546,5 +557,5 @@ def imshow( if labels["y"]: fig.update_yaxes(title_text=labels["y"]) configure_animation_controls(args, go.Image, fig) - #fig.update_layout(template=args["template"], overwrite=True) + # fig.update_layout(template=args["template"], overwrite=True) return fig diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py index 1de6b32076..7e2bcca413 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -327,11 +327,29 @@ def test_facet_col(facet_col, binary_string): facet_col_wrap=facet_col_wrap, binary_string=binary_string, ) - if facet_col is not None: - nslices = img.shape[facet_col] - ncols = int(facet_col_wrap) - nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols - nmax = ncols * nrows - assert "yaxis%d" % nmax in fig.layout - assert "yaxis%d" % (nmax + 1) not in fig.layout - assert len(fig.data) == nslices + nslices = img.shape[facet_col] + ncols = int(facet_col_wrap) + nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols + nmax = ncols * nrows + assert "yaxis%d" % nmax in fig.layout + assert "yaxis%d" % (nmax + 1) not in fig.layout + assert len(fig.data) == nslices + + +@pytest.mark.parametrize("animation_frame", [0, 1, 2, -1]) +@pytest.mark.parametrize("binary_string", [False, True]) +def test_animation_frame_grayscale(animation_frame, binary_string): + img = np.random.randint(255, size=(10, 9, 8)).astype(np.uint8) + fig = px.imshow(img, animation_frame=animation_frame, binary_string=binary_string,) + nslices = img.shape[animation_frame] + assert len(fig.frames) == nslices + + +@pytest.mark.parametrize("animation_frame", [0, 1, 2]) +@pytest.mark.parametrize("binary_string", [False, True]) +def test_animation_frame_rgb(animation_frame, binary_string): + img = np.random.randint(255, size=(10, 9, 8, 3)).astype(np.uint8) + fig = px.imshow(img, animation_frame=animation_frame, binary_string=binary_string,) + print(binary_string) + nslices = img.shape[animation_frame] + assert len(fig.frames) == nslices From 12cec34f1a5fb1ee5af15b5261ba34bc42e20c5a Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Sun, 6 Sep 2020 14:59:17 +0200 Subject: [PATCH 05/28] docs on facets and animations + add subplots titles --- doc/python/imshow.md | 52 +++++++++++++++++++ .../python/plotly/plotly/express/_imshow.py | 10 ++-- 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/doc/python/imshow.md b/doc/python/imshow.md index f8aa497083..533fcd88e2 100644 --- a/doc/python/imshow.md +++ b/doc/python/imshow.md @@ -399,5 +399,57 @@ for compression_level in range(0, 9): fig.show() ``` +### Exploring 3-D images and timeseries with `facet_col` + +*Introduced in plotly 4.11* + +For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by representing its different planes as facets. The `facet_col` argument specifies along which axes the image is sliced through to make the facets. With `facet_col_wrap` , one can set the maximum number of columns. + +It is recommended to use `binary_string=True` for facetted plots of images in order to keep a small figure size and a short rendering time. + +See the [tutorial on facet plots](/python/facet-plots/) for more information on creating and styling facet plots. + +```python +import plotly.express as px +from skimage import io +from skimage.data import image_fetcher +path = image_fetcher.fetch('data/cells.tif') +data = io.imread(path) +img = data[25:40] +fig = px.imshow(img, facet_col=0, binary_string=True, facet_col_wrap=5, height=700) +fig.show() +``` + +```python +import plotly.express as px +from skimage import io +from skimage.data import image_fetcher +path = image_fetcher.fetch('data/cells.tif') +data = io.imread(path) +img = data[25:40] +fig = px.imshow(img, facet_col=0, binary_string=True, facet_col_wrap=5) +# To have square facets one needs to unmatch axes +fig.update_xaxes(matches=None) +fig.update_yaxes(matches=None) +fig.show() +``` + +### Exploring 3-D images and timeseries with `animation_frame` + +*Introduced in plotly 4.11* + +For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by sliding through its different planes in an animation. The `animation_frame` argument of `px.imshow` sets the axis along which the 3-D image is sliced in the animation. + +```python +import plotly.express as px +from skimage import io +from skimage.data import image_fetcher +path = image_fetcher.fetch('data/cells.tif') +data = io.imread(path) +img = data[25:40] +fig = px.imshow(img, animation_frame=0, binary_string=True) +fig.show() +``` + #### Reference See https://plotly.com/python/reference/#image for more information and chart attribute options! diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 812190c7bd..bddb543b5d 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -288,14 +288,17 @@ def imshow( args = locals() apply_default_cascade(args) labels = labels.copy() + col_labels = [] if facet_col is not None: nslices = img.shape[facet_col] - ncols = int(facet_col_wrap) + ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols + col_labels = ["plane = %d" % i for i in range(nslices)] else: nrows = 1 ncols = 1 - fig = init_figure(args, "xy", [], nrows, ncols, [], []) + slice_through = (facet_col is not None) or (animation_frame is not None) + fig = init_figure(args, "xy", [], nrows, ncols, col_labels, []) # ----- Define x and y, set labels if img is an xarray ------------------- if xarray_imported and isinstance(img, xarray.DataArray): if binary_string: @@ -353,16 +356,13 @@ def imshow( # --------------- Starting from here img is always a numpy array -------- img = np.asanyarray(img) - slice_through = False if facet_col is not None: img = np.moveaxis(img, facet_col, 0) facet_col = True - slice_through = True if animation_frame is not None: img = np.moveaxis(img, animation_frame, 0) animation_frame = True args["animation_frame"] = "plane" - slice_through = True # Default behaviour of binary_string: True for RGB images, False for 2D if binary_string is None: From 7a3a9f46b66a5a111dcfa101cd3ea83e4b6db6c9 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 7 Sep 2020 13:44:48 +0200 Subject: [PATCH 06/28] solved old unnoticed conflict --- doc/python/imshow.md | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/doc/python/imshow.md b/doc/python/imshow.md index e2f25a8e6a..8e357a7071 100644 --- a/doc/python/imshow.md +++ b/doc/python/imshow.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.2' - jupytext_version: 1.4.2 + jupytext_version: 1.3.0 kernelspec: display_name: Python 3 language: python @@ -20,7 +20,7 @@ jupyter: name: python nbconvert_exporter: python pygments_lexer: ipython3 - version: 3.7.7 + version: 3.7.3 plotly: description: How to display image data in Python with Plotly. display_as: scientific @@ -452,8 +452,5 @@ fig.show() ``` #### Reference -<<<<<<< HEAD -See https://plotly.com/python/reference/#image for more information and chart attribute options! -======= + See https://plotly.com/python/reference/image/ for more information and chart attribute options! ->>>>>>> doc-prod From b689a2fd7d6edb4845cf52171dcf6126d0291043 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 7 Sep 2020 14:54:33 +0200 Subject: [PATCH 07/28] attempt to use imshow with binary strings and xarrays --- .../python/plotly/plotly/express/_imshow.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index bddb543b5d..7ca491dc16 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -301,12 +301,12 @@ def imshow( fig = init_figure(args, "xy", [], nrows, ncols, col_labels, []) # ----- Define x and y, set labels if img is an xarray ------------------- if xarray_imported and isinstance(img, xarray.DataArray): - if binary_string: - raise ValueError( - "It is not possible to use binary image strings for xarrays." - "Please pass your data as a numpy array instead using" - "`img.values`" - ) + # if binary_string: + # raise ValueError( + # "It is not possible to use binary image strings for xarrays." + # "Please pass your data as a numpy array instead using" + # "`img.values`" + # ) y_label, x_label = img.dims[0], img.dims[1] # np.datetime64 is not handled correctly by go.Heatmap for ax in [x_label, y_label]: @@ -506,7 +506,10 @@ def imshow( else: raise ValueError( "px.imshow only accepts 2D single-channel, RGB or RGBA images. " - "An image of shape %s was provided" % str(img.shape) + "An image of shape %s was provided" + "Alternatively, 3-D single or multichannel datasets can be" + "visualized using the `facet_col` or `animation_frame` arguments." + % str(img.shape) ) layout_patch = dict() From fbb3f6534d52ded9efc9a2ddbc12288fc7f60c1c Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 7 Sep 2020 14:56:30 +0200 Subject: [PATCH 08/28] added test --- .../plotly/plotly/tests/test_core/test_px/test_imshow.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py index 7e2bcca413..61abd4f3a2 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -171,14 +171,16 @@ def test_zmin_zmax_range_color_source(): assert fig1 == fig2 -def test_imshow_xarray(): +@pytest.mark.parametrize("binary_string", [False, True]) +def test_imshow_xarray(binary_string): img = np.random.random((20, 30)) da = xr.DataArray(img, dims=["dim_rows", "dim_cols"]) - fig = px.imshow(da) + fig = px.imshow(da, binary_string=binary_string) # Dimensions are used for axis labels and coordinates assert fig.layout.xaxis.title.text == "dim_cols" assert fig.layout.yaxis.title.text == "dim_rows" - assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_cols"])) + if not binary_string: + assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_cols"])) def test_imshow_labels_and_ranges(): From 882810f59520de14db30d347b806d8ed7d2af191 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 7 Sep 2020 16:09:41 +0200 Subject: [PATCH 09/28] animation work for xarrays, still need to fix slider label --- doc/python/imshow.md | 13 +++++++++ .../python/plotly/plotly/express/_imshow.py | 27 ++++++++++++++++--- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/doc/python/imshow.md b/doc/python/imshow.md index 8e357a7071..c8a755f42d 100644 --- a/doc/python/imshow.md +++ b/doc/python/imshow.md @@ -451,6 +451,19 @@ fig = px.imshow(img, animation_frame=0, binary_string=True) fig.show() ``` +### Animations of xarray datasets + +*Introduced in plotly 4.11* + +```python +import plotly.express as px +import xarray as xr +# Load xarray from dataset included in the xarray tutorial +ds = xr.tutorial.open_dataset('air_temperature').air[:20] +fig = px.imshow(ds, animation_frame='lat', color_continuous_scale='RdBu_r') +fig.show() +``` + #### Reference See https://plotly.com/python/reference/image/ for more information and chart attribute options! diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 7ca491dc16..afc635cc10 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -290,6 +290,8 @@ def imshow( labels = labels.copy() col_labels = [] if facet_col is not None: + if isinstance(facet_col, str): + facet_col = img.dims.index(facet_col) nslices = img.shape[facet_col] ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols @@ -297,7 +299,11 @@ def imshow( else: nrows = 1 ncols = 1 + if animation_frame is not None: + if isinstance(animation_frame, str): + animation_frame = img.dims.index(animation_frame) slice_through = (facet_col is not None) or (animation_frame is not None) + plane_label = None fig = init_figure(args, "xy", [], nrows, ncols, col_labels, []) # ----- Define x and y, set labels if img is an xarray ------------------- if xarray_imported and isinstance(img, xarray.DataArray): @@ -307,7 +313,14 @@ def imshow( # "Please pass your data as a numpy array instead using" # "`img.values`" # ) - y_label, x_label = img.dims[0], img.dims[1] + dims = list(img.dims) + print(dims) + if slice_through: + slice_index = facet_col if facet_col is not None else animation_frame + _ = dims.pop(slice_index) + plane_label = img.dims[slice_index] + y_label, x_label = dims[0], dims[1] + print(y_label, x_label) # np.datetime64 is not handled correctly by go.Heatmap for ax in [x_label, y_label]: if np.issubdtype(img.coords[ax].dtype, np.datetime64): @@ -322,6 +335,8 @@ def imshow( labels["x"] = x_label if labels.get("y", None) is None: labels["y"] = y_label + if labels.get("plane", None) is None: + labels["plane"] = plane_label if labels.get("color", None) is None: labels["color"] = xarray.plot.utils.label_from_attrs(img) labels["color"] = labels["color"].replace("\n", "
") @@ -362,7 +377,9 @@ def imshow( if animation_frame is not None: img = np.moveaxis(img, animation_frame, 0) animation_frame = True - args["animation_frame"] = "plane" + args["animation_frame"] = ( + "plane" if labels.get("plane") is None else labels["plane"] + ) # Default behaviour of binary_string: True for RGB images, False for 2D if binary_string is None: @@ -403,12 +420,14 @@ def imshow( # For 2d data, use Heatmap trace, unless binary_string is True if (img.ndim == 2 or (img.ndim == 3 and slice_through)) and not binary_string: - if y is not None and img.shape[0] != len(y): + y_index = 1 if slice_through else 0 + if y is not None and img.shape[y_index] != len(y): raise ValueError( "The length of the y vector must match the length of the first " + "dimension of the img matrix." ) - if x is not None and img.shape[1] != len(x): + x_index = 2 if slice_through else 1 + if x is not None and img.shape[x_index] != len(x): raise ValueError( "The length of the x vector must match the length of the second " + "dimension of the img matrix." From ba65990c87073605db10fb89b64dfd89752fbccf Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 7 Sep 2020 16:26:14 +0200 Subject: [PATCH 10/28] added test with xarray and animations --- .../plotly/tests/test_core/test_px/test_imshow.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py index 61abd4f3a2..03a5c0f254 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -183,6 +183,16 @@ def test_imshow_xarray(binary_string): assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_cols"])) +def test_imshow_xarray_slicethrough(): + img = np.random.random((8, 9, 10)) + da = xr.DataArray(img, dims=["dim_0", "dim_1", "dim_2"]) + fig = px.imshow(da, animation_frame="dim_0") + # Dimensions are used for axis labels and coordinates + assert fig.layout.xaxis.title.text == "dim_2" + assert fig.layout.yaxis.title.text == "dim_1" + assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"])) + + def test_imshow_labels_and_ranges(): fig = px.imshow([[1, 2], [3, 4], [5, 6]],) assert fig.layout.xaxis.title.text is None From cf644e55f6824e29faf821b38c09dabe81ce6d23 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 7 Sep 2020 17:12:31 +0200 Subject: [PATCH 11/28] added doc --- doc/python/imshow.md | 6 ++-- .../python/plotly/plotly/express/_imshow.py | 31 ++++++++++++------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/doc/python/imshow.md b/doc/python/imshow.md index c8a755f42d..125066fcc5 100644 --- a/doc/python/imshow.md +++ b/doc/python/imshow.md @@ -403,7 +403,7 @@ fig.show() *Introduced in plotly 4.11* -For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by representing its different planes as facets. The `facet_col` argument specifies along which axes the image is sliced through to make the facets. With `facet_col_wrap` , one can set the maximum number of columns. +For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by representing its different planes as facets. The `facet_col` argument specifies along which axes the image is sliced through to make the facets. With `facet_col_wrap` , one can set the maximum number of columns. For image datasets passed as xarrays, it is also possible to give an axis name as a string for `facet_col`. It is recommended to use `binary_string=True` for facetted plots of images in order to keep a small figure size and a short rendering time. @@ -455,12 +455,14 @@ fig.show() *Introduced in plotly 4.11* +For xarray datasets, one can pass either an axis number or an axis name to `animation_frame`. Axis names and coordinates are automatically used for the labels, ticks and animation controls of the figure. + ```python import plotly.express as px import xarray as xr # Load xarray from dataset included in the xarray tutorial ds = xr.tutorial.open_dataset('air_temperature').air[:20] -fig = px.imshow(ds, animation_frame='lat', color_continuous_scale='RdBu_r') +fig = px.imshow(ds, animation_frame='time', zmin=220, zmax=300, color_continuous_scale='RdBu_r') fig.show() ``` diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index afc635cc10..d63440a73d 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -288,23 +288,23 @@ def imshow( args = locals() apply_default_cascade(args) labels = labels.copy() - col_labels = [] + nslices = 1 if facet_col is not None: if isinstance(facet_col, str): facet_col = img.dims.index(facet_col) nslices = img.shape[facet_col] ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols - col_labels = ["plane = %d" % i for i in range(nslices)] else: nrows = 1 ncols = 1 if animation_frame is not None: if isinstance(animation_frame, str): animation_frame = img.dims.index(animation_frame) + nslices = img.shape[animation_frame] slice_through = (facet_col is not None) or (animation_frame is not None) - plane_label = None - fig = init_figure(args, "xy", [], nrows, ncols, col_labels, []) + slice_label = None + slices = range(nslices) # ----- Define x and y, set labels if img is an xarray ------------------- if xarray_imported and isinstance(img, xarray.DataArray): # if binary_string: @@ -314,13 +314,12 @@ def imshow( # "`img.values`" # ) dims = list(img.dims) - print(dims) if slice_through: slice_index = facet_col if facet_col is not None else animation_frame + slices = img.coords[img.dims[slice_index]].values _ = dims.pop(slice_index) - plane_label = img.dims[slice_index] + slice_label = img.dims[slice_index] y_label, x_label = dims[0], dims[1] - print(y_label, x_label) # np.datetime64 is not handled correctly by go.Heatmap for ax in [x_label, y_label]: if np.issubdtype(img.coords[ax].dtype, np.datetime64): @@ -335,8 +334,8 @@ def imshow( labels["x"] = x_label if labels.get("y", None) is None: labels["y"] = y_label - if labels.get("plane", None) is None: - labels["plane"] = plane_label + if labels.get("slice", None) is None: + labels["slice"] = slice_label if labels.get("color", None) is None: labels["color"] = xarray.plot.utils.label_from_attrs(img) labels["color"] = labels["color"].replace("\n", "
") @@ -378,7 +377,7 @@ def imshow( img = np.moveaxis(img, animation_frame, 0) animation_frame = True args["animation_frame"] = ( - "plane" if labels.get("plane") is None else labels["plane"] + "slice" if labels.get("slice") is None else labels["slice"] ) # Default behaviour of binary_string: True for RGB images, False for 2D @@ -531,6 +530,14 @@ def imshow( % str(img.shape) ) + # Now build figure + col_labels = [] + if facet_col is not None: + slice_label = "slice" if labels.get("slice") is None else labels["slice"] + if slices is None: + slices = range(nslices) + col_labels = ["%s = %d" % (slice_label, i) for i in slices] + fig = init_figure(args, "xy", [], nrows, ncols, col_labels, []) layout_patch = dict() for attr_name in ["height", "width"]: if args[attr_name]: @@ -541,11 +548,11 @@ def imshow( layout_patch["margin"] = {"t": 60} frame_list = [] - for index, trace in enumerate(traces): + for index, (slice_index, trace) in enumerate(zip(slices, traces)): if facet_col or index == 0: fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1) if animation_frame: - frame_list.append(dict(data=trace, layout=layout, name=str(index))) + frame_list.append(dict(data=trace, layout=layout, name=str(slice_index))) if animation_frame: fig.frames = frame_list fig.update_layout(layout) From 72674b799b7dd0057e8aaeeb859622021e398f4d Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 7 Sep 2020 21:30:05 +0200 Subject: [PATCH 12/28] added pooch to doc requirements --- doc/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/requirements.txt b/doc/requirements.txt index 63887be80a..471081539b 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -28,3 +28,4 @@ pyarrow cufflinks==0.17.3 kaleido umap-learn +pooch From bd42385869ec89514de789fd9997dd89fbed4108 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 8 Sep 2020 23:24:24 +0200 Subject: [PATCH 13/28] Update packages/python/plotly/plotly/express/_imshow.py Co-authored-by: Marianne Corvellec --- packages/python/plotly/plotly/express/_imshow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index d63440a73d..8d8c2efc50 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -524,7 +524,7 @@ def imshow( else: raise ValueError( "px.imshow only accepts 2D single-channel, RGB or RGBA images. " - "An image of shape %s was provided" + "An image of shape %s was provided." "Alternatively, 3-D single or multichannel datasets can be" "visualized using the `facet_col` or `animation_frame` arguments." % str(img.shape) From fc2375b953c6074cd34d869f53b8118e8087f853 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 8 Sep 2020 23:24:34 +0200 Subject: [PATCH 14/28] Update doc/python/imshow.md Co-authored-by: Marianne Corvellec --- doc/python/imshow.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/python/imshow.md b/doc/python/imshow.md index 125066fcc5..1ed59e8251 100644 --- a/doc/python/imshow.md +++ b/doc/python/imshow.md @@ -403,7 +403,7 @@ fig.show() *Introduced in plotly 4.11* -For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by representing its different planes as facets. The `facet_col` argument specifies along which axes the image is sliced through to make the facets. With `facet_col_wrap` , one can set the maximum number of columns. For image datasets passed as xarrays, it is also possible to give an axis name as a string for `facet_col`. +For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by representing its different planes as facets. The `facet_col` argument specifies along which axis the image is sliced through to make the facets. With `facet_col_wrap`, one can set the maximum number of columns. For image datasets passed as xarrays, it is also possible to specify the axis by its name (label), thus passing a string to `facet_col`. It is recommended to use `binary_string=True` for facetted plots of images in order to keep a small figure size and a short rendering time. From a431fad2926483beacb96709d527de62ddcc2403 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 9 Sep 2020 22:48:12 +0200 Subject: [PATCH 15/28] remove commented-out code --- packages/python/plotly/plotly/express/_imshow.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index d63440a73d..1c5c619624 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -189,8 +189,13 @@ def imshow( their lengths must match the lengths of the second and first dimensions of the img argument. They are auto-populated if the input is an xarray. - facet_col: int, optional (default None) - axis number along which the image array is slices to create a facetted plot. + animation_frame: int or str, optional (default None) + axis number along which the image array is sliced to create an animation plot. + If `img` is an xarray, `animation_frame` can be the name of one the dimensions. + + facet_col: int or str, optional (default None) + axis number along which the image array is sliced to create a facetted plot. + If `img` is an xarray, `facet_col` can be the name of one the dimensions. facet_col_wrap: int Maximum number of facet columns. Wraps the column variable at this width, @@ -307,12 +312,6 @@ def imshow( slices = range(nslices) # ----- Define x and y, set labels if img is an xarray ------------------- if xarray_imported and isinstance(img, xarray.DataArray): - # if binary_string: - # raise ValueError( - # "It is not possible to use binary image strings for xarrays." - # "Please pass your data as a numpy array instead using" - # "`img.values`" - # ) dims = list(img.dims) if slice_through: slice_index = facet_col if facet_col is not None else animation_frame From b65203987326d03df5cfbcdca3d4c29108b6be98 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 17 Sep 2020 11:46:40 +0200 Subject: [PATCH 16/28] animation + facet kinda working now, but it broke labels --- .../python/plotly/plotly/express/_imshow.py | 102 +++++++++++++----- 1 file changed, 73 insertions(+), 29 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 1c5c619624..967e11f2c9 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -7,6 +7,7 @@ import pandas as pd from .png import Writer, from_array import numpy as np +import itertools try: import xarray @@ -293,31 +294,41 @@ def imshow( args = locals() apply_default_cascade(args) labels = labels.copy() - nslices = 1 + nslices_facet = 1 if facet_col is not None: if isinstance(facet_col, str): facet_col = img.dims.index(facet_col) - nslices = img.shape[facet_col] - ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices - nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols + nslices_facet = img.shape[facet_col] + facet_slices = range(nslices_facet) + ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices_facet + nrows = ( + nslices_facet // ncols + 1 + if nslices_facet % ncols + else nslices_facet // ncols + ) else: nrows = 1 ncols = 1 if animation_frame is not None: if isinstance(animation_frame, str): animation_frame = img.dims.index(animation_frame) - nslices = img.shape[animation_frame] + nslices_animation = img.shape[animation_frame] + animation_slices = range(nslices_animation) slice_through = (facet_col is not None) or (animation_frame is not None) - slice_label = None - slices = range(nslices) + double_slice_through = (facet_col is not None) and (animation_frame is not None) + facet_label = None + animation_label = None # ----- Define x and y, set labels if img is an xarray ------------------- if xarray_imported and isinstance(img, xarray.DataArray): dims = list(img.dims) - if slice_through: - slice_index = facet_col if facet_col is not None else animation_frame - slices = img.coords[img.dims[slice_index]].values - _ = dims.pop(slice_index) - slice_label = img.dims[slice_index] + if facet_col is not None: + facet_slices = img.coords[img.dims[facet_col]].values + _ = dims.pop(facet_col) + facet_label = img.dims[facet_col] + if animation_frame is not None: + animation_slices = img.coords[img.dims[animation_frame]].values + _ = dims.pop(animation_frame) + animation_label = img.dims[animation_frame] y_label, x_label = dims[0], dims[1] # np.datetime64 is not handled correctly by go.Heatmap for ax in [x_label, y_label]: @@ -333,8 +344,10 @@ def imshow( labels["x"] = x_label if labels.get("y", None) is None: labels["y"] = y_label - if labels.get("slice", None) is None: - labels["slice"] = slice_label + if labels.get("animation_slice", None) is None: + labels["animation_slice"] = animation_label + if labels.get("facet_slice", None) is None: + labels["facet_slice"] = facet_label if labels.get("color", None) is None: labels["color"] = xarray.plot.utils.label_from_attrs(img) labels["color"] = labels["color"].replace("\n", "
") @@ -371,11 +384,15 @@ def imshow( img = np.asanyarray(img) if facet_col is not None: img = np.moveaxis(img, facet_col, 0) + print(img.shape) + if animation_frame is not None and animation_frame < facet_col: + animation_frame += 1 facet_col = True if animation_frame is not None: img = np.moveaxis(img, animation_frame, 0) + print(img.shape) animation_frame = True - args["animation_frame"] = ( + args["animation_frame"] = ( # TODO "slice" if labels.get("slice") is None else labels["slice"] ) @@ -431,9 +448,16 @@ def imshow( + "dimension of the img matrix." ) if slice_through: + iterables = () + if animation_frame is not None: + iterables += (range(nslices_animation),) + if facet_col is not None: + iterables += (range(nslices_facet),) traces = [ - go.Heatmap(x=x, y=y, z=img_slice, coloraxis="coloraxis1", name=str(i)) - for i, img_slice in enumerate(img) + go.Heatmap( + x=x, y=y, z=img[index_tup], coloraxis="coloraxis1", name=str(i) + ) + for i, index_tup in enumerate(itertools.product(*iterables)) ] else: traces = [go.Heatmap(x=x, y=y, z=img, coloraxis="coloraxis1")] @@ -464,11 +488,21 @@ def imshow( _vectorize_zvalue(zmin, mode="min"), _vectorize_zvalue(zmax, mode="max"), ) + if slice_through: + iterables = () + if animation_frame is not None: + iterables += (range(nslices_animation),) + if facet_col is not None: + iterables += (range(nslices_facet),) if binary_string: if zmin is None and zmax is None: # no rescaling, faster img_rescaled = img rescale_image = False - elif img.ndim == 2 or (img.ndim == 3 and slice_through): + elif ( + img.ndim == 2 + or (img.ndim == 3 and slice_through) + or (img.ndim == 4 and double_slice_through) + ): img_rescaled = rescale_intensity( img, in_range=(zmin[0], zmax[0]), out_range=np.uint8 ) @@ -485,14 +519,15 @@ def imshow( axis=-1, ) if slice_through: + tuples = [index_tup for index_tup in itertools.product(*iterables)] img_str = [ _array_to_b64str( - img_rescaled_slice, + img_rescaled[index_tup], backend=binary_backend, compression=binary_compression_level, ext=binary_format, ) - for img_rescaled_slice in img_rescaled + for index_tup in itertools.product(*iterables) ] else: @@ -512,8 +547,10 @@ def imshow( colormodel = "rgb" if img.shape[-1] == 3 else "rgba256" if slice_through: traces = [ - go.Image(z=img_slice, zmin=zmin, zmax=zmax, colormodel=colormodel) - for img_slice in img + go.Image( + z=img[index_tup], zmin=zmin, zmax=zmax, colormodel=colormodel + ) + for index_tup in itertools.product(*iterables) ] else: traces = [go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)] @@ -533,9 +570,9 @@ def imshow( col_labels = [] if facet_col is not None: slice_label = "slice" if labels.get("slice") is None else labels["slice"] - if slices is None: - slices = range(nslices) - col_labels = ["%s = %d" % (slice_label, i) for i in slices] + if facet_slices is None: + facet_slices = range(nslices_facet) + col_labels = ["%s = %d" % (slice_label, i) for i in facet_slices] fig = init_figure(args, "xy", [], nrows, ncols, col_labels, []) layout_patch = dict() for attr_name in ["height", "width"]: @@ -547,11 +584,18 @@ def imshow( layout_patch["margin"] = {"t": 60} frame_list = [] - for index, (slice_index, trace) in enumerate(zip(slices, traces)): - if facet_col or index == 0: + for index, trace in enumerate(traces): + if (facet_col and index < nrows * ncols) or index == 0: fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1) - if animation_frame: - frame_list.append(dict(data=trace, layout=layout, name=str(slice_index))) + if animation_frame is not None: + for i in range(nslices_animation): + frame_list.append( + dict( + data=traces[nslices_facet * i : nslices_facet * (i + 1)], + layout=layout, + name=str(i), + ) + ) if animation_frame: fig.frames = frame_list fig.update_layout(layout) From 59c6622815fa4d4a0dccc2def71c4b3da6c497cd Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 17 Sep 2020 11:52:17 +0200 Subject: [PATCH 17/28] added test --- .../plotly/plotly/tests/test_core/test_px/test_imshow.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py index 03a5c0f254..4f87f88aae 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -362,6 +362,13 @@ def test_animation_frame_grayscale(animation_frame, binary_string): def test_animation_frame_rgb(animation_frame, binary_string): img = np.random.randint(255, size=(10, 9, 8, 3)).astype(np.uint8) fig = px.imshow(img, animation_frame=animation_frame, binary_string=binary_string,) - print(binary_string) nslices = img.shape[animation_frame] assert len(fig.frames) == nslices + + +def test_animation_and_facet(): + img = np.random.randint(255, size=(10, 9, 8, 7)).astype(np.uint8) + fig = px.imshow(img, animation_frame=0, facet_col=1, binary_string=True) + nslices = img.shape[0] + assert len(fig.frames) == nslices + assert len(fig.data) == img.shape[1] From c7285a346feeb0a3ad3153cc1b9a1d9980f4f7f5 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 17 Sep 2020 12:00:18 +0200 Subject: [PATCH 18/28] simplified code --- .../python/plotly/plotly/express/_imshow.py | 64 +++++++------------ .../tests/test_core/test_px/test_imshow.py | 5 +- 2 files changed, 27 insertions(+), 42 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 967e11f2c9..a5388c3450 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -434,7 +434,11 @@ def imshow( zmin = 0 # For 2d data, use Heatmap trace, unless binary_string is True - if (img.ndim == 2 or (img.ndim == 3 and slice_through)) and not binary_string: + if ( + img.ndim == 2 + or (img.ndim == 3 and slice_through) + or (img.ndim == 4 and double_slice_through) + ) and not binary_string: y_index = 1 if slice_through else 0 if y is not None and img.shape[y_index] != len(y): raise ValueError( @@ -447,20 +451,16 @@ def imshow( "The length of the x vector must match the length of the second " + "dimension of the img matrix." ) + iterables = () if slice_through: - iterables = () if animation_frame is not None: iterables += (range(nslices_animation),) if facet_col is not None: iterables += (range(nslices_facet),) - traces = [ - go.Heatmap( - x=x, y=y, z=img[index_tup], coloraxis="coloraxis1", name=str(i) - ) - for i, index_tup in enumerate(itertools.product(*iterables)) - ] - else: - traces = [go.Heatmap(x=x, y=y, z=img, coloraxis="coloraxis1")] + traces = [ + go.Heatmap(x=x, y=y, z=img[index_tup], coloraxis="coloraxis1", name=str(i)) + for i, index_tup in enumerate(itertools.product(*iterables)) + ] autorange = True if origin == "lower" else "reversed" layout = dict(yaxis=dict(autorange=autorange)) if aspect == "equal": @@ -488,8 +488,8 @@ def imshow( _vectorize_zvalue(zmin, mode="min"), _vectorize_zvalue(zmax, mode="max"), ) + iterables = () if slice_through: - iterables = () if animation_frame is not None: iterables += (range(nslices_animation),) if facet_col is not None: @@ -518,42 +518,26 @@ def imshow( ], axis=-1, ) - if slice_through: - tuples = [index_tup for index_tup in itertools.product(*iterables)] - img_str = [ - _array_to_b64str( - img_rescaled[index_tup], - backend=binary_backend, - compression=binary_compression_level, - ext=binary_format, - ) - for index_tup in itertools.product(*iterables) - ] + img_str = [ + _array_to_b64str( + img_rescaled[index_tup], + backend=binary_backend, + compression=binary_compression_level, + ext=binary_format, + ) + for index_tup in itertools.product(*iterables) + ] - else: - img_str = [ - _array_to_b64str( - img_rescaled, - backend=binary_backend, - compression=binary_compression_level, - ext=binary_format, - ) - ] traces = [ go.Image(source=img_str_slice, name=str(i)) for i, img_str_slice in enumerate(img_str) ] else: colormodel = "rgb" if img.shape[-1] == 3 else "rgba256" - if slice_through: - traces = [ - go.Image( - z=img[index_tup], zmin=zmin, zmax=zmax, colormodel=colormodel - ) - for index_tup in itertools.product(*iterables) - ] - else: - traces = [go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)] + traces = [ + go.Image(z=img[index_tup], zmin=zmin, zmax=zmax, colormodel=colormodel) + for index_tup in itertools.product(*iterables) + ] layout = {} if origin == "lower": layout["yaxis"] = dict(autorange=True) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py index 4f87f88aae..ac0c1a96ea 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -366,9 +366,10 @@ def test_animation_frame_rgb(animation_frame, binary_string): assert len(fig.frames) == nslices -def test_animation_and_facet(): +@pytest.mark.parametrize("binary_string", [False, True]) +def test_animation_and_facet(binary_string): img = np.random.randint(255, size=(10, 9, 8, 7)).astype(np.uint8) - fig = px.imshow(img, animation_frame=0, facet_col=1, binary_string=True) + fig = px.imshow(img, animation_frame=0, facet_col=1, binary_string=binary_string) nslices = img.shape[0] assert len(fig.frames) == nslices assert len(fig.data) == img.shape[1] From 91c066e7df97717b5b2097108829b5a237cf024e Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 17 Sep 2020 12:03:59 +0200 Subject: [PATCH 19/28] simplified code --- .../python/plotly/plotly/express/_imshow.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index a5388c3450..f902f96925 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -395,6 +395,12 @@ def imshow( args["animation_frame"] = ( # TODO "slice" if labels.get("slice") is None else labels["slice"] ) + iterables = () + if slice_through: + if animation_frame is not None: + iterables += (range(nslices_animation),) + if facet_col is not None: + iterables += (range(nslices_facet),) # Default behaviour of binary_string: True for RGB images, False for 2D if binary_string is None: @@ -451,12 +457,6 @@ def imshow( "The length of the x vector must match the length of the second " + "dimension of the img matrix." ) - iterables = () - if slice_through: - if animation_frame is not None: - iterables += (range(nslices_animation),) - if facet_col is not None: - iterables += (range(nslices_facet),) traces = [ go.Heatmap(x=x, y=y, z=img[index_tup], coloraxis="coloraxis1", name=str(i)) for i, index_tup in enumerate(itertools.product(*iterables)) @@ -488,12 +488,6 @@ def imshow( _vectorize_zvalue(zmin, mode="min"), _vectorize_zvalue(zmax, mode="max"), ) - iterables = () - if slice_through: - if animation_frame is not None: - iterables += (range(nslices_animation),) - if facet_col is not None: - iterables += (range(nslices_facet),) if binary_string: if zmin is None and zmax is None: # no rescaling, faster img_rescaled = img From ac5aa1fa82e31d9d8b09f6985913b1ec237c900c Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 17 Sep 2020 13:50:08 +0200 Subject: [PATCH 20/28] polished code and added doc example --- doc/python/imshow.md | 17 ++++- .../python/plotly/plotly/express/_imshow.py | 71 +++++++------------ 2 files changed, 43 insertions(+), 45 deletions(-) diff --git a/doc/python/imshow.md b/doc/python/imshow.md index 125066fcc5..1e1d3b1610 100644 --- a/doc/python/imshow.md +++ b/doc/python/imshow.md @@ -415,7 +415,7 @@ from skimage import io from skimage.data import image_fetcher path = image_fetcher.fetch('data/cells.tif') data = io.imread(path) -img = data[25:40] +img = data[20:45:2] fig = px.imshow(img, facet_col=0, binary_string=True, facet_col_wrap=5, height=700) fig.show() ``` @@ -466,6 +466,21 @@ fig = px.imshow(ds, animation_frame='time', zmin=220, zmax=300, color_continuous fig.show() ``` +### Combining animations and facets + +It is possible to view 4-dimensional datasets (for example, 3-D images evolving with time) using a combination of `animation_frame` and `facet_col`. + +```python +import plotly.express as px +from skimage import io +from skimage.data import image_fetcher +path = image_fetcher.fetch('data/cells.tif') +data = io.imread(path) +data = data.reshape((15, 4, 256, 256))[5:] +fig = px.imshow(data, animation_frame=0, facet_col=1, binary_string=True) +fig.show() +``` + #### Reference See https://plotly.com/python/reference/image/ for more information and chart attribute options! diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index f902f96925..e979097a48 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -314,8 +314,9 @@ def imshow( animation_frame = img.dims.index(animation_frame) nslices_animation = img.shape[animation_frame] animation_slices = range(nslices_animation) - slice_through = (facet_col is not None) or (animation_frame is not None) - double_slice_through = (facet_col is not None) and (animation_frame is not None) + slice_dimensions = (facet_col is not None) + ( + animation_frame is not None + ) # 0, 1, or 2 facet_label = None animation_label = None # ----- Define x and y, set labels if img is an xarray ------------------- @@ -344,10 +345,10 @@ def imshow( labels["x"] = x_label if labels.get("y", None) is None: labels["y"] = y_label - if labels.get("animation_slice", None) is None: - labels["animation_slice"] = animation_label - if labels.get("facet_slice", None) is None: - labels["facet_slice"] = facet_label + if labels.get("animation", None) is None: + labels["animation"] = animation_label + if labels.get("facet", None) is None: + labels["facet"] = facet_label if labels.get("color", None) is None: labels["color"] = xarray.plot.utils.label_from_attrs(img) labels["color"] = labels["color"].replace("\n", "
") @@ -382,32 +383,27 @@ def imshow( # --------------- Starting from here img is always a numpy array -------- img = np.asanyarray(img) + # Reshape array so that animation dimension comes first, then facets, then images if facet_col is not None: img = np.moveaxis(img, facet_col, 0) - print(img.shape) if animation_frame is not None and animation_frame < facet_col: animation_frame += 1 facet_col = True if animation_frame is not None: img = np.moveaxis(img, animation_frame, 0) - print(img.shape) animation_frame = True - args["animation_frame"] = ( # TODO - "slice" if labels.get("slice") is None else labels["slice"] + args["animation_frame"] = ( + "slice" if labels.get("animation") is None else labels["animation"] ) iterables = () - if slice_through: - if animation_frame is not None: - iterables += (range(nslices_animation),) - if facet_col is not None: - iterables += (range(nslices_facet),) + if animation_frame is not None: + iterables += (range(nslices_animation),) + if facet_col is not None: + iterables += (range(nslices_facet),) # Default behaviour of binary_string: True for RGB images, False for 2D if binary_string is None: - if slice_through: - binary_string = img.ndim >= 4 and not is_dataframe - else: - binary_string = img.ndim >= 3 and not is_dataframe + binary_string = img.ndim >= (3 + slice_dimensions) and not is_dataframe # Cast bools to uint8 (also one byte) if img.dtype == np.bool: @@ -419,11 +415,7 @@ def imshow( # -------- Contrast rescaling: either minmax or infer ------------------ if contrast_rescaling is None: - contrast_rescaling = ( - "minmax" - if (img.ndim == 2 or (img.ndim == 3 and slice_through)) - else "infer" - ) + contrast_rescaling = "minmax" if img.ndim == (2 + slice_dimensions) else "infer" # We try to set zmin and zmax only if necessary, because traces have good defaults if contrast_rescaling == "minmax": @@ -439,19 +431,15 @@ def imshow( if zmin is None and zmax is not None: zmin = 0 - # For 2d data, use Heatmap trace, unless binary_string is True - if ( - img.ndim == 2 - or (img.ndim == 3 and slice_through) - or (img.ndim == 4 and double_slice_through) - ) and not binary_string: - y_index = 1 if slice_through else 0 + # For 2d data, use Heatmap trace, unless binary_string is True + if img.ndim == 2 + slice_dimensions and not binary_string: + y_index = slice_dimensions if y is not None and img.shape[y_index] != len(y): raise ValueError( "The length of the y vector must match the length of the first " + "dimension of the img matrix." ) - x_index = 2 if slice_through else 1 + x_index = slice_dimensions + 1 if x is not None and img.shape[x_index] != len(x): raise ValueError( "The length of the x vector must match the length of the second " @@ -480,7 +468,8 @@ def imshow( # For 2D+RGB data, use Image trace elif ( - img.ndim >= 3 and (img.shape[-1] in [3, 4] or slice_through and binary_string) + img.ndim >= 3 + and (img.shape[-1] in [3, 4] or slice_dimensions and binary_string) ) or (img.ndim == 2 and binary_string): rescale_image = True # to check whether image has been modified if zmin is not None and zmax is not None: @@ -492,11 +481,7 @@ def imshow( if zmin is None and zmax is None: # no rescaling, faster img_rescaled = img rescale_image = False - elif ( - img.ndim == 2 - or (img.ndim == 3 and slice_through) - or (img.ndim == 4 and double_slice_through) - ): + elif img.ndim == 2 + slice_dimensions: # single-channel image img_rescaled = rescale_intensity( img, in_range=(zmin[0], zmax[0]), out_range=np.uint8 ) @@ -547,9 +532,7 @@ def imshow( # Now build figure col_labels = [] if facet_col is not None: - slice_label = "slice" if labels.get("slice") is None else labels["slice"] - if facet_slices is None: - facet_slices = range(nslices_facet) + slice_label = "slice" if labels.get("facet") is None else labels["facet"] col_labels = ["%s = %d" % (slice_label, i) for i in facet_slices] fig = init_figure(args, "xy", [], nrows, ncols, col_labels, []) layout_patch = dict() @@ -566,12 +549,12 @@ def imshow( if (facet_col and index < nrows * ncols) or index == 0: fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1) if animation_frame is not None: - for i in range(nslices_animation): + for i, index in zip(range(nslices_animation), animation_slices): frame_list.append( dict( data=traces[nslices_facet * i : nslices_facet * (i + 1)], layout=layout, - name=str(i), + name=str(index), ) ) if animation_frame: @@ -607,5 +590,5 @@ def imshow( if labels["y"]: fig.update_yaxes(title_text=labels["y"]) configure_animation_controls(args, go.Image, fig) - # fig.update_layout(template=args["template"], overwrite=True) + fig.update_layout(template=args["template"], overwrite=True) return fig From 8cdc6afd486d4492a06310b12d3c1865d4249c3e Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 17 Nov 2020 14:45:29 +0100 Subject: [PATCH 21/28] updated doc --- doc/python/imshow.md | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/doc/python/imshow.md b/doc/python/imshow.md index a75edd3e6b..81c350e706 100644 --- a/doc/python/imshow.md +++ b/doc/python/imshow.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.2' - jupytext_version: 1.3.0 + jupytext_version: 1.7.1 kernelspec: display_name: Python 3 language: python @@ -401,7 +401,7 @@ fig.show() ### Exploring 3-D images and timeseries with `facet_col` -*Introduced in plotly 4.11* +*Introduced in plotly 4.13* For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by representing its different planes as facets. The `facet_col` argument specifies along which axis the image is sliced through to make the facets. With `facet_col_wrap`, one can set the maximum number of columns. For image datasets passed as xarrays, it is also possible to specify the axis by its name (label), thus passing a string to `facet_col`. @@ -416,27 +416,13 @@ from skimage.data import image_fetcher path = image_fetcher.fetch('data/cells.tif') data = io.imread(path) img = data[20:45:2] -fig = px.imshow(img, facet_col=0, binary_string=True, facet_col_wrap=5, height=700) -fig.show() -``` - -```python -import plotly.express as px -from skimage import io -from skimage.data import image_fetcher -path = image_fetcher.fetch('data/cells.tif') -data = io.imread(path) -img = data[25:40] fig = px.imshow(img, facet_col=0, binary_string=True, facet_col_wrap=5) -# To have square facets one needs to unmatch axes -fig.update_xaxes(matches=None) -fig.update_yaxes(matches=None) fig.show() ``` ### Exploring 3-D images and timeseries with `animation_frame` -*Introduced in plotly 4.11* +*Introduced in plotly 4.13* For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by sliding through its different planes in an animation. The `animation_frame` argument of `px.imshow` sets the axis along which the 3-D image is sliced in the animation. From 5d1d8d8024a7ec9d8c55d26b90dd7fb76f9cad75 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Nov 2020 17:30:33 +0100 Subject: [PATCH 22/28] add facet_col_spacing and facet_row_spacing --- packages/python/plotly/plotly/express/_imshow.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index efe264bd7c..8d66a27e3a 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -64,6 +64,8 @@ def imshow( animation_frame=None, facet_col=None, facet_col_wrap=None, + facet_col_spacing=None, + facet_row_spacing=None, color_continuous_scale=None, color_continuous_midpoint=None, range_color=None, @@ -130,6 +132,13 @@ def imshow( so that the column facets span multiple rows. Ignored if `facet_col` is None. + facet_col_spacing: float between 0 and 1 + Spacing between facet columns, in paper units. Default is 0.02. + + facet_row_spacing: float between 0 and 1 + Spacing between facet rows created when ``facet_col_wrap`` is used, in + paper units. Default is 0.0.7. + color_continuous_scale : str or list of str colormap used to map scalar data to colors (for a 2D image). This parameter is not used for RGB or RGBA images. If a string is provided, it should be the name From c27f88a9d518175ff6a6278822ab4d02a7d76a34 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Nov 2020 17:48:41 +0100 Subject: [PATCH 23/28] modify error message + animation_frame label --- packages/python/plotly/plotly/express/_imshow.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 8d66a27e3a..552599e577 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -331,7 +331,9 @@ def imshow( img = np.moveaxis(img, animation_frame, 0) animation_frame = True args["animation_frame"] = ( - "slice" if labels.get("animation") is None else labels["animation"] + "animation_frame" + if labels.get("animation_frame") is None + else labels["animation_frame"] ) iterables = () if animation_frame is not None: @@ -509,8 +511,8 @@ def imshow( raise ValueError( "px.imshow only accepts 2D single-channel, RGB or RGBA images. " "An image of shape %s was provided." - "Alternatively, 3-D single or multichannel datasets can be" - "visualized using the `facet_col` or `animation_frame` arguments." + "Alternatively, 3- or 4-D single or multichannel datasets can be" + "visualized using the `facet_col` or/and `animation_frame` arguments." % str(img.shape) ) From 502fdfd5178788ecd6f37872ced5da4e1db132d0 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Nov 2020 17:57:40 +0100 Subject: [PATCH 24/28] improve code readibility --- packages/python/plotly/plotly/express/_imshow.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 552599e577..deafa8dcde 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -522,14 +522,13 @@ def imshow( slice_label = "slice" if labels.get("facet") is None else labels["facet"] col_labels = ["%s = %d" % (slice_label, i) for i in facet_slices] fig = init_figure(args, "xy", [], nrows, ncols, col_labels, []) - layout_patch = dict() for attr_name in ["height", "width"]: if args[attr_name]: - layout_patch[attr_name] = args[attr_name] + layout[attr_name] = args[attr_name] if args["title"]: - layout_patch["title_text"] = args["title"] + layout["title_text"] = args["title"] elif args["template"].layout.margin.t is None: - layout_patch["margin"] = {"t": 60} + layout["margin"] = {"t": 60} frame_list = [] for index, trace in enumerate(traces): @@ -547,7 +546,6 @@ def imshow( if animation_frame: fig.frames = frame_list fig.update_layout(layout) - fig.update_layout(layout_patch) # Hover name, z or color if binary_string and rescale_image and not np.all(img == img_rescaled): # we rescaled the image, hence z is not displayed in hover since it does From 135b01bd174a76cc9f634cbc9ac30d85b9e3ba82 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Nov 2020 21:50:47 +0100 Subject: [PATCH 25/28] added example with sequence of images --- doc/python/imshow.md | 31 ++++++++++++++++--- .../python/plotly/plotly/express/_imshow.py | 4 ++- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/doc/python/imshow.md b/doc/python/imshow.md index 01477f15a8..dbd19b1336 100644 --- a/doc/python/imshow.md +++ b/doc/python/imshow.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.2' - jupytext_version: 1.7.1 + jupytext_version: 1.3.0 kernelspec: display_name: Python 3 language: python @@ -399,9 +399,9 @@ for compression_level in range(0, 9): fig.show() ``` -### Exploring 3-D images and timeseries with `facet_col` +### Exploring 3-D images, timeseries and sequences of images with `facet_col` -*Introduced in plotly 4.13* +*Introduced in plotly 4.14* For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by representing its different planes as facets. The `facet_col` argument specifies along which axis the image is sliced through to make the facets. With `facet_col_wrap`, one can set the maximum number of columns. For image datasets passed as xarrays, it is also possible to specify the axis by its name (label), thus passing a string to `facet_col`. @@ -420,9 +420,30 @@ fig = px.imshow(img, facet_col=0, binary_string=True, facet_col_wrap=5) fig.show() ``` +Facets can also be used to represent several images of equal shape, like in the example below where different values of the blurring parameter of a Gaussian filter are compared. + +```python +import plotly.express as px +import numpy as np +from skimage import data, filters, img_as_float +img = data.camera() +sigmas = [1, 2, 4] +img_sequence = [filters.gaussian(img, sigma=sigma) for sigma in sigmas] +fig = px.imshow(np.array(img_sequence), facet_col=0, binary_string=True, + labels={'facet_col':'sigma'}) +# Set facet titles +for sigma in sigmas: + fig.layout.annotations[i]['text'] = 'sigma = %d' %sigma +fig.show() +``` + +```python +print(fig) +``` + ### Exploring 3-D images and timeseries with `animation_frame` -*Introduced in plotly 4.13* +*Introduced in plotly 4.14* For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by sliding through its different planes in an animation. The `animation_frame` argument of `px.imshow` sets the axis along which the 3-D image is sliced in the animation. @@ -439,7 +460,7 @@ fig.show() ### Animations of xarray datasets -*Introduced in plotly 4.11* +*Introduced in plotly 4.14* For xarray datasets, one can pass either an axis number or an axis name to `animation_frame`. Axis names and coordinates are automatically used for the labels, ticks and animation controls of the figure. diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index deafa8dcde..f81e3523a2 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -519,7 +519,9 @@ def imshow( # Now build figure col_labels = [] if facet_col is not None: - slice_label = "slice" if labels.get("facet") is None else labels["facet"] + slice_label = ( + "slice" if labels.get("facet_col") is None else labels["facet_col"] + ) col_labels = ["%s = %d" % (slice_label, i) for i in facet_slices] fig = init_figure(args, "xy", [], nrows, ncols, col_labels, []) for attr_name in ["height", "width"]: From 6ac3e36e5b1812f4614d12f3f470445e47843d8e Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Nov 2020 22:14:09 +0100 Subject: [PATCH 26/28] typoe --- doc/python/imshow.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/python/imshow.md b/doc/python/imshow.md index dbd19b1336..1b92bfd2f6 100644 --- a/doc/python/imshow.md +++ b/doc/python/imshow.md @@ -432,7 +432,7 @@ img_sequence = [filters.gaussian(img, sigma=sigma) for sigma in sigmas] fig = px.imshow(np.array(img_sequence), facet_col=0, binary_string=True, labels={'facet_col':'sigma'}) # Set facet titles -for sigma in sigmas: +for i, sigma in enumerate(sigmas): fig.layout.annotations[i]['text'] = 'sigma = %d' %sigma fig.show() ``` From a5a225267bf65ff3c45a25a029e0a221075fe8a4 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Fri, 27 Nov 2020 17:10:23 +0100 Subject: [PATCH 27/28] label names --- packages/python/plotly/plotly/express/_imshow.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index f81e3523a2..e7b6f72f9e 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -283,10 +283,10 @@ def imshow( labels["x"] = x_label if labels.get("y", None) is None: labels["y"] = y_label - if labels.get("animation", None) is None: - labels["animation"] = animation_label - if labels.get("facet", None) is None: - labels["facet"] = facet_label + if labels.get("animation_frame", None) is None: + labels["animation_frame"] = animation_label + if labels.get("facet_col", None) is None: + labels["facet_col"] = facet_label if labels.get("color", None) is None: labels["color"] = xarray.plot.utils.label_from_attrs(img) labels["color"] = labels["color"].replace("\n", "
") From 77cb5cdf1ce295dea49367a70a6144f308642b19 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 30 Nov 2020 16:47:52 +0100 Subject: [PATCH 28/28] label name --- packages/python/plotly/plotly/express/_imshow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index e7b6f72f9e..27d1bc7349 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -520,7 +520,7 @@ def imshow( col_labels = [] if facet_col is not None: slice_label = ( - "slice" if labels.get("facet_col") is None else labels["facet_col"] + "facet_col" if labels.get("facet_col") is None else labels["facet_col"] ) col_labels = ["%s = %d" % (slice_label, i) for i in facet_slices] fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])