matplotlib imshow complex 2D array

Is there any good way how to plot 2D array of complex numbers as image in mathplotlib ?

It makes very much sense to map magnitude of complex number as "brightness" or "saturation" and phase as "Hue" ( anyway Hue is nothing else than phase in RBG color space). http://en.wikipedia.org/wiki/HSL_and_HSV

But as far as I know imshow does accept only scalar values which are then mapped using some colorscale. There is nothing like ploting real RGB pictures?

I thing it would be easy just implement a version which accepts 2D array of tuples (vectors) of 3 floating point numbers or ndarray of floats of shape [:,:,3]. I guess this would be generally usefful feature. It would be also usefull for plotting real RGB colord images, such as textures outputted from OpenCL

2 Answers

this does almost the same of @Hooked code but very much faster.

import numpy as np
from numpy import pi
import pylab as plt
from colorsys import hls_to_rgb

def colorize(z):
    r = np.abs(z)
    arg = np.angle(z) 

    h = (arg + pi)  / (2 * pi) + 0.5
    l = 1.0 - 1.0/(1.0 + r**0.3)
    s = 0.8

    c = np.vectorize(hls_to_rgb) (h,l,s) # --> tuple
    c = np.array(c)  # -->  array of (3,n,m) shape, but need (n,m,3)
    c = c.swapaxes(0,2) 
    return c

x,y = np.ogrid[-5:5:N*1j, -5:5:N*1j]
z = x + 1j*y

w = 1/(z+1j)**2 + 1/(z-2)**2
img = colorize(w)
Adapting the plotting code from mpmath you can plot a numpy array even if you don't known the original function with numpy and matplotlib. If you do know the function, see my original answer using mpmath.cplot.

from colorsys import hls_to_rgb

def colorize(z):
    n,m = z.shape
    c = np.zeros((n,m,3))
    c[np.isinf(z)] = (1.0, 1.0, 1.0)
    c[np.isnan(z)] = (0.5, 0.5, 0.5)

    idx = ~(np.isinf(z) + np.isnan(z))
    A = (np.angle(z[idx]) + np.pi) / (2*np.pi)
    A = (A + 0.5) % 1.0
    B = 1.0 - 1.0/(1.0+abs(z[idx])**0.3)
    c[idx] = [hls_to_rgb(a, b, 0.8) for a,b in zip(A,B)]
    return c

From here, you can plot an arbitrary complex numpy array:

N = 1000
A = np.zeros((N,N),dtype='complex')
axis_x = np.linspace(-5,5,N)
axis_y = np.linspace(-5,5,N)
X,Y = np.meshgrid(axis_x,axis_y)
Z = X + Y*1j

A = 1/(Z+1j)**2 + 1/(Z-2)**2

# Plot the array "A" using colorize
import pylab as plt
plt.imshow(colorize(A), interpolation='none',extent=(-5,5,-5,5))

enter image description here

