I construct a pandas dataframe like so:
import pandas as pd
import numpy as np
df = pd.DataFrame(np.random.randn(100,3), columns=['A','B', 'C'])
df['X'] = np.random.choice(['Alpha', 'Beta', 'Theta'], size=100)
Which gives me df.head()
:
A B C X
0 2.279163 -1.790076 1.187603 Beta
1 -0.590897 0.837605 -0.606424 Alpha
2 0.448334 -1.142946 0.002507 Beta
3 0.540165 -0.204184 1.389645 Beta
4 0.105643 -1.298379 -1.404680 Beta
Now if I plot Andrews Curves using column 'X' -- which has one of three values -- as the class, I expect to see 100 curves with three colors, based on class X. Instead, each curve has its own color.
pd.tools.plotting.andrews_curves(df, 'X')
(The legend looks as expected, which is interesting.)
Is there a bug here or am I misunderstanding things?
The following fixes the pandas code (https://github.com/pydata/pandas/pull/5378):
from pandas.compat import range, lrange, lmap, map, zip
from pandas.tools.plotting import _get_standard_colors
import pandas.core.common as com
def andrews_curves(data, class_column, ax=None, samples=200, colormap=None,
**kwds):
"""
Parameters:
-----------
data : DataFrame
Data to be plotted, preferably normalized to (0.0, 1.0)
class_column : Name of the column containing class names
ax : matplotlib axes object, default None
samples : Number of points to plot in each curve
colormap : str or matplotlib colormap object, default None
Colormap to select colors from. If string, load colormap with that name
from matplotlib.
kwds : Optional plotting arguments to be passed to matplotlib
Returns:
--------
ax: Matplotlib axis object
"""
from math import sqrt, pi, sin, cos
import matplotlib.pyplot as plt
def function(amplitudes):
def f(x):
x1 = amplitudes[0]
result = x1 / sqrt(2.0)
harmonic = 1.0
for x_even, x_odd in zip(amplitudes[1::2], amplitudes[2::2]):
result += (x_even * sin(harmonic * x) +
x_odd * cos(harmonic * x))
harmonic += 1.0
if len(amplitudes) % 2 != 0:
result += amplitudes[-1] * sin(harmonic * x)
return result
return f
n = len(data)
class_col = data[class_column]
uniq_class = class_col.drop_duplicates()
columns = [data[col] for col in data.columns if (col != class_column)]
x = [-pi + 2.0 * pi * (t / float(samples)) for t in range(samples)]
used_legends = set([])
colors = _get_standard_colors(num_colors=len(uniq_class), colormap=colormap,
color_type='random', color=kwds.get('color'))
col_dict = dict([(klass, col) for klass, col in zip(uniq_class, colors)])
if ax is None:
ax = plt.gca(xlim=(-pi, pi))
for i in range(n):
row = [columns[c][i] for c in range(len(columns))]
f = function(row)
y = [f(t) for t in x]
label = None
if com.pprint_thing(class_col[i]) not in used_legends:
label = com.pprint_thing(class_col[i])
used_legends.add(label)
ax.plot(x, y, color=col_dict[class_col[i]], label=label, **kwds)
else:
ax.plot(x, y, color=col_dict[class_col[i]], **kwds)
ax.legend(loc='upper right')
ax.grid()
return ax
It looks like a bug, you can fix it by following code:
import pandas as pd
import numpy as np
df = pd.DataFrame(np.random.randn(100,3), columns=['A','B', 'C'])
df['X'] = np.random.choice(['Alpha', 'Beta', 'Theta'], size=100)
ax = pd.tools.plotting.andrews_curves(df, 'X')
colors = {l.get_label():l.get_color() for l in ax.lines}
for line, klass in zip(ax.lines, df["X"]):
line.set_color(colors[klass])
output:
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With