Advances in machine learning have greatly improved products, processes, and research, and how people might interact with computers. One of the factors lacking in machine learning processes is the ability to give an explanation for their predictions. The inability to give a proper explanation of results leads to end-users losing their trust over the system, which ultimately acts as a barrier to the adoption of machine learning. Hence, along with the impressive results from machine learning, it is also important to understand why and where it works, and when it won’t. In this article, we will talk about some ways to increase machine learning interpretability and make predictions from machine learning models understandable.
3 interesting methods for interpreting Machine Learning predictions
According to Miller, interpretability is the degree to which a human can understand the cause of a decision. Interpretable predictions lead to better trust and provide insight into how the model may be improved. The kind of machine learning developments happening in the present times require a lot of complex models, which lack in interpretability. Simpler models (e.g. linear models), on the other hand, often give a correct interpretation of a prediction model’s output, but they are often less accurate than complex models. Thus creating a tension between accuracy and interpretability.
Complex models are less interpretable as their relationships are generally not concisely summarized. However, if we focus on a prediction made on a particular sample, we can describe the relationships more easily. Balancing the trade-off between model complexity and interpretability lies at the heart of the research done in the area of developing interpretable deep learning and machine learning models.
We will discuss a few methods to increase the interpretability of complex ML models by summarizing model behavior with respect to a single prediction.
LIME or Local Interpretable Model-Agnostic Explanations, is a method developed in the paper Why should I trust you? for interpreting individual model predictions based on locally approximating the model around a given prediction. LIME uses two approaches to explain specific predictions: perturbation and linear approximation. With Perturbation, LIME takes a prediction that requires explanation and systematically perturbs its inputs. These perturbed inputs become new, labeled training data for a simpler approximate model. It then does local linear approximation by fitting a linear model to describe the relationships between the (perturbed) inputs and outputs. Thus a simple linear algorithm approximates the more complex, nonlinear function.
DeepLIFT (Deep Learning Important FeaTures) is another method which serves as a recursive prediction explanation method for deep learning. This method decomposes the output prediction of a neural network on a specific input by backpropagating the contributions of all neurons in the network to every feature of the input. DeepLIFT assigns contribution scores based on the difference between activation of each neuron and its ‘reference activation’. DeepLIFT can also reveal dependencies which are missed by other approaches by optionally giving separate consideration to positive and negative contributions.
Layer-wise relevance propagation is another method for interpreting the predictions of deep learning models. It determines which features in a particular input vector contribute most strongly to a neural network’s output. It defines a set of constraints to derive a number of different relevance propagation functions.
Thus we saw 3 different ways of summarizing model behavior with a single prediction to increase model interpretability. Another important avenue to interpret machine learning models is to understand (and rethink) generalization.
What is generalization and how it affects Machine learning interpretability
Machine learning algorithms are trained on certain datasets, called training sets. During training, a model learns intrinsic patterns in data and updates its internal parameters to better understand the data. Once training is over, the model is tried upon test data to predict results based on what it has learned. In an ideal scenario, the model would always accurately predict the results for the test data. In reality, what happens is that the model is able to identify all the relevant information in the training data, but sometimes fails when presented with the new data. This difference between “training error” and “test error” is called the generalization error. The ultimate aim of turning a machine learning system to a scalable product is generalization. Every task in ML wants to create a generalized algorithm that acts in the same way for all kind of distributions. And the ability to distinguish models that generalize well from those that do not, will not only help to make ML models more interpretable, but it might also lead to more principled and reliable model architecture design.
According to the conventional statistical theory, small generalization error is either due to properties of the model family or because of the regularization techniques used during training. A recent paper at ICLR 2017, Understanding deep learning requires rethinking generalization shows that current machine learning theoretical frameworks fail to explain the impressive results of deep learning approaches and why understanding deep learning requires rethinking generalization. They support their findings through extensive systematic experiments.
Developing human understanding through visualizing ML models
Interpretability also means creating models that support human understanding of machine learning. Human interpretation is enhanced when visual and interactive diagrams and figures are used for the purpose of explaining the results of ML models. This is why a tight interplay of UX design with Machine learning is essential for increasing Machine learning interpretability.
Walking along the lines of Human-centered Machine Learning, researchers at Google, OpenAI, DeepMind, YC Research and others have come up with Distill. This open science journal features articles which have a clear exposition of machine learning concepts using excellent interactive visualization tools. Most of these articles are aimed at understanding the inner working of various machine learning techniques. Some of these include:
- An article on attention and Augmented Recurrent Neural Networks which has a beautiful visualization of attention distribution in RNN.
- Another one on feature visualization, which talks about how neural networks build up their understanding of images
Google has also launched the PAIR initiative to study and design the most effective ways for people to interact with AI systems. It helps researchers understand ML systems through work on interpretability and expanding the community of developers.
R2D3 is another website, which provides an excellent visual introduction to machine learning.
Facets is another tool for visualizing and understanding training datasets to provide a human-centered approach to ML engineering.
Human-Centered Machine Learning is all about increasing machine learning interpretability of ML systems and in developing their human understanding. It is about ML and AI systems understanding how humans reason, communicate and collaborate. As algorithms are used to make decisions in more angles of everyday life, it’s important for data scientists to train them thoughtfully to ensure the models make decisions for the right reasons.
As more progress is done in this area, ML systems will not make commonsense errors or violate user expectations or place themselves in situations that can lead to conflict and harm, making such systems safer to use. As research continues in this area, machines will soon be able to completely explain their decisions and their results in the most humane way possible.