Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how to plot gradient fill on the 3d bars in matplotlib

Tags:

matplotlib

Right now there're some statistics plotted in 3d bar over (x, y). each bar height represents the density of the points in side the square grid of (x,y) plane. Right now, i can put different color on each bar. However, I want to put progressive color on the 3d bar, similar as the cmap, so the bar will be gradient filled depending on the density.

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# height of the bars
z = np.ones((4, 4)) * np.arange(4)
# position of the bars
xpos, ypos = np.meshgrid(np.arange(4), np.arange(4))

xpos = xpos.flatten('F')

ypos = ypos.flatten('F')

zpos = np.zeros_like(xpos)


dx = 0.5 * np.ones_like(zpos)
dy = dx.copy()
dz = z.flatten()

ax.bar3d(xpos, ypos, zpos, dx, dy, dz, color='b', zsort='average')

plt.show()

Output the above code:

enter image description here

like image 449
QuantChris Avatar asked Mar 10 '23 11:03

QuantChris


1 Answers

Let me first say that matplotlib may not be the tool of choice when it comes to sophisticated 3D plots.

That said, there is no built-in method to produce bar plots with differing colors over the extend of the bar.

We therefore need to mimic the bar somehow. A possible solution can be found below. Here, we use a plot_surface plot to create a bar that contains a gradient.

enter image description here

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib.colors 
import numpy as np

fig = plt.figure()
ax = fig.add_subplot(111, projection= Axes3D.name)

def make_bar(ax, x0=0, y0=0, width = 0.5, height=1 , cmap="viridis",  
              norm=matplotlib.colors.Normalize(vmin=0, vmax=1), **kwargs ):
    # Make data
    u = np.linspace(0, 2*np.pi, 4+1)+np.pi/4.
    v_ = np.linspace(np.pi/4., 3./4*np.pi, 100)
    v = np.linspace(0, np.pi, len(v_)+2 )
    v[0] = 0 ;  v[-1] = np.pi; v[1:-1] = v_
    x = np.outer(np.cos(u), np.sin(v))
    y = np.outer(np.sin(u), np.sin(v))
    z = np.outer(np.ones(np.size(u)), np.cos(v))

    xthr = np.sin(np.pi/4.)**2 ;  zthr = np.sin(np.pi/4.)
    x[x > xthr] = xthr; x[x < -xthr] = -xthr
    y[y > xthr] = xthr; y[y < -xthr] = -xthr
    z[z > zthr] = zthr  ; z[z < -zthr] = -zthr

    x *= 1./xthr*width; y *= 1./xthr*width
    z += zthr
    z *= height/(2.*zthr)
    #translate
    x += x0; y += y0
    #plot
    ax.plot_surface(x, y, z, cmap=cmap, norm=norm, **kwargs)

def make_bars(ax, x, y, height, width=1):
    widths = np.array(width)*np.ones_like(x)
    x = np.array(x).flatten()
    y = np.array(y).flatten()

    h = np.array(height).flatten()
    w = np.array(widths).flatten()
    norm = matplotlib.colors.Normalize(vmin=0, vmax=h.max())
    for i in range(len(x.flatten())):
        make_bar(ax, x0=x[i], y0=y[i], width = w[i] , height=h[i], norm=norm)


X, Y = np.meshgrid([1,2,3], [2,3,4])
Z = np.sin(X*Y)+1.5

make_bars(ax, X,Y,Z, width=0.2, )
plt.show()
like image 134
ImportanceOfBeingErnest Avatar answered Mar 12 '23 06:03

ImportanceOfBeingErnest