Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

group multiple plot in one figure python

My function return 28 plots ( figure) but i need to group them on one figure this is my code for generating 28 plots

for cat in df.ASS_ASSIGNMENT.unique() :
    a = df.loc[df['ASS_ASSIGNMENT'] == cat]
    dates = a['DATE']
    prediction = a['CSPL_RECEIVED_CALLS']
    plt.plot(dates,prediction)  
    plt.ylabel("nmb_app")
    plt.legend([cat.decode('utf-8')],loc='best')
    plt.xlabel(cat.decode('utf-8'))
like image 840
Amal Kostali Targhi Avatar asked Jan 05 '17 19:01

Amal Kostali Targhi


1 Answers

Use plt.subplots. For example,

import numpy as np
import matplotlib.pyplot as plt

fig, axes = plt.subplots(ncols=7, nrows=4)

for i, ax in enumerate(axes.flatten()):
    x = np.random.randint(-5, 5, 20)
    y = np.random.randint(-5, 5, 20)
    ax.scatter(x, y)
    ax.set_title('Axis {}'.format(i))

plt.tight_layout()

Going a little deeper, as Mauve points out, it depends if you want 28 curves in a single plot in a single figure or 28 individual plots each with its own axis all in one figure.

Assuming you have a dataframe, df, with 28 columns you can put all 28 curves on a single plot in a single figure using plt.subplots like so,

fig1, ax1 = plt.subplots()
df.plot(color=colors, ax=ax1)
plt.legend(ncol=4, loc='best')

enter image description here

If instead you want 28 individual axes all in one figure you can use plt.subplots this way

fig2, axes = plt.subplots(nrows=4, ncols=7)
for i, ax in enumerate(axes.flatten()):
    df[df.columns[i]].plot(color=colors[i], ax=ax)
    ax.set_title(df.columns[i])

enter image description here


Here df looks like

In [114]: df.shape
Out[114]: (15, 28)

In [115]: df.head()
Out[115]: 
         IYU        ZMK        DRO       UIC       DOF       ASG       DLU  \
0   0.970467   1.026171  -0.141261  1.719777  2.344803  2.956578  2.433358   
1   7.982833   7.667973   7.907016  7.897172  6.659990  5.623201  6.818639   
2   4.608682   4.494827   6.078604  5.634331  4.553364  5.418964  6.079736   
3   1.299400   3.235654   3.317892  2.689927  2.575684  4.844506  4.368858   
4  10.690242  10.375313  10.062212  9.150162  9.620630  9.164129  8.661847   

         BO1       JFN       S9Q    ...          X4K       ZQG       2TS  \
0   2.798409  2.425745  3.563515    ...     7.623710  7.678988  7.044471   
1   8.391905  7.242406  8.960973    ...     5.389336  5.083990  5.857414   
2   7.631030  7.822071  5.657916    ...     2.884925  2.570883  2.550461   
3   6.061272  4.224779  5.709211    ...     4.961713  5.803743  6.008319   
4  10.240355  9.792029  8.438934    ...     6.451223  5.072552  6.894701   

        RS0       P6T       FOU       LN9       CFG       C9D       ZG2  
0  9.380106  9.654287  8.065816  7.029103  7.701655  6.811254  7.315282  
1  3.931037  3.206575  3.728755  2.972959  4.436053  4.906322  4.796217  
2  3.784638  2.445668  1.423225  1.506143  0.786983 -0.666565  1.120315  
3  5.749563  7.084335  7.992780  6.998563  7.253861  8.845475  9.592453  
4  4.581062  5.807435  5.544668  5.249163  6.555792  8.299669  8.036408  

and was created by

import pandas as pd
import numpy as np
import string
import random

m = 28
n = 15

def random_data(m, n):
    return np.cumsum(np.random.randn(m*n)).reshape(m, n)

def id_generator(number, size=6, chars=string.ascii_uppercase + string.digits):
    sequence = []
    for n in range(number):
        sequence.append(''.join(random.choice(chars) for _ in range(size)))
    return sequence

df = pd.DataFrame(random_data(n, m), columns=id_generator(number=m, size=3))

Colors was defined as

import seaborn as sns
colors = sns.cubehelix_palette(28, rot=-0.4)
like image 192
lanery Avatar answered Oct 04 '22 21:10

lanery