/
plotly_chart.py
224 lines (178 loc) · 7.38 KB
/
plotly_chart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Streamlit support for Plotly charts."""
import json
import urllib.parse
from typing import Any, cast, Dict, List, Set, TYPE_CHECKING, Union
from typing_extensions import Final, Literal, TypeAlias
from streamlit.legacy_caching import caching
from streamlit import type_util
from streamlit.logger import get_logger
from streamlit.proto.PlotlyChart_pb2 import PlotlyChart as PlotlyChartProto
if TYPE_CHECKING:
import matplotlib
import plotly.graph_objs as go
from plotly.basedatatypes import BaseFigure
from streamlit.delta_generator import DeltaGenerator
LOGGER: Final = get_logger(__name__)
SharingMode: TypeAlias = Literal["streamlit", "private", "public", "secret"]
SHARING_MODES: Set[SharingMode] = {
# This means the plot will be sent to the Streamlit app rather than to
# Plotly.
"streamlit",
# The three modes below are for plots that should be hosted in Plotly.
# These are the names Plotly uses for them.
"private",
"public",
"secret",
}
_AtomicFigureOrData: TypeAlias = Union[
"go.Figure",
"go.Data",
]
FigureOrData: TypeAlias = Union[
_AtomicFigureOrData,
List[_AtomicFigureOrData],
# It is kind of hard to figure out exactly what kind of dict is supported
# here, as plotly hasn't embraced typing yet. This version is chosen to
# align with the docstring.
Dict[str, _AtomicFigureOrData],
"BaseFigure",
"matplotlib.figure.Figure",
]
class PlotlyMixin:
def plotly_chart(
self,
figure_or_data: FigureOrData,
use_container_width: bool = False,
sharing: SharingMode = "streamlit",
**kwargs: Any,
) -> "DeltaGenerator":
"""Display an interactive Plotly chart.
Plotly is a charting library for Python. The arguments to this function
closely follow the ones for Plotly's `plot()` function. You can find
more about Plotly at https://plot.ly/python.
To show Plotly charts in Streamlit, call `st.plotly_chart` wherever you
would call Plotly's `py.plot` or `py.iplot`.
Parameters
----------
figure_or_data : plotly.graph_objs.Figure, plotly.graph_objs.Data,
dict/list of plotly.graph_objs.Figure/Data
See https://plot.ly/python/ for examples of graph descriptions.
use_container_width : bool
If True, set the chart width to the column width. This takes
precedence over the figure's native `width` value.
sharing : {'streamlit', 'private', 'secret', 'public'}
Use 'streamlit' to insert the plot and all its dependencies
directly in the Streamlit app using plotly's offline mode (default).
Use any other sharing mode to send the chart to Plotly chart studio, which
requires an account. See https://plotly.com/chart-studio/ for more information.
**kwargs
Any argument accepted by Plotly's `plot()` function.
Example
-------
The example below comes straight from the examples at
https://plot.ly/python:
>>> import streamlit as st
>>> import plotly.figure_factory as ff
>>> import numpy as np
>>>
>>> # Add histogram data
>>> x1 = np.random.randn(200) - 2
>>> x2 = np.random.randn(200)
>>> x3 = np.random.randn(200) + 2
>>>
>>> # Group data together
>>> hist_data = [x1, x2, x3]
>>>
>>> group_labels = ['Group 1', 'Group 2', 'Group 3']
>>>
>>> # Create distplot with custom bin_size
>>> fig = ff.create_distplot(
... hist_data, group_labels, bin_size=[.1, .25, .5])
>>>
>>> # Plot!
>>> st.plotly_chart(fig, use_container_width=True)
.. output::
https://doc-plotly-chart.streamlitapp.com/
height: 400px
"""
# NOTE: "figure_or_data" is the name used in Plotly's .plot() method
# for their main parameter. I don't like the name, but it's best to
# keep it in sync with what Plotly calls it.
plotly_chart_proto = PlotlyChartProto()
marshall(
plotly_chart_proto, figure_or_data, use_container_width, sharing, **kwargs
)
return self.dg._enqueue("plotly_chart", plotly_chart_proto)
@property
def dg(self) -> "DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("DeltaGenerator", self)
def marshall(
proto: PlotlyChartProto,
figure_or_data: FigureOrData,
use_container_width: bool,
sharing: SharingMode,
**kwargs: Any,
) -> None:
"""Marshall a proto with a Plotly spec.
See DeltaGenerator.plotly_chart for docs.
"""
# NOTE: "figure_or_data" is the name used in Plotly's .plot() method
# for their main parameter. I don't like the name, but its best to keep
# it in sync with what Plotly calls it.
import plotly.tools
if type_util.is_type(figure_or_data, "matplotlib.figure.Figure"):
figure = plotly.tools.mpl_to_plotly(figure_or_data)
else:
figure = plotly.tools.return_figure_from_figure_or_data(
figure_or_data, validate_figure=True
)
if not isinstance(sharing, str) or sharing.lower() not in SHARING_MODES:
raise ValueError("Invalid sharing mode for Plotly chart: %s" % sharing)
proto.use_container_width = use_container_width
if sharing == "streamlit":
import plotly.utils
config = dict(kwargs.get("config", {}))
# Copy over some kwargs to config dict. Plotly does the same in plot().
config.setdefault("showLink", kwargs.get("show_link", False))
config.setdefault("linkText", kwargs.get("link_text", False))
proto.figure.spec = json.dumps(figure, cls=plotly.utils.PlotlyJSONEncoder)
proto.figure.config = json.dumps(config)
else:
url = _plot_to_url_or_load_cached_url(
figure, sharing=sharing, auto_open=False, **kwargs
)
proto.url = _get_embed_url(url)
@caching.cache
def _plot_to_url_or_load_cached_url(*args: Any, **kwargs: Any) -> "go.Figure":
"""Call plotly.plot wrapped in st.cache.
This is so we don't unnecessarily upload data to Plotly's SASS if nothing
changed since the previous upload.
"""
try:
# Plotly 4 changed its main package.
import chart_studio.plotly as ply
except ImportError:
import plotly.plotly as ply
return ply.plot(*args, **kwargs)
def _get_embed_url(url: str) -> str:
parsed_url = urllib.parse.urlparse(url)
# Plotly's embed URL is the normal URL plus ".embed".
# (Note that our use namedtuple._replace is fine because that's not a
# private method! It just has an underscore to avoid clashing with the
# tuple field names)
parsed_embed_url = parsed_url._replace(path=parsed_url.path + ".embed")
return urllib.parse.urlunparse(parsed_embed_url)