Traditional machine learning workflows focus heavily on model training and optimization; the best model is usually chosen via performance measures like accuracy or error and we tend to assume that a model is good enough for deployment if it passes certain thresholds of these performance criteria. Why a model makes the predictions it makes, however, is generally neglected. But being able to understand and interpret such models can be immensely important for improving model quality, increasing trust and transparency and for reducing bias. Because complex machine learning models are essentially black boxes and too complicated to understand, we need to use approximations to get a better sense of how they work. One such approach is LIME, which stands for Local Interpretable Model-agnostic Explanations and is a tool that helps understand and explain the decisions made by complex machine learning models.
Accuracy and Error in Machine Learning
A general Data Science workflow in machine learning consists of the following steps: gather data, clean and prepare data, train models and choose the best model based on validation and test errors or other performance criteria. Usually we – particularly we Data Scientists or Statisticians who live for numbers, like small errors and high accuracy – tend to stop at this point. Let’s say we found a model that predicted 99% of our test cases correctly. In and of itself, that is a very good performance and we tend to happily present this model to colleagues, team leaders, decision makers or whoever else might be interested in our great model. And finally, we deploy the model into production. We assume that our model is trustworthy, because we have seen it perform well, but we don’t know why it performed well.
In machine learning we generally see a trade-off between accuracy and model complexity: the more complex a model is, the more difficult it will be to explain. A simple linear model is easy to explain because it only considers linear relationships between variables and predictor. But since it only considers linearity, it won’t be able to model more complex relationships and the prediction accuracy on test data will likely be lower. Deep Neural Nets are on the other end of the spectrum: since they are able to deduce multiple levels of abstraction, they are able to model extremely complex relationships and thus achieve very high accuracy. But their complexity also essentially makes them black boxes. We are not able to grasp the intricate relationships between all features that lead to the predictions made by the model so we have to use performance criteria, like accuracy and error, as a proxy for how trustworthy we believe the model is.
Trying to understand the decisions made by our seemingly perfect model usually isn’t part of the machine learning workflow.
So why would we want to invest the additional time and effort to understand the model if it’s not technically necessary?
One way to improve understanding and explain complex machine learning models is to use so-called explainer functions. There are several reasons why, in my opinion, model understanding and explanation should become part of the machine learning workflow with every classification problem:
- model improvement
- trust and transparency
- identifying and preventing bias
Understanding the relationship between features, classes and predictions, thereby understanding why a machine learning model made the decisions it made and which features were most important in that decision can help us decide if it makes intuitive sense.
Let’s consider the following poignant example from the literature: we have a deep neural net that learned to distinguish images of wolves from huskies ; it was trained on a number of images and tested on an independent set of images. 90 % of the test images were predicted correctly. We could be happy with that! But what we don’t know without running an explainer function is that the model based its decisions primarily on the background: wolf images usually had a snowy background, while husky images rarely did. So we unwittingly trained a snow detector… Just by looking at performance measures like accuracy, we would not have been able to catch that!
Having this additional knowledge about how and based on which features model predictions were made, we can intuitively judge whether our model is picking up on meaningful patterns and if it will be able to generalize on new instances.
Trust and Transparency
Understanding our machine learning models is also necessary to improve trust and provide transparency regarding their predictions and decisions. This is especially relevant given the new General Data Protection Regulation (GDPR) that will go into effect in May of 2018. Even though it is still hotly discussed whether its Article 22 includes a “right to explanation” of algorithmically derived decisions , it probably won’t be enough for long any more to have black box models making decisions that directly affect people’s lives and livelihoods, like loans  or prison sentences .
Another area where trust is particularly critical is medicine; here, decision will potentially have life-or-death consequences for patients. Machine learning models have been impressively accurate at distinguishing malignant from benign tumors of different types. But as basis for (no) medical intervention we still require a professional’s explanation of the diagnosis. Providing the explanation for why a machine learning model classified a certain patient’s tumor as benign or malignant would go a long way to help doctors trust and use machine learning models that support them in their work.
Even in everyday business, where we are not dealing with quite so dire consequences, a machine learning model can have very serious repercussions if it doesn’t perform as expected. A better understanding of machine learning models can save a lot of time and prevent lost revenue in the long run: if a model doesn’t make sensible decisions, we can catch that before it goes into deployment and wreaks havoc there.
Identifying and Preventing Bias
Fairness and bias in machine learning models is a widely discussed topic [5, 6]. Biased models often result from biased ground truths: if the data we use to train ours model contains even subtle biases, our models will learn them and thus propagate a self-fulfilling prophecy! One such (in)famous example is the machine learning model that is used to suggest sentence lengths for prisoners, which obviously reflects the inherent bias for racial inequality in the justice system . Other examples are models used for recruiting, which often show the biases our society still harbors in terms of gender-associations with specific jobs, like male software engineers and female nurses .
Machine learning models are a powerful tool in different areas of our life and they will become ever more prevalent. Therefore, it is our responsibility as Data Scientists and decision makers to understand how the models we develop and deploy make their decisions so that we can proactively work on preventing bias from being reinforced and removing it!
LIME stands for Local Interpretable Model-agnostic Explanations and is a tool that helps understand and explain the decisions made by complex machine learning models. It has been developed by Marco Ribeiro, Sameer Singh and Carlos Guestrin in 2016  and can be used to explain any classification model, whether it is a Random Forest, Gradient Boosting Tree, Neural Net, etc. And it works on different types of input data, like tabular data (data frames), images or text.
At its core, LIME follows three concepts:
- explanations are not given globally for the entire machine learning model, but locally and for every instance separately
- explanations are given on original input features, even though the machine learning model might work on abstractions
- explanations are given for the most important features by locally fitting a simple model to the prediction
This allows us to get an approximate understanding of which features contributed most strongly to a single instance’s classification and which features contradicted it and how they influenced the prediction.
The following example showcases how LIME can be used:
I built a Random Forest model on a data set about Chronic Kidney Disease . The model was trained to predict whether a patient had chronic kidney disease (ckd) or not (notckd). The model achieved 99 % accuracy on validation data and 95 % on test data. Technically, we could stop here and declare victory. But we want to understand why certain patients were diagnosed with chronic kidney disease and why others weren’t. A medical professional would then be able to assess whether what the model learned makes intuitive sense and can be trusted. To achieve this, we can apply LIME.
As described above, LIME works on each instance individually and separately. So first, we take one instance (in this case the data from one patient) and permute it; i.e. the data is replicated with slight modifications. This generates a new data set consisting of similar instances, based on one original instance. For every instance in this permuted data set we also calculate how similar it is to the original instance, i.e. how strong the modifications made during permutation are. Basically, any type of statistical distance and similarity metric can be used in this step, e.g. Euclidean distance converted to similarity with an exponential kernel of specified width.
Next, our complex machine learning model, which was trained before, will make predictions on every permuted instance. Because of the small differences in the permuted data set, we can keep track of how these changes affect the predictions that are made.
And finally, we fit a simple model (usually a linear model) to the permuted data and its predictions using the most important features. There are different ways to determine the most important features: we typically define the number of features we want to include in our explanations (usually around 5 to 10) and then either
- choose the features with highest weights in the regression fit on the predictions made by the complex machine learning model
- apply forward selection, where features are added to improve the regression fit on the predictions made by the complex machine learning model
- choose the features with smallest shrinkage on the regularization of a lasso fit on the predictions made by the complex machine learning model
- or fit a decision tree with fewer or equal number of branch splits as the number of features we have chosen
The similarity between each permuted instance and the original instance feeds as a weight into the simple model so that higher importance is given to instances which are more similar to the original instance. This precludes us from using any simple model as an explainer that is able to take weighted input, e.g. a ridge regression.
Now, we can interpret the prediction made for the original instance. With the example model described above, you can see the LIME output for the eight most important features for six patients/instances in the figure below:
Each of the six facets shows the explanation for the prediction of an individual patient or instance. The header of each facet gives the case number (here the patient ID), which class label was predicted and with what probability. For example, the top left instance describes case number 4 which was classified as “ckd” with 98 % probability. Below the header we find a bar-plot for the top 8 most important features; the length of each bar shows the weight of the feature, positive weights support a prediction, negative weights contradict it. Again described for the top left instance: the bar-plot shows that the hemoglobin value was between 0.388 and 0.466, which supports the classification as “ckd”; packed cell volume (pcv), serum creatinine (sc), etc. similarly support the classification as “ckd” (for a full list of feature abbreviations, see http://archive.ics.uci.edu/ml/datasets/Chronic_Kidney_Disease). This patient’s age and white blood cell count (wbcc), on the other hand, are more characteristic of a healthy person and therefore contradict the classification as “ckd”.