Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to type-hint a matplotlib.axes._subplots.AxesSubplots object in python3

I was wondering how is the "best" way to type-hint the axis-object of matplotlib-subplots.

running

from matplotlib import pyplot as plt

f, ax = plt.subplots()
print(type(ax))

returns

<class 'matplotlib.axes._subplots.AxesSubplot'>

and running

from matplotlib import axes
print(type(axes._subplots))
print(type(axes._subplots.AxesSubplot))

yields

<class 'module'>
AttributeError: module 'matplotlib.axes._subplots' has no attribute 'AxesSubplots'

So far a solution for type-hinting that works is as follows:

def multi_rocker(
                 axy: type(plt.subplots()[1]), 
                 y_trues: np.ndarray,
                 y_preds: np.ndarray,
                 ):
  """
  One-Vs-All ROC-curve:
  """
  fpr = dict()
  tpr = dict()
  roc_auc = dict()
  n_classes = y_trues.shape[1]
  wanted = list(range(n_classes))
  for i,x in enumerate(wanted):
    fpr[i], tpr[i], _ = roc_curve(y_trues[:, i], y_preds[:, i])
    roc_auc[i] = round(auc(fpr[i], tpr[i]),2)
  extra = 0
  for i in range(n_classes):
    axy.plot(fpr[i], tpr[i],)
  return

And the problem with it is that it isn't clear enough for code-sharing

like image 437
Gaston Avatar asked Sep 07 '20 19:09

Gaston


People also ask

How do you get axes of a figure in Python?

To get a list of axes of a figure, we will first create a figure and then, use get_axes() method to get the axes and set the labels of those axes. Create xs and ys using numpy and fig using figure() method.

How do I display an object in Axessubplot?

To show an axes subplot in Python, we can use show() method. When multiple figures are created, then those images are displayed using show() method.

What is axes object in matplotlib?

Axes object is the region of the image with the data space. A given figure can contain many Axes, but a given Axes object can only be in one Figure. The Axes contains two (or three in the case of 3D) Axis objects. The Axes class and its member functions are the primary entry point to working with the OO interface.


Video Answer


2 Answers

As described in Type hints for context manager :

import matplotlib.pyplot as plt

def plot_func(ax: plt.Axes):
    ...
like image 55
felice Avatar answered Oct 17 '22 15:10

felice


import matplotlib.axes as mpl_axes

def multi_rocker(
                 axy: mpl_axes.Axes, 
                 y_trues: np.ndarray,
                 y_preds: np.ndarray,
                 ):
like image 41
Ward Van Driessche Avatar answered Oct 17 '22 16:10

Ward Van Driessche