Using sklearn's StratifiedKFold function, can someone help me understand the error here?
My guess is that it has something to do with my input array of labels, I notice when I print them (the first 16 in this example) the indexing goes from 0 to 15, but an extra 0 is printed above that I wasn't expecting. Maybe I'm just a python noob, but that looks weird.
Anyone see the goof-up here?
Documentation: http://scikit-learn.org...StratifiedKFold.html
Code:
import nltk
import sklearn
print('The nltk version is {}.'.format(nltk.__version__))
print('The scikit-learn version is {}.'.format(sklearn.__version__))
print type(skew_gendata_targets.values), skew_gendata_targets.values.shape
print skew_gendata_targets.head(16)
skew_sfold10 = cross_validation.StratifiedKFold(skew_gendata_targets.values, n_folds=10, shuffle=True, random_state=20160121)
Result
The nltk version is 3.1.
The scikit-learn version is 0.17.
<type 'numpy.ndarray'> (500L, 1L)
0
0 0
1 0
2 0
3 0
4 0
5 0
6 0
7 0
8 0
9 0
10 0
11 0
12 0
13 0
14 1
15 0
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-373-653b6010b806> in <module>()
8 print skew_gendata_targets.head(16)
9
---> 10 skew_sfold10 = cross_validation.StratifiedKFold(skew_gendata_targets.values, n_folds=10, shuffle=True, random_state=20160121)
11
12 #print '\nSkewed Generated Dataset (', len(skew_gendata_data), ')'
d:\Program Files\Anaconda2\lib\site-packages\sklearn\cross_validation.pyc in __init__(self, y, n_folds, shuffle, random_state)
531 for test_fold_idx, per_label_splits in enumerate(zip(*per_label_cvs)):
532 for label, (_, test_split) in zip(unique_labels, per_label_splits):
--> 533 label_test_folds = test_folds[y == label]
534 # the test split can be too big because we used
535 # KFold(max(c, self.n_folds), self.n_folds) instead of
IndexError: too many indices for array
Check the shape of skew_gendata_targets.values
. You'll see that it isn't a 1d array (shape (500,) ) as StratifiedKFold expects, but rather a (500,1) array. SKlearn treats these separately rather than coercing them to be the same. Let me know if that helps
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With