Transfer learning in deep learning involves using a pre-trained model on a source task and adapting it to a related target task. A common approach to transfer learning is fine-tuning, where the pre-trained model is further trained on the target task with a smaller learning rate to avoid catastrophic forgetting of the source task.
Let’s denote:
- Ds: Source dataset
- Dt: Target dataset
- Mt: Target model (initialized from Ms and fine-tuned on Dt)
- θs: Parameters of Ms
- θt: Parameters of Mt
- Ls: Loss function on the source task
- Lt: Loss function on the target task
The transfer learning process typically involves the following steps:
- Initialization: Initialize Ms with weights θs pre-trained on Ds.
- Fine-tuning: Train Mt on Dt by updating its parameters θt to minimize the loss Lt.
- Evaluation: Evaluate the performance of Mt on the target task using a separate validation or test set from Dt.
The mathematical model for fine-tuning can be represented as an optimization problem:
θtminLt(θt)=θtminL(Dt,Mt(θt))
where �L is the loss function, typically a cross-entropy loss for classification tasks or a mean squared error for regression tasks. The optimization can be performed using stochastic gradient descent (SGD) or its variants, with a smaller learning rate compared to training from scratch.
The learning rate during fine-tuning is often chosen to be smaller because we want to preserve the knowledge gained from the source task while allowing the model to adapt to the nuances of the target task without overfitting.
Additionally, one might also introduce regularization terms to prevent overfitting during fine-tuning. Regularization terms such as L1 or L2 regularization can be added to the loss function:
min����(��)+��(��)θtminLt(θt)+λR(θt)
where �(��)R(θt) is the regularization term penalizing large parameter values and �λ is the regularization strength.
This mathematical model captures the essence of transfer learning in deep learning, where knowledge from a source task is utilized to improve learning on a related target task.