In standard supervised machine learning, we need training data, i.e. a set of data points with known labels, and we build a model to learn the distinguishing properties that separate data points with different labels. This trained model can then be used to make label predictions for new data points. If we want to make predictions for another task (with different labels) in a different domain, we cannot use the model trained previously. We need to gather training data with the new task, and train a separate model.
Transfer learning provides a framework to leverage the already existing model (based on some training data) in a related domain. We can transfer the knowledge gained in the previous model to the new domain (and data).
For example, if we have built a model to detect pedestrians and vehicles in traffic images, and we wish to build a model for detecting pedestrians, cycles, and vehicles in the same data, we will have to train a new model with three classes because the previous model was trained to make two-class predictions. But we clearly have learned something in the two-class situation, e.g. discerning people walking from moving vechicles. In the transfer learning paradigm, we can use our learnings from the two-label classifier to the three-label classifier that we intend to construct.
As such, we can already see that transfer learning has very high potential. In the words of Andrew Ng, a leading expert in machine learning, in his extremly popular NIPS 2016 tutorial, “Transfer learning will be next driver of machine learning success.”
Transfer learning is particularly popular in deep learning. The reason for this is that it’s very expensive to train deep neural networks, and they require huge amounts of data to be able to achieve their full potential. In fact, other recent successes of deep learning can be attributed to the availablity of a lot of data and stronger computational resources. But, other than a few large companies like Google, Facebook, IBM, and Microsoft, it’s very difficult to accrue data and the computational machines required for training strong deep learning models. In such a situation, transfer learning comes to the rescue. Many pre-trained models, trained on a large amount of data, have been made available publically, along with the values of billions of parameters. You can use the pre-trained models on large data, and rely on transfer learning to build models for your specific case.
The most popular application of transfer learning is image classification using deep convolution neural networks (ConvNets). A bunch of high performing, state-of-the-art convolution neural network based image classifiers, trained on ImageNet data (1.2 million images with 100 categories), are available publically. Examples of such models include AlexNet, VGG16, VGG19, InceptionV3, and more, which takes months to train. I have personally used transfer learning to build image classifiers on top of VGG19 and InceptionV3. Another popular model is the pre-trained distributed word embeddings for millions of words, e.g word2vec, GloVe, FastText, etc. These are trained on all of Wikipedia, Google News, etc., and provide vector representations for a huge number of words. This can then be used in a text classification model.
Transfer learning can be used in one the following four ways:
Transfer learning allows someone without a large amount of data or computational capabilities to take advantage of the deep learning paradigm. It is an exciting research and application direction to use off-the-shelf pre-trained models and transfer them to novel domains.
Janu Verma is a Researcher in the IBM T.J. Watson Research Center, New York. His research interests are in mathematics, machine learning, information visualization, computational biology and healthcare analytics. He has held research positions at Cornell University, Kansas State University, Tata Institute of Fundamental Research, Indian Institute of Science, and Indian Statistical Institute. He has written papers for IEEE Vis, KDD, International Conference on HealthCare Informatics, Computer Graphics and Applications, Nature Genetics, IEEE Sensors Journals etc. His current focus is on the development of visual analytics systems for prediction and understanding. He advises startups and companies on data science and machine learning in the Delhi-NCR area, email to schedule a meeting. Check out his personal website at http://jverma.github.io/.
I remember deciding to pursue my first IT certification, the CompTIA A+. I had signed…
Key takeaways The transformer architecture has proved to be revolutionary in outperforming the classical RNN…
Once we learn how to deploy an Ubuntu server, how to manage users, and how…
Key-takeaways: Clean code isn’t just a nice thing to have or a luxury in software projects; it's a necessity. If we…
While developing a web application, or setting dynamic pages and meta tags we need to deal with…
Software architecture is one of the most discussed topics in the software industry today, and…