What is Transfer Learning?

What is Transfer Learning?

Transfer Learning is a machine learning technique where model trained on one task could be applied to the different but related problem. For example (from Wikipedia), knowledge gained while learning to recognize cars could apply when trying to recognize trucks. This model is retrained in a similar problem which will drastically cut down the training time. To conduct deep learning could take days, but transfer learning helps to conduct the work in a very short time.

When does Transfer Learning makes sense?

If you are trying to learn from task A and transfer some of the knowledge to B. Then,

  • Task A and Task B has same input x. Eg. images, audio
  • You have a lot more data in Task A than Task B because data in Task B adds more value than task A.
  • Low-Level feature from Task A could be helpful for learning Task B.

Below is the example of image classification on imagenet pre-trained dataset. Weights are downloaded automatically when instantiating a model. And, trained model architecture is downloaded from amazon aws. They are stored at ~/.keras/models/.

To predict with this model:

  1. First you should have python, tensorflow, keras and its dependencies installed on your machine. Then,
  2. Download elephant image from Google and name the image as elephant.jpg
  3. Keep the image in the same project folder (be careful of the path)
  4. Copy the code below
  5. Run the code with your favorite editor
  6. See the prediction. Experiment with different elephant images.

Classify ImageNet classes with ResNet50

from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np
model = ResNet50(weights='imagenet')
img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
preds = model.predict(x)
# decode the results into a list of tuples (class, description, probability)
# (one such list for each sample in the batch)
print('Predicted:', decode_predictions(preds, top=3)[0])
# Predicted: [(u'n02504013', u'Indian_elephant', 0.82658225), (u'n01871265', u'tusker', 0.1122357), (u'n02504458', u'African_elephant', 0.061040461)]

REFERENCES: Above code is extracted from keras website. Visit this site for more information. Furthermore, some of the content in this website is extracted from  a link of Andrew Ng teaching.


Leave a Reply

Your email address will not be published. Required fields are marked *