Introduction and motivation
In standard supervised machine learning, we need training data, i.e. a set of data points with known labels, and we build a model to learn the distinguishing properties that separate data points with different labels. This trained model can then be used to make label predictions for new data points. If we want to make predictions for another task (with different labels) in a different domain, we cannot use the model trained previously. We need to gather training data with the new task, and train a separate model.
Transfer learning provides a framework to leverage the already existing model (based on some training data) in a related domain. We can transfer the knowledge gained in the previous model to the new domain (and data).
For example, if we have built a model to detect pedestrians and vehicles in traffic images, and we wish to build a model for detecting pedestrians, cycles, and vehicles in the same data, we will have to train a new model with three classes because the previous model was trained to make two-class predictions. But we clearly have learned something in the two-class situation, e.g. discerning people walking from moving vechicles. In the transfer learning paradigm, we can use our learnings from the two-label classifier to the three-label classifier that we intend to construct.
As such, we can already see that transfer learning has very high potential. In the words of Andrew Ng, a leading expert in machine learning, in his extremly popular NIPS 2016 tutorial, “Transfer learning will be next driver of machine learning success.”
Transfer learning in deep learning
Transfer learning is particularly popular in deep learning. The reason for this is that it’s very expensive to train deep neural networks, and they require huge amounts of data to be able to achieve their full potential. In fact, other recent successes of deep learning can be attributed to the availablity of a lot of data and stronger computational resources. But, other than a few large companies like Google, Facebook, IBM, and Microsoft, it’s very difficult to accrue data and the computational machines required for training strong deep learning models. In such a situation, transfer learning comes to the rescue. Many pre-trained models, trained on a large amount of data, have been made available publically, along with the values of billions of parameters. You can use the pre-trained models on large data, and rely on transfer learning to build models for your specific case.
The most popular application of transfer learning is image classification using deep convolution neural networks (ConvNets). A bunch of high performing, state-of-the-art convolution neural network based image classifiers, trained on ImageNet data (1.2 million images with 100 categories), are available publically. Examples of such models include AlexNet, VGG16, VGG19, InceptionV3, and more, which takes months to train. I have personally used transfer learning to build image classifiers on top of VGG19 and InceptionV3. Another popular model is the pre-trained distributed word embeddings for millions of words, e.g word2vec, GloVe, FastText, etc. These are trained on all of Wikipedia, Google News, etc., and provide vector representations for a huge number of words. This can then be used in a text classification model.
Strategies for transfer learning
Transfer learning can be used in one the following four ways:
- Directly use pre-trained model: The pre-trained model can be directly used for a similar task. For example, you can use the InceptionV3 model by Google to make predictions about the categories of images. These models are already shown to have high accuracy.
- Fixed features: The knowledge gained in one model can be used to build features for the data points, and such features (fixed) are then fed to new models. For example, you can run the new images through a pre-trained ConvNet and the output of any layer can be used as a feature vector for this image. The features thus built can be used in a classifier for the desired situation. Similarly, you can directly use the word vectors in the text classification model.
- Fine-tuning the model: In this strategy, you can use the pre-trained network as your model while allowing for fine-tuning the network. For example, for the image classifier model, you can feed your images to the InceptionV3 model and use the pre-trained weights as an initialization (rather than random initialzation). The model will be trained on the much smaller user-provided data. The advantage of such a strategy is that weights can reach the global minima without much data and training. You can also make a portion (usually the begining layers) fixed, and only fine-tune the remaining layers.
- Combining models: Instead of re-training the top few layers of a pre-trained model, you can replace the top few layers by a new classifier, and train this combined network, while keeping the pre-trained portion fixed.
- It is not a good idea to fine-tune the pre-trained model if the data is too small and similar to the original data. This will result in overfitting. You can directly feed the data to the pre-trained model or train a simple classifier on the fixed features extracted from it.
- If the new data is large, it is a good idea to fine-tune the pre-trained model. In case the data is similar to the original, we can fine-tune only the top few layers, and fine-tuning will increase confidence in our predictions. If the data is very different, we will have to fine-tune the whole network.
Transfer learning allows someone without a large amount of data or computational capabilities to take advantage of the deep learning paradigm. It is an exciting research and application direction to use off-the-shelf pre-trained models and transfer them to novel domains.
About the Author
Janu Verma is a Researcher in the IBM T.J. Watson Research Center, New York. His research interests are in mathematics, machine learning, information visualization, computational biology and healthcare analytics. He has held research positions at Cornell University, Kansas State University, Tata Institute of Fundamental Research, Indian Institute of Science, and Indian Statistical Institute. He has written papers for IEEE Vis, KDD, International Conference on HealthCare Informatics, Computer Graphics and Applications, Nature Genetics, IEEE Sensors Journals etc. His current focus is on the development of visual analytics systems for prediction and understanding. He advises startups and companies on data science and machine learning in the Delhi-NCR area, email to schedule a meeting. Check out his personal website at http://jverma.github.io/.