Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Creating a subplot of images with plotly

I wanted to display the first 10 images from the mnist dataset with plotly. This is turning out to be more complicated than I thought. This does not work:

import numpy as np
np.random.seed(123)

import plotly.express as px
from keras.datasets import mnist

(X_train, y_train), (X_test, y_test) =  mnist.load_data()

fig = subplots.make_subplots(rows=1, cols=10)
fig.add_trace(px.imshow(X_train[0]), row=1, col=1)

as it results in

ValueError: 
    Invalid element(s) received for the 'data' property of 
        Invalid elements include: [Figure({
    'data': [{'coloraxis': 'coloraxis',
              'type': 'heatmap',
              'z': array([[0, 0, 0, ..., 0, 0, 0],
                          [0, 0, 0, ..., 0, 0, 0],
                          [0, 0, 0, ..., 0, 0, 0],
                          ...,
                          [0, 0, 0, ..., 0, 0, 0],
                          [0, 0, 0, ..., 0, 0, 0],
                          [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}],
    'layout': {'coloraxis': {'colorscale': [[0.0, '#0d0887'], [0.1111111111111111,
                                            '#46039f'], [0.2222222222222222,
                                            '#7201a8'], [0.3333333333333333,
                                            '#9c179e'], [0.4444444444444444,
                                            '#bd3786'], [0.5555555555555556,
                                            '#d8576b'], [0.6666666666666666,
                                            '#ed7953'], [0.7777777777777778,
                                            '#fb9f3a'], [0.8888888888888888,
                                            '#fdca26'], [1.0, '#f0f921']]},
               'margin': {'t': 60},
               'template': '...',
               'xaxis': {'constrain': 'domain', 'scaleanchor': 'y'},
               'yaxis': {'autorange': 'reversed', 'constrain': 'domain'}}
})]

    The 'data' property is a tuple of trace instances
    that may be specified as:
      - A list or tuple of trace instances
        (e.g. [Scatter(...), Bar(...)])
      - A single trace instance
        (e.g. Scatter(...), Bar(...), etc.)
      - A list or tuple of dicts of string/value properties where:
        - The 'type' property specifies the trace type
            One of: ['area', 'bar', 'barpolar', 'box',
                     'candlestick', 'carpet', 'choropleth',
                     'choroplethmapbox', 'cone', 'contour',
                     'contourcarpet', 'densitymapbox', 'funnel',
                     'funnelarea', 'heatmap', 'heatmapgl',
                     'histogram', 'histogram2d',
                     'histogram2dcontour', 'image', 'indicator',
                     'isosurface', 'mesh3d', 'ohlc', 'parcats',
                     'parcoords', 'pie', 'pointcloud', 'sankey',
                     'scatter', 'scatter3d', 'scattercarpet',
                     'scattergeo', 'scattergl', 'scattermapbox',
                     'scatterpolar', 'scatterpolargl',
                     'scatterternary', 'splom', 'streamtube',
                     'sunburst', 'surface', 'table', 'treemap',
                     'violin', 'volume', 'waterfall']

        - All remaining properties are passed to the constructor of
          the specified trace type

        (e.g. [{'type': 'scatter', ...}, {'type': 'bar, ...}])

neither does

fig.add_trace(go.Image(X_train[0]), row=1, col=1)

or

fig.add_trace(go.Figure(go.Heatmap(z=X_train[0])), 1,1)

I am starting to run out of ideas. It should be possible to have a row of images as a header.

like image 816
Felix B. Avatar asked May 14 '26 13:05

Felix B.


1 Answers

this worked - I hope this is not the final answer:

fig = subplots.make_subplots(rows=2, cols=5)

for n, image in enumerate(X_train[:10]):
  fig.add_trace(px.imshow(255-image).data[0], row=int(n/5)+1, col=n%5+1)

# the layout gets lost, so we have to carry it over - but we cannot simply do
# fig.layout = layout since the layout has to be slightly different for subplots
# fig.layout.yaxis in a subplot refers only to the first axis for example
# update_yaxes updates *all* axis on the other hand
layout = px.imshow(X_train[0], color_continuous_scale='gray').layout
fig.layout.coloraxis = layout.coloraxis
fig.update_xaxes(**layout.xaxis.to_plotly_json())
fig.update_yaxes(**layout.yaxis.to_plotly_json())
fig.show()

enter image description here

Interestingly enough the picture generated if you click on the "plotly download plot as png" icon is not the same and looks like this (cf. Github issue):

enter image description here

like image 113
Felix B. Avatar answered May 16 '26 02:05

Felix B.



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!