Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python itertools product, but conditional?

I have a function fun that takes several parameters p0,p1,.. For each parameter i give a list of possible values:

p0_list = ['a','b','c']
p1_list = [5,100]

I can now call my function for every combination of p0,p1

for i in itertools.product(*[p0,p1]):
    print fun(i)

Now comes the problem: What if i already know, that the parameter p1 only has an effect on the result of fun, if p0 is 'a' or 'c'? In this case i need my list of parameter combinations to look like:

[('a', 5), ('a',100), ('b', 5), ('c',5), ('c', 100)]

So ('b', 100) is just omitted, as it would be an unecessary evaluation of fun.

My final Solution:

param_lists = [['p0', ['a','b','c']],['p1', [5,100]]]
l = itertools.product(*[x[1] for x in param_lists])
l = [x for x in l if not x[0] == 'b' or x[1]==5]

I used this approach for 5 parameters and various conditions and it works fine. It's pretty easy to read as well. This code is inspired by Corley Brigmans' and nmcleans' answers.

like image 414
HeinzKurt Avatar asked Oct 02 '22 17:10

HeinzKurt


2 Answers

Here's a general filter function that could work for this:

def without_duplicate_item(groups, index, item):
    seen = False
    for group in groups:
        if group[index] == item:
            if seen:
                continue
            seen = True
        yield group

Usage:

param_groups = itertools.product(*[p0_list, p1_list])

param_groups = without_duplicate_item(param_groups, 0, "b")

You can of course keep adding filters for different parameters. This should be quite memory-efficient, compared to storing previous calls, because it essentially only stores one boolean value seen per filter.

like image 175
nmclean Avatar answered Oct 05 '22 12:10

nmclean


You could generate it, and then suppress 'duplicates'. But probably just better to generate separately:

p0_list = ['a', 'b', 'c']
p0_noarg1 = ['b']
sp0_noarg1 = set(p0_noarg1)
p0_arg1 = [x for x in p0_list if x not in sp0_noarg1]
p1_list = [5, 100]

total_list = [x for x in itertools.product(p0_arg1, p1_list)] + [x for x in itertools.product(p0_noarg1, p1_list[:1])]
like image 41
Corley Brigman Avatar answered Oct 05 '22 13:10

Corley Brigman