Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Explaining CNN (Keras) outputs with LIME

I am trying to explain the outputs of my convolutional neural network bult in Keras with LIME.

My neural network is a multi-class text classifier where every class is independent. Thus, a text can contain class 1 and 2 or only 1 etc. A fifth "class" (None) for cases where no classes are in the text.

However, while i managed to explain a binary classification case with Keras and Lime, I just cannot get the multi-class case with independent classes. A first help was found here:

However, my code does not work, I get internal errors from Lime such as: "ValueError: Found input variables with inconsistent numbers of samples: [5000, 100000]"

from lime.lime_text import LimeTextExplainer, TextDomainMapper
explainer = LimeTextExplainer(class_names=encoder.classes_)


chosen_text = 2

def flatten_predict(i):
    global model   
    # catch single string inputs and convert them to list
    if i.__class__ != list:
        i = [i]
        print("## Caught and transformed single string.")
    # list for predictions
    predStorage = []
    # loop through input list and predict
    for textInput in i:
        textInput = preprocess(textInput)
        textInput = make_predictable(textInput)
        pred = model.predict(textInput)
        pred = np.append(pred, 1-pred, axis=1)
        # control output of function

        predStorage.extend(pred)
    return np.asarray(predStorage)


def get_predict_proba_fn_of_class(label):
    """assuming wrapped_predict outputs an (n, d) array of prediction probabilities, where d is the number of labels"""
    def rewrapped_predict(strings): 
        preds = flatten_predict(strings)[:, np.where(flatten_predict(strings)==label)].reshape(-1, 1)
        ret = np.asarray(np.hstack([(1 - preds), preds]))
        return ret

    return rewrapped_predict

str = 'Ein sehr freundlicher Arzt.'
preds = flatten_predict(str)
labels_to_explain = preds# 
print(labels_to_explain)

explanation_for_label = {}
for label in labels_to_explain:
    wrapped = get_predict_proba_fn_of_class(label)
    explanation_for_label[label] = explainer.explain_instance(str, wrapped)
    explanation_for_label[label].show_in_notebook()

Error Message:

ValueError                                Traceback (most recent call last)
<ipython-input-26-8df61aaa23f4> in <module>()
     53 for label in labels_to_explain:
     54     wrapped = get_predict_proba_fn_of_class(label)
---> 55     explanation_for_label[label] = explainer.explain_instance(str, wrapped)
     56     explanation_for_label[label].show_in_notebook()
     57 

/usr/local/lib/python3.6/dist-packages/lime/lime_text.py in explain_instance(self, text_instance, classifier_fn, labels, top_labels, num_features, num_samples, distance_metric, model_regressor)
    405                 data, yss, distances, label, num_features,
    406                 model_regressor=model_regressor,
--> 407                 feature_selection=self.feature_selection)
    408         return ret_exp
    409 

/usr/local/lib/python3.6/dist-packages/lime/lime_base.py in explain_instance_with_data(self, neighborhood_data, neighborhood_labels, distances, label, num_features, feature_selection, model_regressor)
    155                                                weights,
    156                                                num_features,
--> 157                                                feature_selection)
    158 
    159         if model_regressor is None:

/usr/local/lib/python3.6/dist-packages/lime/lime_base.py in feature_selection(self, data, labels, weights, num_features, method)
    104                 n_method = 'highest_weights'
    105             return self.feature_selection(data, labels, weights,
--> 106                                           num_features, n_method)
    107 
    108     def explain_instance_with_data(self,

/usr/local/lib/python3.6/dist-packages/lime/lime_base.py in feature_selection(self, data, labels, weights, num_features, method)
     78             clf = Ridge(alpha=0, fit_intercept=True,
     79                         random_state=self.random_state)
---> 80             clf.fit(data, labels, sample_weight=weights)
     81             feature_weights = sorted(zip(range(data.shape[0]),
     82                                          clf.coef_ * data[0]),

/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/ridge.py in fit(self, X, y, sample_weight)
    678         self : returns an instance of self.
    679         """
--> 680         return super(Ridge, self).fit(X, y, sample_weight=sample_weight)
    681 
    682 

/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/ridge.py in fit(self, X, y, sample_weight)
    489 
    490         X, y = check_X_y(X, y, ['csr', 'csc', 'coo'], dtype=_dtype,
--> 491                          multi_output=True, y_numeric=True)
    492 
    493         if ((sample_weight is not None) and

/usr/local/lib/python3.6/dist-packages/sklearn/utils/validation.py in check_X_y(X, y, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, multi_output, ensure_min_samples, ensure_min_features, y_numeric, warn_on_dtype, estimator)
    764         y = y.astype(np.float64)
    765 
--> 766     check_consistent_length(X, y)
    767 
    768     return X, y

/usr/local/lib/python3.6/dist-packages/sklearn/utils/validation.py in check_consistent_length(*arrays)
    233     if len(uniques) > 1:
    234         raise ValueError("Found input variables with inconsistent numbers of"
--> 235                          " samples: %r" % [int(l) for l in lengths])
    236 
    237 

ValueError: Found input variables with inconsistent numbers of samples: [5000, 100000]

Does anyone know what I am doing wrong? I am pretty sure it has to do with the input format.

like image 963
junkmaster Avatar asked Jan 07 '19 10:01

junkmaster


People also ask

How do you read limes?

The general approach lime takes to achieving this goal is as follows: For each prediction to explain, permute the observation n times. Let the complex model predict the outcome of all permuted observations. Calculate the distance from all permutations to the original observation.

How does lime explanation work?

LIME provides local model interpretability. LIME modifies a single data sample by tweaking the feature values and observes the resulting impact on the output. Often, this is also related to what humans are interested in when observing the output of a model.

What is CNN output?

The output of CNN model is calculated using SoftMax function. SoftMax is preferred as it gives the probability of outputs for different classes rather than just >= 0.5 in the case of sigmoid output.


1 Answers

I managed to solve the problem now. Here comes the solution for those who might have an interest in it.

In short, the trick was to pick the right columns from the numpy arrays in get_predict_proba_fn_of_class. While I had five independent classification scores that do not add up to one, I had to add the negative scores for every label's classification scores in a new column (e.g. for 0.67 I added 1-0.67) and then pick the original and new column.

from lime.lime_text import LimeTextExplainer, TextDomainMapper
print(encoder.classes_)
##https://lime-ml.readthedocs.io/en/latest/lime.html#module-lime.lime_text


def flatten_predict(i):
    global model   
    # catch single string input and convert to list
    if i.__class__ != list:
        i = [i]
    # list for predictions
    predStorage = []
    # loop through input list and predict
    for textInput in i:
        textInput = preprocess(textInput)
        textInput = make_predictable(textInput)
        pred = model.predict(textInput)
        pred = np.append(pred, 1-pred, axis=1)
        predStorage.extend(pred)

    return np.asarray(predStorage)


def get_predict_proba_fn_of_class(strings):
    def rewrapped_predict(strings): 
        pred = flatten_predict(strings)
        index = np.where(pred==label)[1][0]
        preds = pred[:, index::5]
        return preds

    return rewrapped_predict


string="Der Arzt weiß, was er tut und hat mir alles genau erklärt."
print("Simple model prediction:", model.predict(make_predictable(preprocess(string))))

labels_to_explain = flatten_predict(string)
print("labels_to_explain:", labels_to_explain)

explanation_for_label = {}
for index, label in enumerate(labels_to_explain[0]):
    if index < (len(labels_to_explain[0])/2):
        actual_classes = [encoder.classes_[index], 'None']
        explainer = LimeTextExplainer(class_names=actual_classes)
        wrapped = get_predict_proba_fn_of_class(string)  # function returns function!
        explanation_for_label[label] = explainer.explain_instance(string, wrapped)
        explanation_for_label[label].show_in_notebook()
like image 123
junkmaster Avatar answered Oct 02 '22 12:10

junkmaster