Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

putting HTML output from SHAP into the Dash output layout callback

Tags:

plotly-dash

I am trying to make a dashboard where the output from shap forceplot is illustrated. Shap.forceplot is HTML decorated with json. The example is here

I made a very simple dashboard using the tutorial which should plot the desirable figure after clicking the submit

here is the code

# -*- coding: utf-8 -*-
import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output, State
import pandas as pd
from sqlalchemy import create_engine
import shap
from sources import *
import xgboost

external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']

app = dash.Dash(__name__, external_stylesheets=external_stylesheets)

app.layout = html.Div([
    dcc.Input(id='input-cvr-state', type='text', value='12'),
    html.Button(id='submit-button', n_clicks=0, children='Submit'),
    html.Div(id='output-state'),
    html.Div(id='output-shap')
])


@app.callback(Output('output-shap', 'children'),
              [Input('submit-button', 'n_clicks')],
              [State('input-cvr-state', 'value')])

def update_shap_figure(n_clicks, input_cvr):
    shap.initjs()

    # train XGBoost model
    X,y = shap.datasets.boston()

    model = xgboost.train({"learning_rate": 0.01}, xgboost.DMatrix(X, label=y), 100)

    # explain the model's predictions using SHAP values(same syntax works for LightGBM, CatBoost, and scikit-learn models)
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X)

    # visualize the first prediction's explanation

    return(shap.force_plot(explainer.expected_value, shap_values[0,:], X.iloc[0,:])) # matplotlib=True

if __name__ == '__main__':
    app.run_server(debug=True)
like image 707
Areza Avatar asked Jan 21 '19 14:01

Areza


People also ask

What is callback in Dash?

Whenever an input property changes, the function that the callback decorator wraps will get called automatically. Dash provides this callback function with the new value of the input property as its argument, and Dash updates the property of the output component with whatever was returned by the function.

Do you need HTML for Dash?

Dash is a web app framework that provides pure Python abstraction around HTML, CSS, and JavaScript. Instead of writing HTML or using an HTML templating engine, you compose your layout using Python with the Dash HTML Components module ( dash.

What is Dash app layout?

Summary. The layout of a Dash app describes what the app looks like. The layout is a hierarchical tree of components. Dash HTML Components ( dash.html ) provides classes for all of the HTML tags and the keyword arguments describe the HTML attributes like style , class , and id .

What is Shap dependence plot?

SHAP dependence plots are an alternative to partial dependence plots and accumulated local effects. While PDP and ALE plot show average effects, SHAP dependence also shows the variance on the y-axis. Especially in case of interactions, the SHAP dependence plot will be much more dispersed in the y-axis.


2 Answers

I managed it by following steps:

import shap
from shap.plots._force_matplotlib import draw_additive_plot

# ... class dashApp
# ... callback as method 
# matplotlib=False => retrun addaptativevisualizer, 
# if set to True the visualizer will render the result is the stdout directly
# x is index of wanted input
# class_1 is ma class to draw

force_plot = shap.force_plot(
    self.explainer.expected_value[class_1],
    self.shap_values[class_1][x[0], :],
    self.data.iloc[x, :].drop(columns=["TARGET"], errors="ignore"),
    matplotlib=False
)
# set show=False to force the figure to be returned
force_plot_mpl = draw_additive_plot(force_plot.data, (30, 7), show=False)
return figure_to_html_img(force_plot_mpl)


def figure_to_html_img(figure):
    """ figure to html base64 png image """ 
    try:
        tmpfile = io.BytesIO()
        figure.savefig(tmpfile, format='png')
        encoded = base64.b64encode(tmpfile.getvalue()).decode('utf-8')
        shap_html = html.Img(src=f"data:image/png;base64, {encoded}")
        return shap_html
    except AttributeError:
        return ""

 

The result will be like it

enter image description here

like image 77
Ali SAID OMAR Avatar answered Oct 31 '22 22:10

Ali SAID OMAR


An alternative is to use html.IFrame which will produce a better looking and fully interactive plot.

Here's an example that can be used directly as an Output

def _force_plot_html(*args):
    force_plot = shap.force_plot(*args, matplotlib=False)
    shap_html = f"<head>{shap.getjs()}</head><body>{force_plot.html()}</body>"
    return html.Iframe(srcDoc=shap_html,
                       style={"width": "100%", "height": "200px", "border": 0})
like image 2
Marigold Avatar answered Nov 01 '22 00:11

Marigold