Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

ImportError: cannot import name 'StratifiedGroupKFold' from 'sklearn.model_selection'

I'm getting an ImportError when I try to use the StratifiedGroupKFold split from sklearn.

I noticed to use it the nightly build is required and I've installed it, yet I get the error. Any suggestions on how to solve this are welcome.

Edit: I tried updating conda and then scikit learn and even then it did not work. On installing the night build, it says requirements already satisfied.

Do I need to uninstall scikit-learn and then install the nightly build? Am I missing something here?

Edit 2: I created a new virtual environment and installed the nightly build in the same. Here, the import works fine on the terminal and Spyder but for some strange reason it does not work on jupyter-notebook (which is also installed in the same virtual environment).

like image 962
user42 Avatar asked Oct 19 '25 09:10

user42


1 Answers

Rather than trying to install nightly version, I would go for installing from the source. See instructions here

BTW, I don't see any issues for me, when importing StratifiedGroupKFold from the dev version

[ins] In [2]: import sklearn                                                                                                                                           
sklearn.__version__

[ins] In [3]: sklearn.__version__                                                                                                                                      
Out[3]: '1.0.dev0'

[ins] In [4]: import numpy as np 
         ...: from sklearn.model_selection import StratifiedGroupKFold                                                                                                 

[ins] In [5]:  

like image 185
Venkatachalam Avatar answered Oct 21 '25 23:10

Venkatachalam