Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What does X_set[y_set == j, 0] mean?

Recently, I have been following a tutorial where I came up with the following code

for i, j in enumerate(np.unique(y_set)):
    plt.scatter(X_set[y_set == j, 0], X_set[y_set == j, 1],
        c = ListedColormap(('red', 'green'))(i), label = j)

here, y_set is a vector having binary values 0, 1 and X_set is an array with two columns. I am specifically not understanding how to interpret the following line of code

X_set[y_set == j, 0], X_set[y_set == j, 1]
like image 659
Becky Avatar asked Jul 31 '18 12:07

Becky


1 Answers

There's a few things going on here. For now, I will drop the loop but we know that j will take values in y_set so will be either zero or one. First make the two arrays:

import numpy as np

X_set = np.arange(20).reshape(10, 2)
y_set = np.array([0, 1, 1, 1, 0, 0, 1, 1, 0, 1])

From the above, this code is basically doing:

plt.scatter(filtered_values_in_first_column_of X_set, 
            filtered_values_in_second_column_of X_set)

y_set is providing the filter. We can get there by building up the steps:

print("Where y_set == 0: Boolean mask.")
print(y_set == 0)
print()

print("All rows of X_set indexed by the Boolean mask")
print(X_set[y_set == 0])
print()

print("2D indexing to get only the first column of the above")
print(X_set[y_set == 0, 0])
print()

You can see more on the numpy indexing here. Once you break the steps down, it's not too complicated but I think it was an unnecessarily complex way of achieving this task.

The for loop is so that they could repeat the plot with two different colours depending on whether the values are filtered by y_set being equal to 0 or 1.

like image 73
roganjosh Avatar answered Sep 30 '22 11:09

roganjosh