Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

converting numpy array of string fields to numerical format

Tags:

python

numpy

I have an array of strings grouped into three fields:

x = np.array([(-1, 0, 1),
              (-1, 1, 0),
              (0, 1, -1),
              (0, -1, 1)],
             dtype=[('a', 'S2'),
                    ('b', 'S2'),
                    ('c', 'S2')])

I would like to convert to a numerical array (of type np.int8 for a preference, but not required), shaped 4x3, instead of the fields.

My general approach is to transform into a 4x3 array of type 'S2', then use astype to make it numerical. The only problem is that the only approach I can think of involves both view and np.lib.stride_tricks.as_strided, which doesn't seem like a very robust solution:

y = np.lib.stride_tricks.as_strided(x.view(dtype='S2'),
                                    shape=(4, 3), strides=(6, 2))
z = y.astype(np.int8)

This works for the toy case shown here, but I feel like there must be a simpler way to unpack an array with fields all having the same dtype. What is a more robust alternative?

like image 665
Mad Physicist Avatar asked Nov 07 '22 22:11

Mad Physicist


1 Answers

The latest version of numpy 1.16 added structured_to_unstructured which solves this purpose:

from numpy.lib.recfunctions import structured_to_unstructured
y = structured_to_unstructured(x)  # 2d array of 'S2'
z = y.astype(np.int8)

In previous version of numpy, you can combine x.data and np.frombuffer to create another array from the same data in memory without having to use strides. It doesn't bring performance gain though, as the computation is driven by the casting from S2 to int8.

n = 1000

def f1(x):
    y = np.lib.stride_tricks.as_strided(x.view(dtype='S2'),
                                        shape=(n, 3),
                                        strides=(6, 2))
    return y.astype(np.int8)

def f2(x):
    y = np.frombuffer(x.data, dtype='S2').reshape((n, 3))
    return y.astype(np.int8)


x = np.array([(i%3-1, (i+1)%3-1, (i+2)%3-1)
              for i in xrange(n)],
             dtype='S2,S2,S2')

z1 = f1(x)
z2 = f2(x)
assert (z1==z2).all()
like image 134
M1L0U Avatar answered Nov 13 '22 16:11

M1L0U