Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numpy: vectorization for multiple values

Imagine you have an RGB image and want to process every pixel:

import numpy as np
image = np.zeros((1024, 1024, 3))

def rgb_to_something(rgb):
    pass

vfunc = np.vectorize(rgb_to_something)
vfunc(image)

vfunc should now get every RGB value. The problem is that numpy flattens the array and the function gets r0, g0, b0, r1, g1, b1, ... when it should get rgb0, rgb1, rgb2, .... Can this be done somehow?

http://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html

Maybe by converting the numpy array to some special datatype beforehand?

For example (of course not working):

image = image.astype(np.float32)
import ctypes
RGB = ctypes.c_float * 3
image.astype(RGB)
ValueError: shape mismatch: objects cannot be broadcast to a single shape

Update: The main purpose is efficiency here. A non vectorized version could simply look like this:

import numpy as np
image = np.zeros((1024, 1024, 3))
shape = image.shape[0:2]
image = image.reshape((-1, 3))
def rgb_to_something((r, g, b)):
    return r + g + b
transformed_image = np.array([rgb_to_something(rgb) for rgb in image]).reshape(shape)
like image 511
tauran Avatar asked Mar 14 '12 11:03

tauran


1 Answers

The easy way to solve this kind of problem is to pass the entire array to the function and used vectorized idioms inside it. Specifically, your rgb_to_something can also be written

def rgb_to_something(pixels):
    return pixels.sum(axis=1)

which is about 15 times faster than your version:

In [16]: %timeit np.array([old_rgb_to_something(rgb) for rgb in image]).reshape(shape)
1 loops, best of 3: 3.03 s per loop

In [19]: %timeit image.sum(axis=1).reshape(shape)
1 loops, best of 3: 192 ms per loop

The problem with np.vectorize is that it necessarily incurs a lot of Python function call overhead when applied to large arrays.

like image 100
Fred Foo Avatar answered Oct 10 '22 09:10

Fred Foo