Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Plotly: How to make a figure with multiple lines and shaded area for standard deviations?

How can I use Plotly to produce a line plot with a shaded standard deviation? I am trying to achieve something similar to seaborn.tsplot. Any help is appreciated. enter image description here

like image 573
Tyesh Avatar asked Apr 29 '20 04:04

Tyesh


People also ask

How do you plot two lines on the same graph in plotly?

To create multiple line charts on the same plot with plotly graph objects, all you have to do is add another trace to the plot.

How do you add a line in plotly?

Plotly have express. line() – function to create a line plot. The line() – function takes two list's as parameters to plot in X, Y axes OR You can name the data set and put the names of the columns on the X & Y axes.

What is opacity in plotly?

Opacity. Setting opacity outside the marker will set the opacity of the trace. Thus, it will allow greater visibility of additional traces but like fully opaque it is hard to distinguish density. import plotly.graph_objects as go # Generate example data import numpy as np x = np. random.

How do you add labels to a plotly scatter plot?

You can include the text labels in the text attribute. To make sure that they are displayed on the scatter plot, set mode='lines+markers+text' . See the Plotly documentation on text and annotations.

What is the data structure of Plotly figures?

Plotly's figure data structure supports defining subplots of various types (e.g. cartesian, polar, 3-dimensional, maps etc) with attached traces of various compatible types (e.g. scatter, bar, choropleth, surface etc). This means that Plotly figures are not constrained to representing a fixed set of "chart types" such as scatter plots only ...

How are Plotly line charts implemented?

Plotly line charts are implemented as connected scatterplots (see below), meaning that the points are plotted and connected with lines in the order they are provided, with no automatic reordering.

How to create a stacked area plot in Python using Plotly?

px.area creates a stacked area plot. Each filled area corresponds to one value of the column given by the line_group parameter. Dash is the best way to build analytical apps in Python using Plotly figures.

What happens if the number of lines in a plot is too many?

If the number of lines exceed the number of colors, the colors will be re-used from the start. As of now px.colors.qualitative.Plotly can be replaced with any hex color sequence that you can find using px.colors.qualitative:


4 Answers

I was able to come up with something similar. I post the code here to be used by someone else or for any suggestions for improvements.

enter image description here

import matplotlib
import random
import plotly.graph_objects as go
import numpy as np


#random color generation in plotly
hex_colors_dic = {}
rgb_colors_dic = {}
hex_colors_only = []
for name, hex in matplotlib.colors.cnames.items():
    hex_colors_only.append(hex)
    hex_colors_dic[name] = hex
    rgb_colors_dic[name] = matplotlib.colors.to_rgb(hex)

data = [[1, 3, 5, 4],
        [2, 3, 5, 4],
        [1, 1, 4, 5],
        [2, 3, 5, 4]]
#calculating mean and standard deviation
mean=np.mean(data,axis=0)
std=np.std(data,axis=0)

#draw figure
fig = go.Figure()
c = random.choice(hex_colors_only)
fig.add_trace(go.Scatter(x=np.arange(4), y=mean+std,
                                     mode='lines',
                                     line=dict(color=c,width =0.1),
                                     name='upper bound'))
fig.add_trace(go.Scatter(x=np.arange(4), y=mean,
                         mode='lines',
                         line=dict(color=c),
                         fill='tonexty',
                         name='mean'))
fig.add_trace(go.Scatter(x=np.arange(4), y=mean-std,
                         mode='lines',
                         line=dict(color=c, width =0.1),
                         fill='tonexty',
                         name='lower bound'))
fig.show()
like image 100
Tyesh Avatar answered Oct 18 '22 04:10

Tyesh


The following approach is fully flexible with regards to the number of columns in a pandas dataframe and uses the default color cycle of plotly. If the number of lines exceed the number of colors, the colors will be re-used from the start. As of now px.colors.qualitative.Plotly can be replaced with any hex color sequence that you can find using px.colors.qualitative:

Alphabet = ['#AA0DFE', '#3283FE', '#85660D', '#782AB6', '#565656', '#1...
Alphabet_r = ['#FA0087', '#FBE426', '#B00068', '#FC1CBF', '#C075A6', '...
[...]

enter image description here

Complete code:

# imports
import plotly.graph_objs as go
import plotly.express as px
import pandas as pd
import numpy as np

# sample data in a pandas dataframe
np.random.seed(1)
df=pd.DataFrame(dict(A=np.random.uniform(low=-1, high=2, size=25).tolist(),
                    B=np.random.uniform(low=-4, high=3, size=25).tolist(),
                    C=np.random.uniform(low=-1, high=3, size=25).tolist(),
                    ))
df = df.cumsum()

# define colors as a list 
colors = px.colors.qualitative.Plotly

# convert plotly hex colors to rgba to enable transparency adjustments
def hex_rgba(hex, transparency):
    col_hex = hex.lstrip('#')
    col_rgb = list(int(col_hex[i:i+2], 16) for i in (0, 2, 4))
    col_rgb.extend([transparency])
    areacol = tuple(col_rgb)
    return areacol

rgba = [hex_rgba(c, transparency=0.2) for c in colors]
colCycle = ['rgba'+str(elem) for elem in rgba]

# Make sure the colors run in cycles if there are more lines than colors
def next_col(cols):
    while True:
        for col in cols:
            yield col
line_color=next_col(cols=colCycle)

# plotly  figure
fig = go.Figure()

# add line and shaded area for each series and standards deviation
for i, col in enumerate(df):
    new_col = next(line_color)
    x = list(df.index.values+1)
    y1 = df[col]
    y1_upper = [(y + np.std(df[col])) for y in df[col]]
    y1_lower = [(y - np.std(df[col])) for y in df[col]]
    y1_lower = y1_lower[::-1]

    # standard deviation area
    fig.add_traces(go.Scatter(x=x+x[::-1],
                                y=y1_upper+y1_lower,
                                fill='tozerox',
                                fillcolor=new_col,
                                line=dict(color='rgba(255,255,255,0)'),
                                showlegend=False,
                                name=col))

    # line trace
    fig.add_traces(go.Scatter(x=x,
                              y=y1,
                              line=dict(color=new_col, width=2.5),
                              mode='lines',
                              name=col)
                                )
# set x-axis
fig.update_layout(xaxis=dict(range=[1,len(df)]))

fig.show()
like image 37
vestland Avatar answered Oct 18 '22 04:10

vestland


Great custom responses posted by others. In case someone is interested in code from the official plotly website, see here: https://plotly.com/python/continuous-error-bars/

like image 31
Moritz Avatar answered Oct 18 '22 04:10

Moritz


I wrote a function to extend plotly.express.line with the same high level interface of Plotly Express. The line function (source code below) is used in the same exact way as plotly.express.line but allows for continuous error bands with the flag argument error_y_mode which can be either 'band' or 'bar'. In the second case it produces the same result as the original plotly.express.line. Here is an usage example:

import plotly.express as px

df = px.data.gapminder().query('continent=="Americas"')
df = df[df['country'].isin({'Argentina','Brazil','Colombia'})]
df['lifeExp std'] = df['lifeExp']*.1 # Invent some error data...

for error_y_mode in {'band', 'bar'}:
    fig = line(
        data_frame = df,
        x = 'year',
        y = 'lifeExp',
        error_y = 'lifeExp std',
        error_y_mode = error_y_mode, # Here you say `band` or `bar`.
        color = 'country',
        title = f'Using error {error_y_mode}',
        markers = '.',
    )
    fig.show()

which produces the following two plots:

enter image description here enter image description here

The source code of the line function that extends plotly.express.line is this:

import plotly.express as px
import plotly.graph_objs as go

def line(error_y_mode=None, **kwargs):
    """Extension of `plotly.express.line` to use error bands."""
    ERROR_MODES = {'bar','band','bars','bands',None}
    if error_y_mode not in ERROR_MODES:
        raise ValueError(f"'error_y_mode' must be one of {ERROR_MODES}, received {repr(error_y_mode)}.")
    if error_y_mode in {'bar','bars',None}:
        fig = px.line(**kwargs)
    elif error_y_mode in {'band','bands'}:
        if 'error_y' not in kwargs:
            raise ValueError(f"If you provide argument 'error_y_mode' you must also provide 'error_y'.")
        figure_with_error_bars = px.line(**kwargs)
        fig = px.line(**{arg: val for arg,val in kwargs.items() if arg != 'error_y'})
        for data in figure_with_error_bars.data:
            x = list(data['x'])
            y_upper = list(data['y'] + data['error_y']['array'])
            y_lower = list(data['y'] - data['error_y']['array'] if data['error_y']['arrayminus'] is None else data['y'] - data['error_y']['arrayminus'])
            color = f"rgba({tuple(int(data['line']['color'].lstrip('#')[i:i+2], 16) for i in (0, 2, 4))},.3)".replace('((','(').replace('),',',').replace(' ','')
            fig.add_trace(
                go.Scatter(
                    x = x+x[::-1],
                    y = y_upper+y_lower[::-1],
                    fill = 'toself',
                    fillcolor = color,
                    line = dict(
                        color = 'rgba(255,255,255,0)'
                    ),
                    hoverinfo = "skip",
                    showlegend = False,
                    legendgroup = data['legendgroup'],
                    xaxis = data['xaxis'],
                    yaxis = data['yaxis'],
                )
            )
        # Reorder data as said here: https://stackoverflow.com/a/66854398/8849755
        reordered_data = []
        for i in range(int(len(fig.data)/2)):
            reordered_data.append(fig.data[i+int(len(fig.data)/2)])
            reordered_data.append(fig.data[i])
        fig.data = tuple(reordered_data)
    return fig
like image 31
user171780 Avatar answered Oct 18 '22 02:10

user171780