Transformer Learning on Vision Transformer Architecture.
The results of only using the model where bad as test accuracy was 26% this is because the we had few number of image dataset.
Transfer learning has proven to be a powerful technique in enhancing the performance of neural network models,
including the Vision Transformer (ViT). By leveraging pre-trained models on large datasets, transfer learning allows the ViT to benefit from knowledge gained in solving one task and apply it to a related task, even with limited labeled data. So inorder to improve the test accuracy of the model we will use ViT_B_16_Weights
as we are dealing with 16 * 16 patch size.
1. Import Pretrained weights
# Get pretrained weights for ViT-Base
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT # requires torchvision >= 0.13, "DEFAULT" means best available
# Setup a ViT model instance with pretrained weights
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)
pretrained_vit
After importing we need to change the classifier layer or in other name the last layer in the problem that you are trying to solve. For use we are trying to classify three class so we need to set it up for that.
# Freeze the base parameters
for parameter in pretrained_vit.parameters():
parameter.requires_grad = False
pretrained_vit.heads = nn.Linear(in_features=768, out_features=len(class_names)).to(device)
By using torch.info
for summary to visualize the architecture.
# Visualize the Pretrained Vision Architecture
summary(model=pretrained_vit,
input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
col_names=["input_size", "output_size", "num_params", "trainable"],
col_width=20,
row_settings=["var_names"])
============================================================================================================================================
Layer (type (var_name)) Input Shape Output Shape Param # Trainable
============================================================================================================================================
VisionTransformer (VisionTransformer) [32, 3, 224, 224] [32, 3] 768 Partial
├─Conv2d (conv_proj) [32, 3, 224, 224] [32, 768, 14, 14] (590,592) False
├─Encoder (encoder) [32, 197, 768] [32, 197, 768] 151,296 False
│ └─Dropout (dropout) [32, 197, 768] [32, 197, 768] -- --
│ └─Sequential (layers) [32, 197, 768] [32, 197, 768] -- False
│ │ └─EncoderBlock (encoder_layer_0) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_1) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_2) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_3) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_4) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_5) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_6) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_7) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_8) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_9) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_10) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_11) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ └─LayerNorm (ln) [32, 197, 768] [32, 197, 768] (1,536) False
├─Linear (heads) [32, 768] [32, 3] 2,307 True
============================================================================================================================================
Total params: 85,800,963
Trainable params: 2,307
Non-trainable params: 85,798,656
Total mult-adds (Units.GIGABYTES): 5.52
============================================================================================================================================
Input size (MB): 19.27
Forward/backward pass size (MB): 3330.74
Params size (MB): 229.20
Estimated Total Size (MB): 3579.21
============================================================================================================================================
2. Training Pretained Model
After preparing the pretrained model is time to test how the model will perform for our problem.
from modular_functions.utils import set_seeds
# Create optimizer and loss function according to the paper
optimizer = torch.optim.Adam(params=pretrained_vit.parameters(),
lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()
set_seeds(42, "mps") # <- As we have trained it on silcion chip you can change to "gpu"
# Train the classifier head of the pretrained ViT feature extractor model
pretrained_vit_results = engine.train(model=pretrained_vit.to(device),
train_dataloader=train_dataloader_pretrained,
test_dataloader=test_dataloader_pretrained,
optimizer=optimizer,
loss_fn=loss_fn,
epochs=10,
device=device)
After traing the test accuracy jump to 91% this is a big jump of improvement on our model.
Epoch: 1 | train_loss: 0.0598 | train_acc: 0.9922 | test_loss: 0.2011 | test_acc: 0.9384
Epoch: 2 | train_loss: 0.0406 | train_acc: 0.9922 | test_loss: 0.1897 | test_acc: 0.9280
Epoch: 3 | train_loss: 0.0302 | train_acc: 1.0000 | test_loss: 0.1914 | test_acc: 0.9280
Epoch: 4 | train_loss: 0.0229 | train_acc: 1.0000 | test_loss: 0.1901 | test_acc: 0.9280
Epoch: 5 | train_loss: 0.0181 | train_acc: 1.0000 | test_loss: 0.1895 | test_acc: 0.9280
Epoch: 6 | train_loss: 0.0174 | train_acc: 1.0000 | test_loss: 0.1907 | test_acc: 0.9280
Epoch: 7 | train_loss: 0.0130 | train_acc: 1.0000 | test_loss: 0.2071 | test_acc: 0.9384
Epoch: 8 | train_loss: 0.0109 | train_acc: 1.0000 | test_loss: 0.2050 | test_acc: 0.9384
Epoch: 9 | train_loss: 0.0127 | train_acc: 1.0000 | test_loss: 0.1992 | test_acc: 0.9384
Epoch: 10 | train_loss: 0.0086 | train_acc: 1.0000 | test_loss: 0.1945 | test_acc: 0.9176
There are other ways to improve the model from here but I would be focusing on that. I hope this help you to understand how paper replicating works and it importants on machine learning engineer journey.