Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to obtain reproducible but distinct instances of GroupKFold

In the GroupKFold source, the random_state is set to None

    def __init__(self, n_splits=3):
    super(GroupKFold, self).__init__(n_splits, shuffle=False,
                                     random_state=None)

Hence, when run multiple times (code from here)

import numpy as np
from sklearn.model_selection import GroupKFold

for i in range(0,10):
    X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
    y = np.array([1, 2, 3, 4])
    groups = np.array([0, 0, 2, 2])
    group_kfold = GroupKFold(n_splits=2)
    group_kfold.get_n_splits(X, y, groups)

    print(group_kfold)

    for train_index, test_index in group_kfold.split(X, y, groups):
        print("TRAIN:", train_index, "TEST:", test_index)
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        print(X_train, X_test, y_train, y_test)
    print 
    print 

o/p

GroupKFold(n_splits=2)
('TRAIN:', array([0, 1]), 'TEST:', array([2, 3]))
(array([[1, 2],
       [3, 4]]), array([[5, 6],
       [7, 8]]), array([1, 2]), array([3, 4]))
('TRAIN:', array([2, 3]), 'TEST:', array([0, 1]))
(array([[5, 6],
       [7, 8]]), array([[1, 2],
       [3, 4]]), array([3, 4]), array([1, 2]))


GroupKFold(n_splits=2)
('TRAIN:', array([0, 1]), 'TEST:', array([2, 3]))
(array([[1, 2],
       [3, 4]]), array([[5, 6],
       [7, 8]]), array([1, 2]), array([3, 4]))
('TRAIN:', array([2, 3]), 'TEST:', array([0, 1]))
(array([[5, 6],
       [7, 8]]), array([[1, 2],
       [3, 4]]), array([3, 4]), array([1, 2]))

etc ...

The splits are identical.

How do I set a random_state for GroupKFold in order to get a different (but repoducible) set of splits over a few different trials of cross validation?

Eg, I want

GroupKFold(n_splits=2, random_state=42)
('TRAIN:', array([0, 1]), 
  'TEST:', array([2, 3]))

('TRAIN:', array([2, 3]), 
'TEST:', array([0, 1]))


GroupKFold(n_splits=2, random_state=13)
('TRAIN:', array([0, 2]), 
 'TEST:', array([1, 3]))

('TRAIN:', array([1, 3]), 
'TEST:', array([0, 2]))

So far, it seems a strategy might be to use a sklearn.utils.shuffle first, as suggested in this post. However, this actually just rearranges the elements of each fold --- it doesn't give us new splits.

from sklearn.utils import shuffle
from sklearn.model_selection import GroupKFold
import numpy as np
import sys
import pdb

random_state = int(sys.argv[1])


X = np.arange(20).reshape((10,2))
y = np.arange(10)
groups = np.array([0,0,0,1,2,3,4,5,6,7])

def cv(X, y, groups, random_state):
    X_s, y_s, groups_s = shuffle(X,y, groups, random_state=random_state)
    cv_out = GroupKFold(n_splits=2)
    cv_out_splits = cv_out.split(X_s, y_s, groups_s)
    for train, test in cv_out_splits:
        print "---"
        print X_s[test]
        print y_s[test]
        print "test groups", groups_s[test]
        print "train groups", groups_s[train]
    pdb.set_trace()
print "***"
cv(X, y, groups, random_state)

The output:

>python sshuf.py 32

***
---
[[ 2  3]
 [ 4  5]
 [ 0  1]
 [ 8  9]
 [12 13]]
[1 2 0 4 6]
test groups [0 0 0 2 4]
train groups [7 6 1 3 5]
---
[[18 19]
 [16 17]
 [ 6  7]
 [10 11]
 [14 15]]
[9 8 3 5 7]
test groups [7 6 1 3 5]
train groups [0 0 0 2 4]

>python sshuf.py 234

***
---
[[12 13]
 [ 4  5]
 [ 0  1]
 [ 2  3]
 [ 8  9]]
[6 2 0 1 4]
test groups [4 0 0 0 2]
train groups [7 3 1 5 6]
---
[[18 19]
 [10 11]
 [ 6  7]
 [14 15]
 [16 17]]
[9 5 3 7 8]
test groups [7 3 1 5 6]
train groups [4 0 0 0 2]
like image 607
user0 Avatar asked Jan 25 '17 19:01

user0


3 Answers

  • KFold is only randomized if shuffle=True. Some datasets should not be shuffled.
  • GroupKFold is not randomized at all. Hence the random_state=None.
  • GroupShuffleSplit may be closer to what you're looking for.

A comparison of the group-based splitters:

  • In GroupKFold, the test sets form a complete partition of all the data.
  • LeavePGroupsOut leaves all possible subsets of P groups out, combinatorially; test sets will overlap for P > 1. Since this means P ** n_groups splits altogether, often you want a small P, and most often want LeaveOneGroupOut which is basically the same as GroupKFold with k=1.
  • GroupShuffleSplit makes no statement about the relationship between successive test sets; each train/test split is performed independently.

As an aside, Dmytro Lituiev has proposed an alternative GroupShuffleSplit algorithm which is better at getting the right number of samples (not merely the right number of groups) in the test set for a specified test_size.

like image 193
joeln Avatar answered Nov 17 '22 10:11

joeln


Inspired by user0's answer (can't comment) but faster:

def RandomGroupKFold_split(groups, n, seed=None):  # noqa: N802
    """
    Random analogous of sklearn.model_selection.GroupKFold.split.

    :return: list of (train, test) indices
    """
    groups = pd.Series(groups)
    ix = np.arange(len(groups))
    unique = np.unique(groups)
    np.random.RandomState(seed).shuffle(unique)
    result = []
    for split in np.array_split(unique, n):
        mask = groups.isin(split)
        train, test = ix[~mask], ix[mask]
        result.append((train, test))

    return result
like image 5
xrr Avatar answered Nov 17 '22 11:11

xrr


My solution so far has been to simply randomly split the groups. This could lead to very unbalanced groups (which I think GroupKFold was designed to ward off), but the hope is that the number of observations per group is small.

from sklearn.utils import shuffle
from sklearn.model_selection import GroupKFold
from numpy.random import RandomState
import numpy as np
import sys
import pdb

random_state = int(sys.argv[1])


X = np.arange(20).reshape((10,2))


y = np.arange(10)
groups = np.array([0,0,0,1,2,3,4,5,6,7])
for el in zip(range(len(y)),X,y,groups):
    print "ix, X, y, groups", el

def RandGroupKfold(groups, n_splits, random_state=None, shuffle_groups=False):

    ix = np.array(range(len(groups)))
    unique_groups = np.unique(groups)
    if shuffle_groups:
        prng = RandomState(random_state)
        prng.shuffle(unique_groups)
    splits = np.array_split(unique_groups, n_splits)
    train_test_indices = []

    for split in splits:
        mask = [el in split for el in groups]
        train = ix[np.invert(mask)]
        test = ix[mask]
        train_test_indices.append((train, test))
    return train_test_indices

splits = RandGroupKfold(groups, n_splits=3, random_state=random_state, shuffle_groups=True)

for train, test in splits:
    print "---"
    for el in zip(train, X[train], y[train], groups[train]):
        print "train ix, X, y, groups", el
    for el in zip(test, X[test], y[test], groups[test]):
        print "test ix, X, y, groups", el

Data:

ix, X, y, groups (0, array([0, 1]), 0, 0)
ix, X, y, groups (1, array([2, 3]), 1, 0)
ix, X, y, groups (2, array([4, 5]), 2, 0)
ix, X, y, groups (3, array([6, 7]), 3, 1)
ix, X, y, groups (4, array([8, 9]), 4, 2)
ix, X, y, groups (5, array([10, 11]), 5, 3)
ix, X, y, groups (6, array([12, 13]), 6, 4)
ix, X, y, groups (7, array([14, 15]), 7, 5)
ix, X, y, groups (8, array([16, 17]), 8, 6)
ix, X, y, groups (9, array([18, 19]), 9, 7)

Random state as 4

---
train ix, X, y, groups (0, array([0, 1]), 0, 0)
train ix, X, y, groups (1, array([2, 3]), 1, 0)
train ix, X, y, groups (2, array([4, 5]), 2, 0)
train ix, X, y, groups (3, array([6, 7]), 3, 1)
train ix, X, y, groups (4, array([8, 9]), 4, 2)
train ix, X, y, groups (7, array([14, 15]), 7, 5)
train ix, X, y, groups (8, array([16, 17]), 8, 6)
test ix, X, y, groups (5, array([10, 11]), 5, 3)
test ix, X, y, groups (6, array([12, 13]), 6, 4)
test ix, X, y, groups (9, array([18, 19]), 9, 7)
---
train ix, X, y, groups (4, array([8, 9]), 4, 2)
train ix, X, y, groups (5, array([10, 11]), 5, 3)
train ix, X, y, groups (6, array([12, 13]), 6, 4)
train ix, X, y, groups (8, array([16, 17]), 8, 6)
train ix, X, y, groups (9, array([18, 19]), 9, 7)
test ix, X, y, groups (0, array([0, 1]), 0, 0)
test ix, X, y, groups (1, array([2, 3]), 1, 0)
test ix, X, y, groups (2, array([4, 5]), 2, 0)
test ix, X, y, groups (3, array([6, 7]), 3, 1)
test ix, X, y, groups (7, array([14, 15]), 7, 5)
---
train ix, X, y, groups (0, array([0, 1]), 0, 0)
train ix, X, y, groups (1, array([2, 3]), 1, 0)
train ix, X, y, groups (2, array([4, 5]), 2, 0)
train ix, X, y, groups (3, array([6, 7]), 3, 1)
train ix, X, y, groups (5, array([10, 11]), 5, 3)
train ix, X, y, groups (6, array([12, 13]), 6, 4)
train ix, X, y, groups (7, array([14, 15]), 7, 5)
train ix, X, y, groups (9, array([18, 19]), 9, 7)
test ix, X, y, groups (4, array([8, 9]), 4, 2)
test ix, X, y, groups (8, array([16, 17]), 8, 6)

Random state as 5

---
train ix, X, y, groups (0, array([0, 1]), 0, 0)
train ix, X, y, groups (1, array([2, 3]), 1, 0)
train ix, X, y, groups (2, array([4, 5]), 2, 0)
train ix, X, y, groups (3, array([6, 7]), 3, 1)
train ix, X, y, groups (5, array([10, 11]), 5, 3)
train ix, X, y, groups (7, array([14, 15]), 7, 5)
train ix, X, y, groups (8, array([16, 17]), 8, 6)
test ix, X, y, groups (4, array([8, 9]), 4, 2)
test ix, X, y, groups (6, array([12, 13]), 6, 4)
test ix, X, y, groups (9, array([18, 19]), 9, 7)
---
train ix, X, y, groups (4, array([8, 9]), 4, 2)
train ix, X, y, groups (5, array([10, 11]), 5, 3)
train ix, X, y, groups (6, array([12, 13]), 6, 4)
train ix, X, y, groups (8, array([16, 17]), 8, 6)
train ix, X, y, groups (9, array([18, 19]), 9, 7)
test ix, X, y, groups (0, array([0, 1]), 0, 0)
test ix, X, y, groups (1, array([2, 3]), 1, 0)
test ix, X, y, groups (2, array([4, 5]), 2, 0)
test ix, X, y, groups (3, array([6, 7]), 3, 1)
test ix, X, y, groups (7, array([14, 15]), 7, 5)
---
train ix, X, y, groups (0, array([0, 1]), 0, 0)
train ix, X, y, groups (1, array([2, 3]), 1, 0)
train ix, X, y, groups (2, array([4, 5]), 2, 0)
train ix, X, y, groups (3, array([6, 7]), 3, 1)
train ix, X, y, groups (4, array([8, 9]), 4, 2)
train ix, X, y, groups (6, array([12, 13]), 6, 4)
train ix, X, y, groups (7, array([14, 15]), 7, 5)
train ix, X, y, groups (9, array([18, 19]), 9, 7)
test ix, X, y, groups (5, array([10, 11]), 5, 3)
test ix, X, y, groups (8, array([16, 17]), 8, 6)
like image 2
user0 Avatar answered Nov 17 '22 10:11

user0