Save SHAP summary plot as PDF/SVG

I currently work on a classification problem and want to create visualizations of feature importance. I use the Python XGBoost package which already provides feature importance plots. However, I found shap (https://github.com/slundberg/shap), a Python library that creates very nice plots for feature importance based on tree classifiers. Everything works fine, I can also save the created plots as PNG, however, if I try to save it as PDF or SVG, I get an exception. Here is what I am doing:

First I train the XGBoost model and get the model back denoted by bst.

train = remove_labels_for_binary_df(dataset_fc_baseline_1[0].train)
test = remove_labels_for_binary_df(dataset_fc_baseline_1[0].test)
results, bst = xgboost_with_bst(*transform_feat_to_num(train, test))

Then I create the shap values, use these to create a summary plot and save the create visualization. Everything works fine if I save the plot as plt.savefig('shap.png').

import shap
import matplotlib.pyplot as plt


explainer = shap.TreeExplainer(bst)
shap_values = explainer.shap_values(train)
fig = shap.summary_plot(shap_values, train, show=False)

However, I need PDF or SVG plots instead of png and therefore tried to save it with plt.savefig('shap.pdf') which normally works fine, but produces the following exception for the shap plot.

ValueError                                Traceback (most recent call last)
<ipython-input-39-49d17973f438> in <module>()
  1 fig = shap.summary_plot(shap_values, train, show=False)
----> 2 plt.savefig('shap.pdf')

 C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\pyplot.py in 
savefig(*args, **kwargs)
708 def savefig(*args, **kwargs):
709     fig = gcf()
--> 710     res = fig.savefig(*args, **kwargs)
711     fig.canvas.draw_idle()   # need this if 'transparent=True' to reset 
712     return res

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\figure.py in 
savefig(self, fname, **kwargs)
2033             self.set_frameon(frameon)
-> 2035         self.canvas.print_figure(fname, **kwargs)
2037         if frameon:

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\backend_bases.py in 
print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, 
2261                 orientation=orientation,
2262                 bbox_inches_restore=_bbox_inches_restore,
-> 2263                 **kwargs)
2264         finally:
2265             if bbox_inches and restore_bbox:

packages\matplotlib\backends\backend_pdf.py in print_pdf(self, filename, 
2584                 RendererPdf(file, image_dpi, height, width),
2585                 bbox_inches_restore=_bbox_inches_restore)
-> 2586             self.figure.draw(renderer)
2587             renderer.finalize()
2588             if not isinstance(filename, PdfPages):

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\artist.py in 
draw_wrapper(artist, renderer, *args, **kwargs)
 53                 renderer.start_filter()
---> 55             return draw(artist, renderer, *args, **kwargs)
 56         finally:
 57             if artist.get_agg_filter() is not None:

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\figure.py in 
draw(self, renderer)
1474             mimage._draw_list_compositing_images(
-> 1475                 renderer, self, artists, self.suppressComposite)
1477             renderer.close_group('figure')

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\image.py in 
_draw_list_compositing_images(renderer, parent, artists, suppress_composite)
139     if not_composite or not has_images:
140         for a in artists:
--> 141             a.draw(renderer)
142     else:
143         # Composite any adjacent images together

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\artist.py in 
draw_wrapper(artist, renderer, *args, **kwargs)
 53                 renderer.start_filter()
---> 55             return draw(artist, renderer, *args, **kwargs)
 56         finally:
 57             if artist.get_agg_filter() is not None:

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\axes\_base.py in 
draw(self, renderer, inframe)
2605             renderer.stop_rasterizing()
-> 2607         mimage._draw_list_compositing_images(renderer, self, 
2609         renderer.close_group('axes')

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\image.py in 
_draw_list_compositing_images(renderer, parent, artists, suppress_composite)
139     if not_composite or not has_images:
140         for a in artists:
--> 141             a.draw(renderer)
142     else:
143         # Composite any adjacent images together

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\artist.py in 
draw_wrapper(artist, renderer, *args, **kwargs)
 58                 renderer.stop_filter(artist.get_agg_filter())
 59             if artist.get_rasterized():
---> 60                 renderer.stop_rasterizing()
 62     draw_wrapper._supports_rasterization = True

packages\matplotlib\backends\backend_mixed.py in stop_rasterizing(self)
129             height = self._height * self.dpi
--> 130             buffer, bounds = 
131             l, b, w, h = bounds
132             if w > 0 and h > 0:

packages\matplotlib\backends\backend_agg.py in tostring_rgba_minimized(self)
138                 [extents[0] + extents[2], self.height - extents[1]]]
139         region = self.copy_from_bbox(bbox)
--> 140         return np.array(region), extents
142     def draw_path(self, gc, path, transform, rgbFace=None):

ValueError: negative dimensions are not allowed

Do you have any idea how to fix this? Thanks in advance!

2 Answers

While saving the plot one has to append matplotlib=True,show=False:

def heart_disease_risk_factors(model, patient):

    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(patient)

    return shap.force_plot(explainer.expected_value[1],shap_values[1],\

data_for_prediction = X_test.iloc[2,:].astype(float)
heart_disease_risk_factors(model, data_for_prediction)
plt.savefig("gg.png",dpi=150, bbox_inches='tight')
This is an issue between NumPy and matplotlib caused when plotting with rasterized=True (which shap does if there are more than 500 datapoints) and has been resolved in the latest version of matplotlib.

