Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to print variable name as a title in matplotlib

My goal is to create a simple function that titles a graph with the name of the variable that's been plotted.

So far I have:

def comparigraphside(rawvariable, filtervariable, cut):
    variable = rawvariable[filtervariable > 0]
    upperbound = np.mean(variable) + 3*np.std(variable)
    plt.figure(figsize=(20,5))
    plt.subplot(121)
    plt.hist(variable[filtervariable <= cut], bins=20, range=(0,upperbound), normed=True)
    plt.title("%s customers with filter less than or equal to %s" % (len(variable[filtervariable <= cut]), cut))
    plt.subplot(122)
    plt.hist(variable[filtervariable > cut], bins=20, range=(0,upperbound), normed=True)
    plt.title("%s customers with filter greater than %s" % (len(variable[filtervariable > cut]), cut));

And where it's:

plt.title("%s customers with filter less/greater...") 

I'd love it to say:

plt.title("%s customers with %s less/greater...")

At the moment the only solution I can think of involves making a dictionary of my variables, which I would like to avoid. Any and all assistance is much appreciated.

like image 743
W. Gillett Avatar asked Nov 10 '22 04:11

W. Gillett


1 Answers

It is not possible to easily get the name of the variable in python (see this answer). For variables passed to a function in python, there are hacky solutions using inspect, details here with solution for your case based on this answer,

import matplotlib.pyplot as plt
import numpy as np
import inspect
import re

def comparigraphside(rawvariable, filtervariable, cut):

    calling_frame_record = inspect.stack()[1]
    frame = inspect.getframeinfo(calling_frame_record[0])
    m = re.search( "comparigraphside\((.+)\)", frame.code_context[0])
    if m:
        rawvariablename = m.group(1).split(',')[0]

    variable = rawvariable[filtervariable > 0]
    filtervariable = filtervariable[filtervariable > 0]
    upperbound = np.mean(variable) + 3*np.std(variable)
    plt.figure(figsize=(20,5))
    plt.subplot(121)
    plt.hist(variable[filtervariable <= cut], bins=20, range=(0,upperbound), normed=True)
    title = "%s customers with %s less than or equal to %s" % (len(variable[filtervariable <= cut]), rawvariablename, cut)
    plt.title(title)
    plt.subplot(122)
    plt.hist(variable[filtervariable > cut], bins=20, range=(0,upperbound), normed=True)
    plt.title("%s customers with %s greater than %s" % (len(variable[filtervariable > cut]), rawvariablename, cut));


#A solution using inspect
normdist = np.random.randn(1000)
randdist = np.random.rand(1000)

comparigraphside(normdist, normdist, 0.7)
plt.show()

comparigraphside(randdist, normdist, 0.7)
plt.show()

However, another possible solution, which may be neater in your case is to use **kwargs in your function and then the defined variable name on the command line will be what is printed, e.g.,

import matplotlib.pyplot as plt
import numpy as np

normdist = np.random.randn(1000)
randdist = np.random.rand(1000)

#Another solution using kwargs
def print_fns(**kwargs):
    for name, value in kwargs.items():
        plt.hist(value)
        plt.title(name)

print_fns(normal_distribution=normdist)
plt.show()

print_fns(random_distribution=randdist)
plt.show()

Personally, for anything other than a quick plotting script, I'd define a dictionary of all the variables you'd like to plot, with names for each, and pass this into the function. This is more explicit and ensures you have no problems if you use this plotting as part of a larger code...

like image 169
Ed Smith Avatar answered Nov 14 '22 23:11

Ed Smith