Is there a simple/clean way to iterate an array of axis returned by subplots like
nrow = ncol = 2 a = [] fig, axs = plt.subplots(nrows=nrow, ncols=ncol) for i, row in enumerate(axs): for j, ax in enumerate(row): a.append(ax) for i, ax in enumerate(a): ax.set_ylabel(str(i))
which even works for nrow
or ncol == 1
.
I tried list comprehension like:
[element for tupl in tupleOfTuples for element in tupl]
but that fails if nrows
or ncols == 1
The ax
return value is a numpy array, which can be reshaped, I believe, without any copying of the data. If you use the following, you'll get a linear array that you can iterate over cleanly.
nrow = 1; ncol = 2; fig, axs = plt.subplots(nrows=nrow, ncols=ncol) for ax in axs.reshape(-1): ax.set_ylabel(str(i))
This doesn't hold when ncols and nrows are both 1, since the return value is not an array; you could turn the return value into an array with one element for consistency, though it feels a bit like a cludge:
nrow = 1; ncol = 1; fig, axs = plt.subplots(nrows=nrow, ncols=nrow) axs = np.array(axs) for ax in axs.reshape(-1): ax.set_ylabel(str(i))
reshape docs. The argument -1
causes reshape to infer dimensions of the output.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With