Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

plotting multiple plots generated inside a for loop on the same axes python

My code is as follows, the problem is instead of having one plot, I get 242 plots. I tried putting the plt.show() outside the loop, it didn't work.

import numpy as np
import matplotlib.pyplot as plt
import csv

names = list()

with open('selected.csv','rb') as infile:
    reader = csv.reader(infile, delimiter = ' ')
    for row in reader:
        names.append(row[0])

names.pop(0)

for j in range(len(names)):
    filename = '/home/mh/Masters_Project/Sigma/%s.dat' %(names[j])
    average, sigma = np.loadtxt(filename, usecols = (0,1), unpack = True, delimiter = ' ')
    name = '%s' %(names[j]) 
    plt.figure()
    plt.xlabel('Magnitude(average)', fontsize = 16)
    plt.ylabel('$\sigma$', fontsize = 16)
    plt.plot(average, sigma, marker = '+', linestyle = '', label = name)
plt.legend(loc = 'best')
plt.show()
like image 916
Michael Hlabathe Avatar asked Oct 14 '14 07:10

Michael Hlabathe


2 Answers

Your issue is that you're creating a new figure with every iteration using plt.figure(). Remove this line from your for loop and it should work fine, as this short example below shows.

import matplotlib.pyplot as plt
import numpy as np

x = np.arange(10)

for a in [1.0, 2.0, 3.0]:
    plt.plot(x, a*x)

plt.show()

Example plot

like image 161
Ffisegydd Avatar answered Oct 22 '22 15:10

Ffisegydd


Let me improve your code a bit:

import numpy as np
import matplotlib.pyplot as plt

# set the font size globally to get the ticklabels big too:
plt.rcParams["font.size"] = 16

# use numpy to read in the names
names = np.genfromtxt("selected.csv", delimiter=" ", dtype=np.str, skiprows=1)

# not necessary butyou might want to add options to the figure
plt.figure()

# don't use a for i in range loop to loop over array elements
for name in names:
    # use the format function
    filename = '/home/mh/Masters_Project/Sigma/{}.dat'.format(name)

    # use genfromtxt because of better error handling (missing numbers, etc)
    average, sigma = np.genfromtxt(filename, usecols = (0,1), unpack = True, delimiter = ' ')

    plt.xlabel('Magnitude(average)')
    plt.ylabel('$\sigma$')
    plt.plot(average, sigma, marker = '+', linestyle = '', label = name)

plt.legend(loc = 'best')
plt.show()
like image 28
MaxNoe Avatar answered Oct 22 '22 15:10

MaxNoe