Problem Statement
The goal of this project is to classify images according to their content. The training data is organized by labels, with each directory containing multiple images belonging to its respective class.
A snippet of the structure of /data
looks like this,
.
├── cat
│ ├── 0.jpg
│ ├── 1.jpg
│ ├── 2.jpg
│ └── 3.jpg
└── dog
├── 0.jpg
├── 1.jpg
├── 2.jpg
└── 3.jpg
Methodology ☀️
To tackle this image classification, we can leverage contrastive learning.
Contrastive learning is a technique where a model learns to distinguish if images are similar. Images of the same class are coined as positive pairs, while images of different classes are negative pairs. Positive pairs are often created by applying different data augmentations e.g., rotations, crops, color distortions to the same image.
In contrastive learning, the model contrasts positive and negative pairs, where similar images are pulled closer together and dissimilar images are pushed further apart from one another in the embedding space.
The embedding space is a learned representation space where images are mapped to vectors. The contrastive loss function, such as the InfoNCE loss, aims to minimize the distance between positive pairs and maximize the distance between negative pairs.
It is particularly effective in self-supervised learning without labels; however, it is also applicable in supervised learning, which is what we are handling as we have the labels of the images.
:> contrastive learning impression from imageGen
Contrastive Loss Functions Comparison
There are different variants of loss functions in contrastive learning, each adapted to different optimisation in the embedding space.
Loss Function | Positives (Closeness) | Negatives (Distance) | Use Cases |
---|---|---|---|
InfoNCE | 1 Augmented View | All Other Images (Batch) | Self-Supervised Learning |
SupCon | All Images of Same Class | All Other Classes | Supervised Learning |
Triplet Loss | 1 Same Class Image | 1 Different Class Image | Metric Learning |
Data Preprocessing ⏳
For machine to "see" the images, they must first be converted into numerical representations called embeddings. To illustrate the concept, we will initially represent each image as a simple 2-dimensional vector. However, in practice, models like CLIP generate high-dimensional embeddings e.g., 512 dimensions that capture richer information.
Illustrating with a simple example, we have two classes which are cat
and dog
.
After converting the images and normalizing the resulting embedding vectors, the data looks like this:
Class | Image | Raw Embedding (x, y) | Normalized Embedding (x, y) |
---|---|---|---|
cat | 0.jpg | [1.2, 0.9] | [0.800, 0.600] |
cat | 1.jpg | [0.8, 0.3] | [0.936, 0.351] |
cat | 2.jpg | [1.0, 1.0] | [0.707, 0.707] |
cat | 3.jpg | [1.7, 1.1] | [0.839, 0.543] |
dog | 0.jpg | [-1.0, 1.5] | [-0.555, 0.832] |
dog | 1.jpg | [-0.7, 0.7] | [-0.707, 0.707] |
dog | 2.jpg | [-0.5, 0.2] | [-0.928, 0.371] |
dog | 3.jpg | [-1.3, 0.9] | [-0.819, 0.567] |
Normalization is a crucial step as it helps with training stability by reducing the likelihood of numerical overflow / underflow in loss computation involving exponential function.
Similarity Metrics 🍃
Next, we calculate the similarity between all pairs of embeddings using the cosine similarity metric.
Assuming our normalized embeddings are stored in a tensor features of shape (N, D) (where N is the number of images and D is the embedding dimension),
we can compute the pairwise cosine similarity matrix using PyTorch:
similarity_matrix = torch.matmul(features, features.T)
or the equivalent similarity_matrix = features @ features.T
.
Subsequently, a scaling hyperparameter called temperature is often applied to this similarity matrix, typically by dividing the matrix elements by it.
For images cat/0.jpg, cat/1.jpg, dog/0.jpg, dog/1.jpg, the corresponding labels would be:
labels = torch.tensor([ 0, 0, 1, 1])
where class 0
is cat
and class 1
is dog
.
The temperature hyperparameter (τ) adjusts the model's sensitivity to the similarity scores. By scaling the similarity scores, it modifies how strongly the model differentiates between pairs of embeddings based on their similarity.
When a smaller value of τ ( nearer to 0 ) is used, the loss function is more sensitive to small differences in similarity scores. The exponential of the dot product exaggerates these differences, leading to a sharper probability distribution in the denominator of the logarithmic term.
As a result, the model penalised stronger for even slightly incorrect similarities. In another words, it focuses more on fine-grained distinctions of the training data.
Lower Temperature:
- Sharper similarity distribution: The model becomes more confident about its assignments based on similarity scores
- Amplifies differences between similarity scores: Small differences in high similarity scores lead to larger differences in the resulting probabilities (after softmax)
- Increases penalty for hard negatives: The model focuses more strongly on separating embeddings that are not in the same class but have relatively high similarity
Higher temperature "melts" the probability distribution, where samples are more evenly spread out in the embedding space
Deciding Factor:
-
Low temperature (e.g., 0.1): Use when you want the model to learn strong separation between classes and create tightly clustered embeddings for items within the same class. Emphasizes discriminating between positives and hard negatives. Can sometimes lead to instability if too low.
-
High temperature (e.g., 1.0): Use when you want a softer probability distribution. This treats negative pairs more uniformly, potentially preventing the model from focusing too narrowly on only the hardest negatives early in training. It allows for smoother gradients from a wider range of pairs. Can be useful if the task involves subtle intra-class variations or if low temperature causes training instability.
Looking back at the similarity scores derived from our 2D example embeddings, we observe high cosine similarity values for embeddings within the same class (cat-cat, dog-dog), and lower scores for cross-class pairs (cat-dog). This indicates a good initial separation between the classes in the embedding space, even before fine-tuning.
The diagonal elements of the computed similarity matrix correspond to the similarity of each embedding with itself. Since the embeddings are normalized, these values should theoretically be exactly 1.0 (cosine_similarity(v, v) = 1 for ||v||=1). In practice, due to standard floating-point arithmetic precision limits in computations, the calculated diagonal values might be extremely close but not precisely 1.0.
After applying the temperature to the similarity matrix, the matrices look like this,
Before Temperature:
tensor([[ 1.0000, 0.9594, 0.0552, -0.1414],
[ 0.9594, 0.9993, -0.2274, -0.4136],
[ 0.0552, -0.2274, 1.0002, 0.9806],
[-0.1414, -0.4136, 0.9806, 0.9997]])
After Temperature:
tensor([[ 1.4286, 1.3706, 0.0789, -0.2020],
[ 1.3706, 1.4276, -0.3249, -0.5908],
[ 0.0789, -0.3249, 1.4289, 1.4009],
[-0.2020, -0.5908, 1.4009, 1.4281]])
We can see that by applying a temperature that is less than one e.g 0.7, the original similarity scores
are exaggerated and the values move further away from 0. The amplification is even more profound in softmax function as
exp(1.428)
is significantly larger relative to exp(0.0789)
or exp(-0.2020)
as compared to exp(1)
with relative to exp(0.0552)
or exp(-0.1414)
.
The Loss Formula 🌾
Here we are using supervised contrastive (SupCon) loss,
Where:
The numerator:
- Focuses only on the similarity between anchor
i
and one specific positive - Measures the similarity between the anchor sample (z_i) and one specific positive sample (z_p) that belongs to the same class.
The denominator:
- Considers the anchor
i
and all other samplesa
in the batch - Represents the sum of similarities between the anchor (z_i) and all other samples (z_a) in the batch (both positives and negatives).
The fraction
- Like a softmax function, it calculates a normalised score (probability like value) representing how well the positive sample
p
is the correct match for anchor samplei
, given the similarities ofi
to all available candidatesa
. - The loss function encourages the model to make this fraction to approach 1 for every positive
p
Thinking further about the loss.. 💭
That means, similar images should have a higher numerator where the cosine similarity should be closer to 1. The fraction therefore is larger for positive pair and together with , the loss becomes a smaller value closer to 0.
However, if the scaled similarities are large positive numbers, exponentiating these large numbers potentially exceed the limits of standard floating point numbers resulting in numberical overflow and breaking the calculation.
Therefore, we can apply a constant shift to this softmax-like calculation to obtain a more stable version before the exponentiation.
logits_max, _ = torch.max(similarity_matrix, dim=1, keepdim=True)
logits = similarity_matrix - logits_max.detach()
## OUTPUT OF LOGITS
tensor([[ 0.0000, -0.0406, -0.9448, -1.1414],
[-0.0399, 0.0000, -1.2267, -1.4129],
[-0.9450, -1.2277, 0.0000, -0.0196],
[-1.1411, -1.4133, -0.0191, 0.0000]])
As a result, the largest value in each row of the logits should be 0 and all other values should be less than or equal to 0.
Numerator & Denominator Calculation
The core calculation within the Supervised Contrastive loss involves a fraction resembling a Softmax function.
For a given anchor item i
and a positive item p
(where label[i] == label[p]
), the term inside the logarithm is:
Numerator: Denominator:
Here, represents the scaled similarity .
Mask Generation 🎭
To correctly calculate the Supervised Contrastive loss components using matrix operations on the full logits
matrix, we typically need two masks based on item indices and labels within the batch:
1. Self-Mask (Exclusion Mask for Denominator):
The denominator requires excluding the anchor's similarity with itself (k=i
).
We create a self_mask that is 1 everywhere except for the diagonal, which is 0.
batch_size = features.shape[0] # Or logits.shape[0]
device = features.device # Or logits.device
# Creates a mask with 0 on diagonal, 1 elsewhere
self_mask = torch.ones_like(logits, device=device)
self_mask.fill_diagonal_(0)
# Alternative using scatter:
self_mask = torch.scatter(
torch.ones_like(logits, device=device),
1,
torch.arange(batch_size, device=device).view(-1, 1),
0
).to(device)
# -- SELF MASK OUTPUT --
tensor([[0., 1., 1., 1.],
[1., 0., 1., 1.],
[1., 1., 0., 1.],
[1., 1., 1., 0.]])
2. Positive Pair Mask (Label Mask for Loss Calculation):
To calculate the final loss, we need to identify all (i, p)
pairs in the batch that belong to the same class.
This mask is generated by comparing labels. It is 1 where label[i] == label[k]
and 0 otherwise.
It must also incorporate the self-mask logic to ensure i≠k
. So, positive_mask[i, k] = 1
if and only if label[i] == label[k]
AND i ≠ k
.
This mask selects the terms corresponding to positive pairs p
needed for the inner summation Σ_{p ∈ P(i)}
and helps calculate the number of positives |P(i)|
We leverage the labels tensor for this.
# Assumes 'labels' is a 1D tensor of size [batch_size]
labels_reshaped = labels.contiguous().view(-1, 1)
# Create mask where True if labels match (broadcasting [B,1] vs [1,B])
positive_mask = torch.eq(labels_reshaped, labels_reshaped.T).float().to(device)
# Exclude self-comparisons (diagonal) from being positive pairs
positive_mask = mask * self_mask
# -- OUTPUT OF POSITIVE MASK --
tensor([[1., 1., 0., 0.],
[1., 1., 0., 0.],
[0., 0., 1., 1.],
[0., 0., 1., 1.]])
# alternatively, positive mask can be created with:
positive_mask = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
# Exclude self-comparisons (diagonal) from being positive pairs
positive_mask.fill_diagonal_(0)
# -- FINAL POSITIVE MASK EXCLUDING SELF --
tensor([[0., 1., 0., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 1.],
[0., 0., 1., 0.]])
Continuing from the mask generation, let's see how the self_mask is used to calculate the denominator term defined earlier: .
Recall that the self_mask is designed to ignore the self-similarity term (k=i). We can apply this mask after exponentiating the entire logits matrix.
Multiplying the exponentiated logits by the self_mask effectively zeros out the diagonal (i=k) terms.
Summing the result across the k dimension then yields the correct denominator sum for each anchor i
,
as we need to exclude self-similarity term (k=i) from the sum.
# Calculate exp(logits[i, k]) for all k != i
# Note: 'logits' are the scaled and stabilized similarities
exp_logits_all_non_self = torch.exp(logits) * self_mask
# Calculate Denominator: Sum over k where k != i for each anchor i
sum_exp_logits_all_non_self = exp_logits_all_non_self.sum(dim=1, keepdim=True)
Log Probability Calculation
Instead of calculating the fraction (Numerator / Denominator) and then taking -log,
the code uses the more stable log_prob = logits - log(denominator)
.
This reduces the risk of intermediate overflow/underflow during the sum and log steps compared to doing it naively.
# Computes log-probabilities: logits[i, k] - log(denominator[i])
log_prob = logits - torch.log(sum_exp_logits_all_non_self + 1e-12)
The log_prob
is a [batch_size, batch_size]
tensor that contains log probabilities for every possible pair (i, k).
It shows how likely the item k
is considered as the correct match to anchor i
relative to other items.
Loss Calculation
Finally we are ready to put things in the loss function, we will be using the positive mask
which identifies the
positive pairs to calculate the loss.
The loss is the average loss contribution across all positive and all anchors in the batch.
# positive_mask * log_prob zeros out log_prob values for all non-positive pairs
# (negatives and self-pairs), keeping only log_prob[i, p] where p is a positive for i
# positive_mask.sum(1) calculates |P(i)|
mean_log_prob_positive = (positive_mask * log_prob).sum(1) / (positive_mask.sum(1).clamp(min=1e-12))
# Final loss is the mean over all anchors
loss = -mean_log_prob_positive.mean()
Examine Positive Pair
To observe the number of positive pair in each batch, we can log out
positive_mask.sum(1)
# -- NUMBER OF POSITIVE PAIRS E.G --
NUM POSITIVE PAIRS tensor([1., 1., 1., 1.])
For number of positive pairs [1, 1, 1, 1]
for labels [0, 0, 1, 1]
,
this suggests that each sample has exactly 1 other sample with the same label.
For lable [0, 0, 1, 2]
, the number of positive pairs would be [1, 1, 0, 0]
as
class 1
and class 2
has no positive pairs, and class 0
has one other sample of the same class in this batch.
For labels [0, 1, 2, 3]
, the number of positive pairs would be [0, 0, 0, 0]
as all classes have only one sample.
For samples with no positive pairs in the batch (singletons), the supervised attractive component of the loss is zero,
meaning they are not explicitly pulled towards other samples of the same class within that batch.
However, they still contribute to the loss as negative examples for other anchors.
For my dataset, the average number of images per class is ~ 5, therefore to increase the chance of the presence of positive pairs in a batch, there is a need to increase batch size.
With average 5 images/class, probability of singletons:
- Batch Size 4: ~40% chance of singleton classes
- Batch Size 8: ~10% chance
- Batch Size 16: ~1% chance (Based on Poisson distribution approximation)
Time to Finetune 🔆
The image encoder in CLIP follows the vision transformer architecture (ViT). The transformer encoder core consists of stacked transformer blocks which are multiple identical layers stacked sequentially.
After passing through the transformer blocks, the output state corresponding to the initial
[class]
token position is used as the aggregated representation of the image. The vector then
undergoes these final processing layers:
- Final layer normalisation
ln_post
:
This is a nn.LayerNorm
that helps with stablising the vector representation's statistics e.g mean, variance
- Projection layer
proj
:
This is a nn.Linear
that maps normalised features from ln_post
into the target joint embedding space.
The weights of proj
are optimised during pre-training to align image embeddings with the corresponding text
embeddings under a contrastive loss.
In our scenerio, our task is repurposed to image classification defined by supervised contrastive loss, operating on existing class label. As such, introducing a new MLP projector is suitable and the loss is calculated on the output of this new projector.
While the original image-text mapping objective is discarded, visual_model.proj
still serves as the final learned transformation
applied by the pre-trained network components before the features reach our task-specific self.projector
.
By unfreezing proj
, we allow gradients originating from the SupCon loss (calculated further down the network) to flow back
and update proj's weights. This enables the layer to adapt its mapping, refining the features it outputs to be more discriminative
or better suited for the class-based contrastive task, thereby providing a higher quality input to self.projector
.
def unfreeze_visual_encoder_layers(visual_model: nn.Module, unfreeze_last_n_blocks: int):
resblocks = None
if hasattr(visual_model, 'transformer'):
num_blocks = len(visual_model.transformer.resblocks)
resblocks = visual_model.transformer.resblocks
if unfreeze_last_n_blocks == 0:
unfreeze_start_index = num_blocks
elif unfreeze_last_n_blocks >= num_blocks:
unfreeze_start_index = 0
else:
unfreeze_start_index = num_blocks - unfreeze_last_n_blocks
for param in visual_model.parameters():
param.requires_grad = False
if resblocks is not None:
for i, block in enumerate(resblocks):
if i >= unfreeze_start_index:
for param in block.parameters():
param.requires_grad = True
ln_post_unfrozen = False
if hasattr(visual_model, 'ln_post') and isinstance(visual_model.ln_post, nn.LayerNorm):
for param in visual_model.ln_post.parameters():
param.requires_grad = True
ln_post_unfrozen = True
if ln_post_unfrozen: print("Unfroze ln_post.")
else: print("Cannot find ln_post.")
proj_unfrozen = False
if hasattr(visual_model, 'proj'):
if isinstance(visual_model.proj, torch.Tensor) and visual_model.proj.is_leaf:
visual_model.proj.requires_grad = True
proj_unfrozen = True
elif isinstance(visual_model.proj, nn.Module):
for param in visual_model.proj.parameters():
param.requires_grad = True; proj_unfrozen = True
else: print(f"proj type {type(visual_model.proj)} not handled.")
if proj_unfrozen: print("Unfroze proj.")
else: print("Cannot find proj.")
trainable = sum(p.numel() for p in visual_model.parameters() if p.requires_grad)
total = sum(p.numel() for p in visual_model.parameters())
print(f"Total number of trainable parameters: {trainable}/{total}")
Training Setup
Optimizer Configuration
We will be using AdamW
as the optimizer algorithm, initialising it to handle only the
parameters marked as trainable and setting the earning rate.
# Filter parameters that require gradients
# Use model.parameters() to include both encoder (if parts are unfrozen) and projector
params_to_train = filter(lambda p: p.requires_grad, model.parameters())
# Instantiate the optimizer
optimizer = torch.optim.AdamW(params_to_train, lr=lr, weight_decay=weight_decay)
Dataset & Dataloader A custom PyTorch Dataset is required to load images and their integer class labels.
Loss Function
Instantiate the Supervised Contrastive loss function with the desired temperature hyperparameter.
Use torch.utils.data.DataLoader
to create batches, ensuring a sufficiently large batch_size and
setting shuffle=True
for training.
loss = supervised_contrastive_loss(embeddings, labels, temperature)
With the optimizer, data pipeline, and loss function ready, the next step is the core training process.
Training Loop
With the optimizer, data loaders, and loss function prepared, the core fine-tuning process involves iterating through the dataset over multiple epochs. Each epoch processes the entire training dataset in batches.
# typical training loop
for epoch in range(num_epochs):
model.train() # set model to training mode
total_train_loss = 0.0
batches_processed = 0
for batch_idx, (images, labels) in enumerate(trainloader):
optmizer.zero_grad() # reset optimizer gradient
images, labels = images.to(device), labels.to(device)
current_batch_size = labels.shape[0]
embeddings = visual_model(images)
loss = contrastive_loss(embeddings, labels, temperature)
# --- Handle potential NaN/Inf loss ---
if torch.isnan(loss) or torch.isinf(loss):
logger.warning(f"Epoch {epoch+1}, Batch {batch_idx+1}: NaN or Inf loss! Skipping.")
continue
# Backforward pass
loss.backward()
# Update the trainable model weights based on the computed gradients
optimizer.step()
loss_item = loss.item()
total_train_loss += loss_item
batches_processed += 1
if batches_processed > 0:
avg_epoch_loss = total_train_loss / batches_processed
logger.info(f"Epoch [{epoch+1}/{num_epochs}] Average Training Loss: {avg_epoch_loss:.4f}")
else:
logger.warning(f"Epoch [{epoch+1}/{num_epochs}]: No batches processed successfully.")
Evaluation Metrics
Now that we have our weights, how do we assess the quality of embeddings learned during the finetuning process?
Since the model ends with the projector, it does not inherently include a classification head trained during the contrastive phase, thus classification accuracy cannot be derived directly. Alternate evaluation include using k-NN classification or linear probing to examine the quality of the learned embedding space.
k-NN Classification
This metric evaluates the clustering quality of the embedding space. Embeddings of images from the same class should be nearer to one another after finetuning.
Linear Probing Accuracy
The linear probing accuracy assesses the performance of the learned features applied for downstream tasks. It measures the linear separability of the classes in the embedding space.
These metrics are calculated at the end of each epoch or every few epochs to provide insights into the finetuning process.