Implementation of Vision Transformer Paper (ViT)
"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" paper (opens in a new tab), introduces a groundbreaking approach to image classification by leveraging the power of transformer architectures. Authored by Alexey Dosovitskiy et al., the paper challenges the conventional convolutional neural network (CNN) paradigm that has long dominated the field of computer vision. ViT extends the success of transformers from natural language processing to image analysis, demonstrating remarkable performance on various image recognition tasks. This introduction will provide an overview of the key concepts and contributions of the Vision Transformer paper, setting the stage for a deeper exploration of its innovative methodology and experimental results. This implementation will focus on the base architecure of ViT paper but will enable adjustment on the hyperparameters to the architecture.
- Data Preparations
- Architecture Implementation
- Training
- Visualizing Results
This implementation will focus on the base architecture with hidden size of 768 and 12 layers of the transformer encoder.
1. Data Preparation
In the Vision Transformer (ViT) paper, the authors propose a novel approach to handling image data by transforming it into a sequence of fixed-size patches before feeding it into a transformer model. This is a departure from traditional convolutional neural networks (CNNs), where the input images are processed using convolutional layers. This is the first step that should be considered before implementing the architecture.
The specific image transformation steps performed before training the Vision Transformer include:
-
Patch Extraction: The input image is divided into non-overlapping patches. Each patch is then treated as a token, forming a sequence of image patches.
-
Linear Projection: Each patch is linearly projected into a high-dimensional space, typically referred to as the embedding space. This projection allows the model to capture meaningful representations of image content.
-
Positional Encoding: To provide the transformer model with information about the spatial relationships between the patches, positional encodings are added. This helps the model understand the sequential order of the patches in the input sequence.
By converting the image into a sequence of patches and using a transformer architecture, the Vision Transformer effectively captures long-range dependencies and relationships within the image. This innovative approach has shown impressive results in image classification tasks, challenging the conventional wisdom that CNNs are the sole architecture suitable for computer vision applications.
After import the dataset and transform into tensor the main part of the data preparation is the patch extraction then linear projection then add a class token and position encoding of the patches. The following class transform and image tensor to the input shape that the transformer is needs as an input.
# Patch Embedding Class
class PatchEmbedding(nn.Module):
def __init__(self,
in_channels:int=3,
patch_size:int=16,
embedding_dim:int=768,
embedding_dropout:float=0.1):
super().__init__()
self.patcher = nn.Conv2d(in_channels=in_channels,
out_channels=embedding_dim,
kernel_size=patch_size,
stride=patch_size,
padding=0)
self.flatten = nn.Flatten(start_dim=2,
end_dim=3)
# Calculate number of patches (height * width/patch^2)
self.num_patches = (img_size * img_size) // patch_size**2
# Create learnable class embedding (needs to go at front of sequence of patch embeddings)
self.class_embedding = nn.Parameter(data=torch.randn(1, 1, embedding_dim),
requires_grad=True)
# Create learnable position embedding
self.position_embedding = nn.Parameter(data=torch.randn(1, self.num_patches+1, embedding_dim),
requires_grad=True)
# Create embedding dropout value
self.embedding_dropout = nn.Dropout(p=embedding_dropout)
def forward(self, x):
image_resolution = x.shape[-1]
assert image_resolution% patch_size == 0, f"Input image size must be divisble by patch size, image shape: {image_resolution}, patch size: {patch_size}"
class_token = self.class_embedding.expand(x.shape[0], -1, -1) # "-1" means to infer the dimension (try this line on its own)
return self.embedding_dropout(self.position_embedding + torch.cat((class_token, self.flatten(self.patcher(x)).permute(0, 2, 1)), dim=1)) # <- trying to do operation fusion
This is the first (1) equation that is shown on the paper (opens in a new tab). This is the most important part of this architecture am not saying that other other three equation are not but
2. Architecture Implementation
This part will be divided into two main:
2.1 Multi-self attention block
In the Vision Transformer (ViT) paper, the concept of a "Multi-head Self Attention Block" is a fundamental building block within the transformer architecture used for image recognition tasks. The ViT model employs multiple such attention blocks to capture rich contextual information from input sequences, in this case, sequences of image patches. Let's break down the key components of the Multi-head Self Attention Block based on the ViT paper:
-
Self Attention Mechanism:
- The core of the attention block is the self-attention mechanism, which allows the model to weigh different elements of the input sequence differently based on their relevance to each other.
- Self attention calculates attention scores for each element in the sequence relative to every other element, enabling the model to focus on different parts of the input sequence during processing.
-
Multi-head Attention:
- The ViT paper introduces the concept of using multiple attention heads within a single attention block. Each attention head learns a different set of attention weights, capturing diverse relationships in the input.
- The outputs from these multiple attention heads are concatenated and linearly projected to create a final set of representations.
-
Parameterized Linear Projections:
- The attention block includes linear projections to transform the input sequence into query, key, and value representations. These projections are learned during the training process.
- Multiple sets of projections are employed for each attention head, allowing the model to capture various aspects of the input information.
-
Normalization and Feedforward Layer:
- Normalization layers, such as layer normalization, are applied to the concatenated outputs of the attention heads.
- A feedforward neural network is then applied to further process the information, introducing non-linearity and capturing complex patterns.
class MultiHeadSelfAttentionBlock(nn.Module):
"""
Create a multi-head self attention block (MSA - Block)
"""
def __init__(self,
embedding_dim:int=768,
num_heads:int=12,
attn_dropout:int=0):
super().__init__()
# Create a layer norm (LN)
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
# Create a multihead attention (MSA) layer
self.multihead_attention = nn.MultiheadAttention(embed_dim=embedding_dim,
num_heads=num_heads,
dropout=attn_dropout,
batch_first=True) # is batch size first ? (batch, seq, feature) -> (batch, number_of_patches, embedding_dimesion)
def forward(self, x):
x = self.layer_norm(x)
attn_output, _ = self.multihead_attention(query=x,
key=x,
value=x,
need_weights=False)
# attn_output = attn_output + x
return attn_output
2.2 Multi Layer Perceptron
In the Vision Transformer (ViT) paper, the "MLP (Multi-Layer Perceptron) Block" is another crucial component of the transformer architecture. The MLP Block is employed after the multi-head self-attention mechanism in each transformer layer, contributing to the model's ability to capture and process hierarchical features from the input sequence of image patches. Let's delve into the key characteristics of the MLP Block based on the ViT paper:
-
Position-wise Feedforward Networks:
- The MLP Block consists of one or more position-wise feedforward networks. These networks operate independently on each position in the sequence, allowing the model to capture local patterns and non-linear relationships within the data.
- The position-wise feedforward networks include fully connected layers with a non-linear activation function (commonly ReLU) in between but in this architure according to the paper the MLP contains two layers with a GELU non-linearity (section 3.1).
-
Parameterized Linear Projections:
- Similar to the multi-head self-attention block, the MLP Block includes linear projections to transform the input features. These projections are learned during the training process.
- The linear projections are typically followed by activation functions, introducing non-linearity to the model.
-
Normalization:
- Normalization layers, such as layer normalization, are applied to the output of the MLP Block. This helps stabilize the training process and improve the model's generalization.
-
Skip Connection and Residual Connection:
- To facilitate the flow of information through the model, skip connections (also known as residual connections) are commonly employed. These connections allow the output of the MLP Block to be added to the input, aiding in the gradient flow during training.
Residual Connection - connects the output of one earlier convolutional layer to the input of another future convolutional layer several layers later
In summary, the Multi-head Self Attention Block in the ViT paper incorporates multiple attention heads to enable the model to attend to different aspects of the input image patches simultaneously. This parallel processing enhances the model's ability to capture diverse and hierarchical features in the data, contributing to its success in image recognition tasks. The role of the MLP Block is to capture and process local features within the input sequence. While the multi-head self-attention mechanism focuses on capturing global contextual information, the MLP Block helps the model capture and leverage detailed local patterns. The combination of these components contributes to the success of the Vision Transformer in image recognition tasks.
class MultiLayerPerceptron(nn.Module):
"""
Creates a layer normalized multilayer perceptron block ("MLP block" for short).
"""
# Initialize the class with hyperparameters from Table 1 and Table 3
def __init__(self,
embedding_dim:int=768, # Hidden Size D from Table 1 for ViT-Base
mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
dropout:float=0.1): # Dropout from Table 3 for ViT-Base
super().__init__()
# Create the Norm layer (LN)
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
# Create the Multilayer perceptron (MLP) layer(s)
self.mlp = nn.Sequential(
nn.Linear(in_features=embedding_dim,
out_features=mlp_size),
nn.GELU(), # "The MLP contains two layers with a GELU non-linearity (section 3.1)."
nn.Dropout(p=dropout),
nn.Linear(in_features=mlp_size, # needs to take same in_features as out_features of layer above
out_features=embedding_dim), # take back to embedding_dim
nn.Dropout(p=dropout) # "Dropout, when used, is applied after every dense layer.."
)
# Create a forward() method to pass the data throguh the layers
def forward(self, x):
# x = self.layer_norm(x)
# x = self.mlp(x)
return self.mlp(self.layer_norm(x)) # <- Operation Fusion
Transformer Encoder Block
So the transformer encoder consider of Multi-self Attetion (MSA) block and Multi-Layer Perceptron (MlP) block together as on block.
class TransformerEncoderBlock(nn.Module):
"""
Creates a Transformer Encoder block.
"""
# Initialize the class with hyperparameters from Table 1 and Table 3
def __init__(self,
embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
num_heads:int=12, # Heads from Table 1 for ViT-Base
mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
mlp_dropout:float=0.1, # Amount of dropout for dense layers from Table 3 for ViT-Base
attn_dropout:float=0): # Amount of dropout for attention layers
super().__init__()
# Create MSA block (equation 2)
self.msa_block = MultiHeadSelfAttentionBlock(embedding_dim=embedding_dim,
num_heads=num_heads,
attn_dropout=attn_dropout)
# Create MLP block (equation 3)
self.mlp_block = MultiLayerPerceptron(embedding_dim=embedding_dim,
mlp_size=mlp_size,
dropout=mlp_dropout)
# Create a forward() method
def forward(self, x):
# Create residual connection for MSA block (add the input to the output)
y = self.msa_block(x) + x
# Create residual connection for MLP block (add the input to the output)
z = self.mlp_block(y + x) + y
return z
Vision Transformer
After the first three parts of the architucture is time to put everything together but ending with the classifier function that translate the image to its label. The following
class VisionTransformer
is everything put together.
class VisionTransformer(nn.Module):
"""Creates a Vision Transformer architecture with ViT-Base hyperparameters by default."""
# Initialize the class with hyperparameters from Table 1 and Table 3
def __init__(self,
img_size:int=224, # Training resolution from Table 3 in ViT paper
in_channels:int=3, # Number of channels in input image
patch_size:int=16, # Patch size
num_transformer_layers:int=12, # Layers from Table 1 for ViT-Base
embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
num_heads:int=12, # Heads from Table 1 for ViT-Base
attn_dropout:float=0, # Dropout for attention projection
mlp_dropout:float=0.1, # Dropout for dense/MLP layers
embedding_dropout:float=0.1, # Dropout for patch and position embeddings
num_classes:int=1000): # Default for ImageNet but can customize this
super().__init__() # don't forget the super().__init__()!
# Make the image size is divisble by the patch size
assert img_size % patch_size == 0, f"Image size must be divisible by patch size, image size: {img_size}, patch size: {patch_size}."
# Create patch embedding layer
self.patch_embedding = PatchEmbedding(in_channels=in_channels,
patch_size=patch_size,
embedding_dim=embedding_dim)
# Create Transformer Encoder blocks (we can stack Transformer Encoder blocks using nn.Sequential())
# Note: The "*" means "all"
self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_size=mlp_size,
mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])
# 10. Create classifier head
self.classifier = nn.Sequential(
nn.LayerNorm(normalized_shape=embedding_dim),
nn.Linear(in_features=embedding_dim,
out_features=num_classes)
)
# Create a forward() method
def forward(self, x):
# Create patch embedding (equation 1)
x = self.patch_embedding(x)
# Pass patch, position and class embedding through transformer encoder layers (equations 2 & 3)
x = self.transformer_encoder(x)
# Put 0 index logit through classifier (equation 4)
x = self.classifier(x[:, 0]) # run on each sample in a batch at 0 index
return x
3. Training Vision Transformer
After implementing the vision architecture it's time to train the model and see the result of the architecture. For example as you are dealing with few dataset the test accuracy will be low as the vision architecture performs while on large datasets.
from modular_functions import engine
from modular_functions.utils import set_seeds
# Setup the optimizer to optimize our ViT model parameters using hyperparameters from the ViT paper
optimizer = torch.optim.Adam(params=vit.parameters(),
lr=3e-3, # Base LR from Table 3 for ViT-* ImageNet-1k
betas=(0.9, 0.999), # default values but also mentioned in ViT paper section 4.1 (Training & Fine-tuning)
weight_decay=0.3) # from the ViT paper section 4.1 (Training & Fine-tuning) and Table 3 for ViT-* ImageNet-1k
# Setup the loss function for multi-class classification
loss_fn = torch.nn.CrossEntropyLoss()
# Set the seeds
set_seeds(s=42,set_device='mps')
# Train the model and save the training results to a dictionary
results = engine.train(model=vit.to(device),
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
optimizer=optimizer,
loss_fn=loss_fn,
epochs=10,
device=device)
Epoch: 1 | train_loss: 4.0164 | train_acc: 0.2891 | test_loss: 3.3471 | test_acc: 0.2604
Epoch: 2 | train_loss: 1.8290 | train_acc: 0.4258 | test_loss: 1.7631 | test_acc: 0.2604
Epoch: 3 | train_loss: 1.1750 | train_acc: 0.4609 | test_loss: 1.4370 | test_acc: 0.2604
Epoch: 4 | train_loss: 1.2536 | train_acc: 0.2773 | test_loss: 1.4224 | test_acc: 0.1979
Epoch: 5 | train_loss: 1.1403 | train_acc: 0.4102 | test_loss: 1.2852 | test_acc: 0.2604
Epoch: 6 | train_loss: 1.1726 | train_acc: 0.2891 | test_loss: 1.3891 | test_acc: 0.1979
Epoch: 7 | train_loss: 1.2683 | train_acc: 0.2891 | test_loss: 1.3073 | test_acc: 0.1979
Epoch: 8 | train_loss: 1.1236 | train_acc: 0.2656 | test_loss: 1.2403 | test_acc: 0.2604
Epoch: 9 | train_loss: 1.1619 | train_acc: 0.3047 | test_loss: 1.0208 | test_acc: 0.5417
Epoch: 10 | train_loss: 1.2532 | train_acc: 0.2539 | test_loss: 1.4964 | test_acc: 0.2604
Improvements can be done on the test accuracy by using transfer learning on the same architecture trained on larger datasets. This will help to improve the test accuracy of the models. In the following we will discuss the ways to improve the model.