Transfer Learning is a machine learning technique where model trained on one task could be applied to the different but related problem. For example (from Wikipedia), knowledge gained while learning to recognize cars could apply when trying to recognize trucks. This model is retrained in a similar problem which will drastically cut down the training time. To conduct deep learning could take days, but transfer learning helps to conduct the work in a very short time.
When does Transfer Learning makes sense?
If you are trying to learn from task A and transfer some of the knowledge to B. Then,
- Task A and Task B has same input x. Eg. images, audio
- You have a lot more data in Task A than Task B because data in Task B adds more value than task A.
- Low-Level feature from Task A could be helpful for learning Task B.
Below is the example of image classification on imagenet pre-trained dataset. Weights are downloaded automatically when instantiating a model. And, trained model architecture is downloaded from amazon aws. They are stored at
To predict with this model:
- First you should have python, tensorflow, keras and its dependencies installed on your machine. Then,
- Download elephant image from Google and name the image as elephant.jpg
- Keep the image in the same project folder (be careful of the path)
- Copy the code below
- Run the code with your favorite editor
- See the prediction. Experiment with different elephant images.
Classify ImageNet classes with ResNet50
from keras.applications.resnet50 import ResNet50 from keras.preprocessing import image from keras.applications.resnet50 import preprocess_input, decode_predictions import numpy as np model = ResNet50(weights='imagenet') img_path = 'elephant.jpg' img = image.load_img(img_path, target_size=(224, 224)) x = image.img_to_array(img) x = np.expand_dims(x, axis=0) x = preprocess_input(x) preds = model.predict(x) # decode the results into a list of tuples (class, description, probability) # (one such list for each sample in the batch) print('Predicted:', decode_predictions(preds, top=3)) # Predicted: [(u'n02504013', u'Indian_elephant', 0.82658225), (u'n01871265', u'tusker', 0.1122357), (u'n02504458', u'African_elephant', 0.061040461)]