Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

StratifiedKFold : IndexError: too many indices for array

Using sklearn's StratifiedKFold function, can someone help me understand the error here?

My guess is that it has something to do with my input array of labels, I notice when I print them (the first 16 in this example) the indexing goes from 0 to 15, but an extra 0 is printed above that I wasn't expecting. Maybe I'm just a python noob, but that looks weird.

Anyone see the goof-up here?

Documentation: http://scikit-learn.org...StratifiedKFold.html

Code:

import nltk
import sklearn

print('The nltk version is {}.'.format(nltk.__version__))
print('The scikit-learn version is {}.'.format(sklearn.__version__))

print type(skew_gendata_targets.values), skew_gendata_targets.values.shape
print skew_gendata_targets.head(16)

skew_sfold10 = cross_validation.StratifiedKFold(skew_gendata_targets.values, n_folds=10, shuffle=True, random_state=20160121)

Result

The nltk version is 3.1.
The scikit-learn version is 0.17.
<type 'numpy.ndarray'> (500L, 1L)
    0
0   0
1   0
2   0
3   0
4   0
5   0
6   0
7   0
8   0
9   0
10  0
11  0
12  0
13  0
14  1
15  0
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-373-653b6010b806> in <module>()
      8 print skew_gendata_targets.head(16)
      9 
---> 10 skew_sfold10 = cross_validation.StratifiedKFold(skew_gendata_targets.values, n_folds=10, shuffle=True, random_state=20160121)
     11 
     12 #print '\nSkewed Generated Dataset (', len(skew_gendata_data), ')'

d:\Program Files\Anaconda2\lib\site-packages\sklearn\cross_validation.pyc in __init__(self, y, n_folds, shuffle, random_state)
    531         for test_fold_idx, per_label_splits in enumerate(zip(*per_label_cvs)):
    532             for label, (_, test_split) in zip(unique_labels, per_label_splits):
--> 533                 label_test_folds = test_folds[y == label]
    534                 # the test split can be too big because we used
    535                 # KFold(max(c, self.n_folds), self.n_folds) instead of

IndexError: too many indices for array
like image 262
David Parks Avatar asked Jan 26 '16 19:01

David Parks


1 Answers

Check the shape of skew_gendata_targets.values. You'll see that it isn't a 1d array (shape (500,) ) as StratifiedKFold expects, but rather a (500,1) array. SKlearn treats these separately rather than coercing them to be the same. Let me know if that helps

like image 86
Brian Avatar answered Nov 09 '22 19:11

Brian