Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fill a multidimensional array efficiently that have many if else statements

I want to fill an 4dim numpy array in a specific and efficient way. Because I don't know better I startet to write the code with if else statements, but that doesn't look nice, is probably slow and I also can not be really sure if I thought about every combination. Here is the code which I stopped writing down:

sercnew2 = numpy.zeros((gn, gn, gn, gn))
for x1 in range(gn):
    for x2 in range(gn):
        for x3 in range(gn):
            for x4 in range(gn):
                if x1 == x2 == x3 == x4: 
                    sercnew2[x1, x2, x3, x4] = ewp[x1]
                elif x1 == x2 == x3 != x4:
                    sercnew2[x1, x2, x3, x4] = ewp[x1] * ewp[x4]
                elif x1 == x2 == x4 != x3:
                    sercnew2[x1, x2, x3, x4] = ewp[x1] * ewp[x3]
                elif x1 == x3 == x4 != x2:
                    sercnew2[x1, x2, x3, x4] = ewp[x1] * ewp[x2]
                elif x2 == x3 == x4 != x1:
                    sercnew2[x1, x2, x3, x4] = ewp[x2] * ewp[x1]
                elif x1 == x2 != x3 == x4:
                    sercnew2[x1, x2, x3, x4] = ewp[x1] * ewp[x3]
                elif ... many more combinations which have to be considered

So basically what should happen is, that if all variables (x1, x2, x3, x4) are different from each other, the entry would be:

sercnew2[x1, x2, x3, x4] = ewp[x1]* ewp[x2] * ewp[x3] * ewp[x4]

Now if lets say the variable x2 and x4 is the same then:

sercnew2[x1, x2, x3, x4] = ewp[x1]* ewp[x2] * ewp[x3]

Others examples can be seen in the code above. Basically if two or more variables are the same, then I only consider on of them. I hope the pattern is clear. Otherwise please let me note and I will try to express my problem better. I am pretty sure, that there is a much more intelligent way to do it. Hope you know better and thanks in advance :)

like image 755
HighwayJohn Avatar asked Dec 31 '25 18:12

HighwayJohn


1 Answers

Really hoping that I have got it! Here's a vectorized approach -

from itertools import product

n_dims = 4 # Number of dims

# Create 2D array of all possible combinations of X's as rows
idx = np.sort(np.array(list(product(np.arange(gn), repeat=n_dims))),axis=1)

# Get all X's indexed values from ewp array
vals = ewp[idx]

# Set the duplicates along each row as 1s. With the np.prod coming up next,
#these 1s would not affect the result, which is the expected pattern here.
vals[:,1:][idx[:,1:] == idx[:,:-1]] = 1

# Perform product along each row and reshape into multi-dim array
out = vals.prod(1).reshape([gn]*n_dims)
like image 74
Divakar Avatar answered Jan 03 '26 07:01

Divakar



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!