I currently work on a classification problem and want to create visualizations of feature importance. I use the Python XGBoost package which already provides feature importance plots. However, I found shap (https://github.com/slundberg/shap), a Python library that creates very nice plots for feature importance based on tree classifiers. Everything works fine, I can also save the created plots as PNG, however, if I try to save it as PDF or SVG, I get an exception. Here is what I am doing:
First I train the XGBoost model and get the model back denoted by bst
.
train = remove_labels_for_binary_df(dataset_fc_baseline_1[0].train)
test = remove_labels_for_binary_df(dataset_fc_baseline_1[0].test)
results, bst = xgboost_with_bst(*transform_feat_to_num(train, test))
Then I create the shap values, use these to create a summary plot and save the create visualization. Everything works fine if I save the plot as plt.savefig('shap.png')
.
import shap
import matplotlib.pyplot as plt
shap.initjs()
explainer = shap.TreeExplainer(bst)
shap_values = explainer.shap_values(train)
fig = shap.summary_plot(shap_values, train, show=False)
plt.savefig('shap.png')
However, I need PDF or SVG plots instead of png and therefore tried to save it with plt.savefig('shap.pdf')
which normally works fine, but produces the following exception for the shap plot.
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-39-49d17973f438> in <module>()
1 fig = shap.summary_plot(shap_values, train, show=False)
----> 2 plt.savefig('shap.pdf')
C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\pyplot.py in
savefig(*args, **kwargs)
708 def savefig(*args, **kwargs):
709 fig = gcf()
--> 710 res = fig.savefig(*args, **kwargs)
711 fig.canvas.draw_idle() # need this if 'transparent=True' to reset
colors
712 return res
C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\figure.py in
savefig(self, fname, **kwargs)
2033 self.set_frameon(frameon)
2034
-> 2035 self.canvas.print_figure(fname, **kwargs)
2036
2037 if frameon:
C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\backend_bases.py in
print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format,
**kwargs)
2261 orientation=orientation,
2262 bbox_inches_restore=_bbox_inches_restore,
-> 2263 **kwargs)
2264 finally:
2265 if bbox_inches and restore_bbox:
C:\Users\Studio\Anaconda3\lib\site-
packages\matplotlib\backends\backend_pdf.py in print_pdf(self, filename,
**kwargs)
2584 RendererPdf(file, image_dpi, height, width),
2585 bbox_inches_restore=_bbox_inches_restore)
-> 2586 self.figure.draw(renderer)
2587 renderer.finalize()
2588 if not isinstance(filename, PdfPages):
C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\artist.py in
draw_wrapper(artist, renderer, *args, **kwargs)
53 renderer.start_filter()
54
---> 55 return draw(artist, renderer, *args, **kwargs)
56 finally:
57 if artist.get_agg_filter() is not None:
C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\figure.py in
draw(self, renderer)
1473
1474 mimage._draw_list_compositing_images(
-> 1475 renderer, self, artists, self.suppressComposite)
1476
1477 renderer.close_group('figure')
C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\image.py in
_draw_list_compositing_images(renderer, parent, artists, suppress_composite)
139 if not_composite or not has_images:
140 for a in artists:
--> 141 a.draw(renderer)
142 else:
143 # Composite any adjacent images together
C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\artist.py in
draw_wrapper(artist, renderer, *args, **kwargs)
53 renderer.start_filter()
54
---> 55 return draw(artist, renderer, *args, **kwargs)
56 finally:
57 if artist.get_agg_filter() is not None:
C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\axes\_base.py in
draw(self, renderer, inframe)
2605 renderer.stop_rasterizing()
2606
-> 2607 mimage._draw_list_compositing_images(renderer, self,
artists)
2608
2609 renderer.close_group('axes')
C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\image.py in
_draw_list_compositing_images(renderer, parent, artists, suppress_composite)
139 if not_composite or not has_images:
140 for a in artists:
--> 141 a.draw(renderer)
142 else:
143 # Composite any adjacent images together
C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\artist.py in
draw_wrapper(artist, renderer, *args, **kwargs)
58 renderer.stop_filter(artist.get_agg_filter())
59 if artist.get_rasterized():
---> 60 renderer.stop_rasterizing()
61
62 draw_wrapper._supports_rasterization = True
C:\Users\Studio\Anaconda3\lib\site-
packages\matplotlib\backends\backend_mixed.py in stop_rasterizing(self)
128
129 height = self._height * self.dpi
--> 130 buffer, bounds =
self._raster_renderer.tostring_rgba_minimized()
131 l, b, w, h = bounds
132 if w > 0 and h > 0:
C:\Users\Studio\Anaconda3\lib\site-
packages\matplotlib\backends\backend_agg.py in tostring_rgba_minimized(self)
138 [extents[0] + extents[2], self.height - extents[1]]]
139 region = self.copy_from_bbox(bbox)
--> 140 return np.array(region), extents
141
142 def draw_path(self, gc, path, transform, rgbFace=None):
ValueError: negative dimensions are not allowed
Do you have any idea how to fix this? Thanks in advance!
By default summary_plot calls plt. show() to ensure the plot displays. But if you pass show=False to summary_plot then it will allow you to save it.
6 SHAP Summary Plot. The summary plot combines feature importance with feature effects. Each point on the summary plot is a Shapley value for a feature and an instance. The position on the y-axis is determined by the feature and on the x-axis by the Shapley value.
The SHAP force plot shows you exactly which features had the most influence on the model's prediction for a single observation.
While saving the plot one has to append matplotlib=True,show=False
:
def heart_disease_risk_factors(model, patient):
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(patient)
shap.initjs()
return shap.force_plot(explainer.expected_value[1],shap_values[1],\
patient,matplotlib=True,show=False)
plt.clf()
data_for_prediction = X_test.iloc[2,:].astype(float)
heart_disease_risk_factors(model, data_for_prediction)
plt.savefig("gg.png",dpi=150, bbox_inches='tight')
This is an issue between NumPy and matplotlib caused when plotting with rasterized=True
(which shap does if there are more than 500 datapoints) and has been resolved in the latest version of matplotlib.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With