Understanding Paper: "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" with ChatGPT
Prompt
- Role: You are an experienced AI/ML researcher and teaching assistant.
- Objective: Help me understand the attached research paper through an interactive, step-by-step dialogue.
- Workflow & Ground Rules: When I ask “EXPLAIN” or Similar words, break down concepts as if teaching a graduate CS student.
- Define all symbols: Use analogies to common models (e.g., CNN, Transformer) if helpful.
- When I ask “CODE”, provide clean PyTorch-style pseudocode with comments and tensor shapes.
- If a question is ambiguous, ask a clarifying follow-up instead of guessing.
- Cite page or figure numbers from the PDF whenever you reference specific content.
- If you are unsure, say “I’m not certain—please verify in the original text.”
- Begin by replying only with: “Ready to analyze the paper. Awaiting your first instruction.”
Prompt 1
At first, I want an elaborate summary of the entire article. Return a bullet list with:
- Main problem addressed
- Core idea (3-4 sentences)
- Claimed contributions (short phrases)
Output 1
Vision Transformer (ViT) – High-level Take-aways
Main problem addressed
Convolutional Neural Networks (CNNs) dominate vision, yet they embed hand-crafted inductive biases (locality, translation equivariance) that may limit scalability. The paper asks: Can a standard NLP Transformer, with minimal changes, match or surpass state-of-the-art CNNs on image classification when trained at scale?-
Core idea (3–4 sentences)
- Treat an image as a sequence of non-overlapping $$(P \times P)$$ patches (e.g., 16×16) and linearly project each flattened patch to a D-dimensional token embedding (Eq. 1, p. 3).
- Prepend a learnable “[class]” token, add 1-D positional embeddings, and feed the resulting sequence into a vanilla Transformer encoder, unchanged from BERT except for patch sizing (Fig. 1, p. 2).
- Supervised pre-training on very large image corpora (ImageNet-21k, 14 M images; JFT-300M, 303 M images) compensates for the lack of CNN inductive bias.
- After fine-tuning (often at higher resolution) the resulting Vision Transformer attains competitive or superior accuracy to strong CNN baselines while using 2–4× less pre-training compute (Table 2, p. 5; Fig. 5, p. 7).
-
Claimed contributions
- Introduce Vision Transformer (pure Transformer on image patches)
- Show scaling law: larger datasets > inductive bias for vision
- Achieve 88.55 % ImageNet top-1 with ViT-H/14 (state-of-the-art at submission)
- Demonstrate strong transfer: 94.55 % CIFAR-100, 77.63 % VTAB mean
- Provide compute-efficient training (≥2× less TPU core-days than prior SOTA)
- Release code & pre-trained checkpoints for reproducibility
Prompt 2
- Summarize the Introduction section in ≤200 words.
- Highlight:
- Practical or theoretical gap the authors claim.
- Prior work they say is insufficient.
- Specific real-world stakes (datasets, applications, societal impact). ## Output 2
The authors confront a widening gap between the scalability of Transformers in NLP and their modest traction in computer-vision pipelines. They argue that most vision research still leans on Convolutional Neural Networks (CNNs) whose built-in inductive biases (locality, translation equivariance) become liabilities when models and datasets scale.
-
Practical/Theoretical gap claimed
- No compelling evidence yet shows that a pure Transformer—devoid of convolutional components—can match leading CNNs on large-scale image classification.
- Existing attention-augmented CNN hybrids inherit CNN constraints, limiting architectural simplicity and hardware efficiency.
-
Why prior work is insufficient
- Earlier “fully-attention” models either keep some CNN stages, use specialized sparse or local attention patterns that complicate accelerator deployment, or were only tested on small images/datasets (e.g., 2×2-patch models, Sparse/Axial/Swin variants).
- Self-supervised iGPT reaches 72% ImageNet top-1—far below CNN SOTA—demonstrating the shortfall of prior Transformer attempts when data or compute are limited.
-
Real-world stakes highlighted
- Benchmarks: ImageNet-1k (1.3 M images), ImageNet-21k (14 M), and Google’s JFT-300M (303 M) are emphasized as proving grounds; performance on smaller transfer sets (CIFAR-100, VTAB suite) gauges generalization.
- Applications: Efficient training on these massive corpora matters for industrial-scale visual recognition systems where compute budgets and hardware utilization dictate feasibility.
- Societal impact: Demonstrating that inductive biases can be learned rather than hard-coded informs future resource allocation—shifting focus from manual architecture design to data curation and availability.
Prompt 3
Create a hierarchical outline of the paper. For every top-level section include a one-line purpose statement. For sub-sections, add a short note (~5 words) on what is done or proved there. Return in Markdown outline format.
Output 3
Hierarchical Outline of “An Image Is Worth 16×16 Words”
- Abstract Purpose: Summarises the problem, method and key results.
- Introduction Purpose: Motivates replacing CNNs with pure Transformers for vision. \ 2.1 NLP success of Transformers – scalability noted. \ 2.2 CNN dominance in vision – limits. \ 2.3 Prior hybrid/self-attention models – scaling issues. \ 2.4 Proposed ViT idea – treat patches as tokens. \ 2.5 Data-scale hypothesis – large data beats inductive bias. \ 2.6 Headline results – ImageNet 88.55%, etc. \
- Related Work Purpose: Positions ViT against prior attention and large-scale vision efforts. \ 3.1 Transformers in NLP – pre-training paradigm. \ 3.2 Scaling attention to images – local/sparse variants. \ 3.3 Closest prior (2×2 patches) – resolution limit. \ 3.4 CNN-attention hybrids – diverse tasks. \ 3.5 iGPT comparison – 72% top-1. \ 3.6 Large-data CNN studies – motivate ViT scale. \
- Method Purpose: Describes Vanilla Vision Transformer architecture and training tricks. \ 4.1 ViT overview – patch embedding + class token. \ 4.2 Transformer encoder math – Eqs 1-4. \ 4.3 Inductive bias discussion – minimal locality. \ 4.4 Hybrid variant – CNN stem option. \ 4.5 Fine-tuning & resolution – pos-embed interpolation. \
- Experiments Purpose: Empirically evaluate ViT versus CNNs, data size, compute. \ 5.1 Setup – datasets, baselines, hyper-params. \ 5.2 SOTA comparison – ViT-H/14 beats BiT. \ 5.3 Data requirement study – bigger data helps ViT. \ 5.4 Scaling study – accuracy vs exaFLOPs. \ 5.5 Model inspection – filters, position sims, attention distance. \ 5.6 Self-supervision pilot – masked patch 79.9%. \
- Conclusion Purpose: Recaps findings, notes future research paths. \ 6.1 Pure Transformer viable – cheaper compute. \ 6.2 Open challenges – detection, SSL gap, further scaling. \
- Acknowledgements Purpose: Credits contributors and infrastructure support. \
- References \ Purpose: Lists cited prior literature.\
- **Appendices
Prompt 4
Task:
- Explain Figure 1 of the attached paper in detail, focusing on the model architecture it illustrates.
- Produce a numbered list that traces the data/control flow, one step per diagram arrow: • Use the exact component name or label shown in the figure. • State the main operation performed in ≤50 words. • Specify the input(s) and output(s) for that step. ## Output 4 ## 1 Detailed Explanation of Figure 1
Figure 1 presents a bird’s-eye view of the Vision Transformer (ViT) classifier. The key idea is to treat an image as a sequence of patch tokens, feed those tokens—plus a special class token—into a standard Transformer encoder, and then read out the class prediction from the final representation of that class token. Nothing in the core encoder is vision-specific; all image handling happens in the patch-embedding “stem” and in the learned positional embeddings that preserve spatial order. The architecture therefore re-uses mature NLP Transformer implementations while replacing CNN inductive biases with scale-driven learning.
Major blocks shown:
- Linear Projection of Flattened Patches – cuts the image into non-overlapping P×P patches, flattens each, and maps it to a D-dimensional vector.
- [class] Embedding – a learned vector prepended to the patch sequence whose final state will carry the whole-image representation.
- Patch + Position Embedding – adds learnable 1-D positional embeddings to every token (patch or class).
- Transformer Encoder (stack repeated L times) – alternates Multi-Head Self-Attention (MSA), LayerNorm and MLP sub-blocks with residual connections.
- MLP Head – a task-specific classifier (1-layer during fine-tuning) that transforms the final [class] representation into logits over K classes.
Flow: Image → patch tokens → positional addition → Transformer stack → [class] vector → logits.
2 Arrow-by-Arrow Data/Control Flow
-
Input Image → Linear Projection of Flattened Patches
- Operation: Slice image into P×P non-overlapping patches, flatten each, apply a linear map.
- Inputs: H×W×C image.
- Outputs: N patch embeddings ∈ℝᴺ×ᴰ.
-
[class] Embedding → Sequence Concatenation
- Operation: Prepend a learned [class] vector to patch embeddings.
- Inputs: [class] token (1×D), patch embeddings (N×D).
- Outputs: (N+1)-token sequence ∈ℝᴺ⁺¹×ᴰ.
-
Position Embedding Addition → Patch + Position Embedding
- Operation: Add learnable 1-D positional vectors element-wise to each token.
- Inputs: token sequence, positional table (N+1×D).
- Outputs: position-encoded sequence ∈ℝᴺ⁺¹×ᴰ.
-
Sequence → Transformer Encoder (Layer 1: Multi-Head Attention)
- Operation: Each token attends to all tokens; produces context-mixed representations.
- Inputs: position-encoded sequence.
- Outputs: attention-updated sequence (same shape).
-
Residual Add & Norm (Layer 1)
- Operation: Add skip connection and apply LayerNorm.
- Inputs: previous layer input and attention output.
- Outputs: normalized sequence.
-
MLP Block (Layer 1)
- Operation: Two dense layers with GELU, projecting D→D_MLP→D.
- Inputs: normalized sequence.
- Outputs: feed-forward updated sequence.
-
Residual Add & Norm (Layer 1)
- Operation: Second skip connection + LayerNorm.
- Inputs: attention-normalized sequence and MLP output.
- Outputs: sequence for next layer.
-
Steps 4–7 repeat L − 1 times in Transformer Encoder
- Operation: Deeper context mixing via identical layers.
- Inputs: sequence from previous layer.
- Outputs: final encoded sequence z_L ∈ℝᴺ⁺¹×ᴰ.
-
Extract [class] Token → LN
- Operation: Apply LayerNorm to final [class] vector.
- Inputs: first token of z_L.
- Outputs: image representation y ∈ℝᴰ.
-
Representation y → MLP Head
- Operation: Single linear layer (fine-tuning) mapping D→K logits.
- Inputs: y.
- Outputs: class-score vector logits ∈ℝᴷ.
-
Logits → Softmax (not drawn)
- Operation: Convert logits to class probabilities.
- Inputs: logits.
- Outputs: probability distribution over K classes.
The eleven numbered steps follow every arrow in Figure 1 and collectively describe the complete forward path from raw image pixels to predicted class probabilities.
Prompt 5
Can you explain bit more about Extra class token specially what is the input how it learn. "Can suggest a better version of the question and ask me if I would like to use it instead.”
Output 5
Clarifying the “extra [class] token” in Vision Transformer
Below is a deeper look at what the [class] token is, why it is added, and how it learns during training.# Clarifying the “extra [class] token” in Vision Transformer
Below is a deeper look at what the [class] token is, why it is added, and how it learns during training.
1 What exactly is fed in as the [class] token?
- It is not derived from the image.
- It is a single, learned D-dimensional vector (same width as any patch embedding) that is randomly initialized along with the rest of the model parameters.
- During every forward pass the token is simply prepended to the patch-embedding sequence, so the input length becomes N + 1.
2 Why add it at all?
- Global pooling substitute Transformers output a vector per token; one must be chosen to represent the whole image. The [class] token gives the model a designated slot whose final hidden state becomes that representation.
- Information sink Because this token attends to—and is attended by—all patch tokens in each self-attention layer, it can collect a summary of the entire image content.
- Compatibility with NLP tooling BERT uses an identical mechanism for sentence-level tasks; ViT inherits that convention and re-uses the same code path.
3 How does the token “learn”?
-
Parameter learning
- Its initial embedding vector is optimized like any other weight through back-propagation.
- With L layers of self-attention, the token is repeatedly updated by mixing with patch tokens; gradients therefore reflect the classification loss.
-
Representation learning
- Early layers let the token gather coarse image context.
- Deeper layers refine it into a discriminative vector.
- The final LayerNorm output (y = LN(zₗ)) is what the MLP head reads.
- No special loss term Learning is entirely driven by the downstream cross-entropy (or whatever task loss) applied to the head’s logits.
4 Couldn’t we just use global average pooling instead?
Yes—and the authors tried it. A GAP-based head works almost identically if you tune the learning rate, but the [class] token keeps the architecture identical to BERT and avoids extra pooling code. It also lets the model decide how much emphasis to place on each patch rather than averaging blindly.
Suggested clearer question
“Can you explain how the Vision Transformer’s learned [class] token is initialized, updated during training, and used for classification? If you prefer, I can rephrase the question this way—would you like me to?”
Prompt 6
Can suggest a better version of the question and ask me if I would like to use it instead.”
- Question = "How this class token result in representing complete image" ## Output 6 Here is a cleaner, more specific way to phrase your question: “How does the Vision Transformer’s learnable [class] token end up summarizing the entire image for classification?” Would you like to use this wording?
Prompt 7
How the Vision Transformer’s [class] token learns to summarize the whole image
Output 7
1. Where the [class] token fits in the ViT pipeline
- ViT first chops an image into fixed-size patches (e.g., 16×16 px).
- Each patch is linearly projected to a D-dimensional embedding and prepended with a single extra learnable vector dubbed [class].
- The resulting sequence (length = number of patches + 1) is fed through L Transformer encoder layers, each containing multi-head self-attention (MSA) and an MLP.
- After the final layer, the model only extracts the updated [class] token (often called zL0 in the paper) and passes it to a lightweight classifier head.
Thus, the network’s entire downstream decision hinges on the information that flows into that single vector.
2. Why information concentrates into the [class] token
Self-attention lets every token exchange information with every other token. During training, gradients from the classification loss propagate - only- through the [class] vector and its incoming attention weights. To minimize the loss, the model learns two complementary behaviours:
Patch-to-class attention grows.
Early layers already show heads where patch tokens assign high attention weights to the [class] token, creating a funnel that channels image content toward it.Class-to-patch attention becomes selective.
Later layers feature heads where the [class] token attends back to spatial tokens that are semantically relevant (e.g., bird silhouette, steering wheel), effectively reading the scene. Visualizations in ViT-L/16 reveal that by layer 10+, many heads attend almost globally, but with strongest weights on object regions.
Together, these two flows repeatedly mix and refine features, so by the final layer the [class] embedding is a nonlinear, global aggregation of the image.
3. Empirical evidence
Observation | Paper section | Evidence |
---|---|---|
Some heads attend to most of the image already in low layers | §4.5, Fig. 7 (right) | Average attention distance of several heads ≈ image size |
[class] attention maps highlight the target object | §4.5, Fig. 6 and Appendix Fig. 14 | Qualitative rollout shows focus on semantically relevant pixels |
Removing the [class] token and using global average pooling works only after retuning LR, confirming it’s a learnable, not structural, choice | Appendix D.3, Fig. 9 | Need different LR; otherwise large drop |
4. Intuition vs. CNN global pooling
A CNN funnels spatial information by gradually shrinking feature maps and finally performing global average pooling. ViT achieves a similar information bottleneck, but:
- The bottleneck exists from the start (single extra token),
- It relies on learned attention patterns rather than fixed convolutional kernels,
- It keeps full-resolution patch tokens to enable long-range reasoning until the last layer.
5. Practical take-aways
- The [class] token is not magic; it works because self-attention plus the loss force it to become a global summary.
- Visualization or probing (e.g., attention rollout) can help verify your fine-tuned ViT is focusing on the right regions.
- If you switch to alternatives like mean-pooled patch tokens, re-tune optimization hyper-parameters—performance can match the [class] approach but is sensitive to learning rate.
The [class] token therefore acts as a learned “collector” that repeatedly gathers, mixes and refines patch information through attention until it holds a representation rich enough for accurate image-level decisions.
Prompt 8
For equation 1 to 4 , Explain all the mathematical functions in the text to a under grad student. Who knows basic deep learning, machine learning but not this technique.
Do NOT omit symbols: rewrite each equation, define every variable, and describe each step in plain language.
Where relevant, draw analogies to known models (e.g., Transformer, CNN). Use numbered steps for algorithm flow.
Output 8
Equations (1) – (4) in Vision Transformer (ViT)
Below each formula you’ll find
• the exact equations
• definitions of every symbol
• a plain-language explanation
• short analogies to standard NLP Transformers or CNNs
Equation (1)
Variable glossary
- $$(x\in\mathbb{R}^{H\times W\times C})$$ – full RGB image (height (H), width (W), channels (C)).
- Patch size $$(P\times P)$$; number of patches $$(N=\tfrac{HW}{P^{2}})$$.
- $$(x^{i}_{p}\in\mathbb{R}^{P^{2}C})$$ – flattened pixel vector of patch $$(i)$$.
- $$(E\in\mathbb{R}^{(P^{2}C)\times D})$$ – learnable linear projection (turns a patch into a $$(D)$$-dimensional “word embedding”).
- $$(x_{\text{class}}\in\mathbb{R}^{1\times D})$$ – learnable [CLS] token that will gather global image information.
- $$(E_{\text{pos}}\in\mathbb{R}^{(N+1)\times D})$$ – learnable positional embeddings (one per token, including the class token).
- $$(z^{0}\in\mathbb{R}^{(N+1)\times D})$$ – sequence injected into the Transformer.
Plain-language explanation
- Cut the image into non-overlapping tiles (patches).
- Flatten each tile’s pixels into a long vector.
- Project every vector through the same linear layer $$(E)$$ to get a patch embedding.
- Prepend a special learnable token $$(x_{\text{class}})$$.
- Add positional embeddings so the model knows where each patch came from.
Analogy: identical to mapping words→embeddings in BERT, just replacing “words” with image patches.
Equation (2)
\boxed{\,z'_{\ell}
= \mathrm{MSA}\!\bigl(\mathrm{LN}(z_{\ell-1})\bigr) + z_{\ell-1}
\quad\text{for } \ell = 1\ldots L}
-
Symbols
- Layer index $\ell$.
- $\mathrm{LN}(\cdot)$ Layer Normalization (normalizes features within a token).
- $\mathrm{MSA}(\cdot)$ Multi-Head Self-Attention (standard Transformer block).
- Residual “$+$” adds the attention output back to the input $z_{\ell-1}$ (skip connection).
- $z'_{\ell}$ Intermediate sequence after the attention sub-layer.
-
Plain-language view
- Normalize current token embeddings.
- Let every patch attend to every other (plus the class token) via multi-head attention, mixing global information.
- Add the original embeddings back in (skip-connection helps learning).
Analogy: Same attention step as BERT; unlike CNNs, receptive field is instantly global.
Equation (3)
\boxed{\,z_{\ell}
= \mathrm{MLP}\!\bigl(\mathrm{LN}(z'_{\ell})\bigr) + z'_{\ell}
\quad\text{for } \ell = 1\ldots L}
-
Symbols
- $\mathrm{MLP}(\cdot)$ Two fully-connected layers with a GELU activation in between (also called “feed-forward network”). Dimensionality: $D \rightarrow 4D \rightarrow D$.
- Remaining symbols as before.
-
Plain-language view
- Normalize again.
- Process each token independently with a small neural network (adds non-linearity and channel-wise mixing).
- Add skip connection.
Analogy: Comparable to a 1×1 convolution (channel mixing) applied to every location in a CNN.
Equation (4)
\boxed{\,y = \mathrm{LN}\bigl(z^{0}_{L}\bigr)}
-
Symbols
- $z^{0}_{L}\in\mathbb{R}^{D}$ The class-token embedding after the final (L-th) layer; we take index 0 along the sequence (because we prepended it first).
- $y\in\mathbb{R}^{D}$ Final image representation after LayerNorm; fed to a softmax or MLP head for classification.
- Plain-language view
The final, normalized vector of the special class token is the model’s distilled understanding of the whole image—just like the “[CLS]” vector in text BERT.
Putting it Together – Algorithm Flow
- Patchify & Embed Apply Eq.(1) to build the token sequence with positional info.
- Repeat for ℓ = 1…L 2.1 Self-Attention with skip → Eq.(2) 2.2 MLP with skip → Eq.(3)
- Readout Take the Layer-NORMed class token (Eq.(4)).
- Head Pass $y$ through a linear or MLP layer to predict class logits.
Key Intuitions
- Global context early: Unlike CNNs, attention lets even low-level layers see the entire image, not just a small kernel.
- Minimal vision bias: The only “inductive bias” is square patching and adding position embeddings; everything else is generic Transformer machinery.
- Same math as text Transformers: If you understand BERT’s forward pass, swap “words” for “image patches” and you’ve got ViT’s core.
Prompt 9.1
CODE Convert Equation (1) to (4) combining with steps in Figure . into clean, runnable PyTorch pseudocode. . Focus on clarity, variable shapes, and comments; Return ONLY the code block.
Vision Transformer – PyTorch-style pseudocode
- Implements Equations (1)–(4) + the flow from Figure 1
- Shapes follow (B = batch, C = channels, H,W = image size, N = #patches, D = embed dim)
import torch
import torch.nn as nn
class ViTEncoder(nn.Module):
def __init__(
self,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768, # D
depth: int = 12, # L
num_heads: int = 12,
mlp_ratio: int = 4,
):
super().__init__()
assert img_size % patch_size == 0, "image size must be divisible by patch size"
self.N = (img_size // patch_size) ** 2 # patches per image
# (Eq. 1) — linear projection of flattened patches
self.patch_embed = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size,
stride=patch_size) # → (B, D, H/P, W/P)
# learnable class token + 1-D position embeddings
self.cls_tok = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, self.N + 1, embed_dim))
# Transformer encoder layers (Eqs. 2–3)
self.layers = nn.ModuleList([
nn.ModuleDict(dict(
ln1 = nn.LayerNorm(embed_dim),
attn = nn.MultiheadAttention(embed_dim, num_heads,
batch_first=True),
ln2 = nn.LayerNorm(embed_dim),
mlp = nn.Sequential(
nn.Linear(embed_dim, embed_dim * mlp_ratio),
nn.GELU(),
nn.Linear(embed_dim * mlp_ratio, embed_dim),
)))
for _ in range(depth)
])
self.ln_final = nn.LayerNorm(embed_dim) # (Eq. 4)
def forward(self, x):
B = x.size(0) # (B, C, H, W)
# ---- Equation (1) ---------------------------------------------------
x = self.patch_embed(x) # (B, D, H/P, W/P)
x = x.flatten(2).transpose(1, 2) # (B, N, D)
cls = self.cls_tok.expand(B, -1, -1) # (B, 1, D)
z = torch.cat((cls, x), dim=1) + self.pos_embed # (B, N+1, D)
# ---- Equations (2) & (3) repeated L times --------------------------
for blk in self.layers:
z = z + blk.attn(blk.ln1(z), blk.ln1(z), blk.ln1(z))[0]
z = z + blk.mlp(blk.ln2(z))
# ---- Equation (4) ---------------------------------------------------
z = self.ln_final(z) # (B, N+1, D)
return z[:, 0] # CLS token ⇒ image rep
Prompt 9.2
“Could you show a complete PyTorch data-pipeline (dataset class, transforms, and dataloaders) plus a concise training loop for classifying the Oxford-IIIT Pet images with integrating ViTEncoder ? Please keep the code runnable end-to-end and highlight any important hyper-parameters.
# ---------------------------------------------------
# 1. CONFIG
# ---------------------------------------------------
CFG = dict(
img_size = 224, # resize-crop for ViT
patch_size = 16, # must match encoder checkpoint
batch_size = 64,
lr = 3e-4, # Adam default β1=0.9, β2=0.999
weight_decay = 0.05,
epochs = 12,
num_workers = 4, # >0 for speed if your OS allows
model_name = "vit_base_patch16_224" # timm preset
)
# ---------------------------------------------------
# 2. DATA PIPELINE
# ---------------------------------------------------
import torch, timm
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
train_tfms = transforms.Compose([
transforms.Resize(CFG["img_size"] + 32),
transforms.RandomResizedCrop(CFG["img_size"]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
val_tfms = transforms.Compose([
transforms.Resize(CFG["img_size"] + 32),
transforms.CenterCrop(CFG["img_size"]),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
train_set = datasets.OxfordIIITPet(
root="data", split="trainval", target_types="category",
transform=train_tfms, download=True
)
val_set = datasets.OxfordIIITPet(
root="data", split="test", target_types="category",
transform=val_tfms, download=True
)
train_loader = DataLoader(
train_set, batch_size=CFG["batch_size"], shuffle=True,
num_workers=CFG["num_workers"], pin_memory=True
)
val_loader = DataLoader(
val_set, batch_size=CFG["batch_size"], shuffle=False,
num_workers=CFG["num_workers"], pin_memory=True
)
# ---------------------------------------------------
# 3. MODEL – ViT encoder + new classifier
# ---------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
vit = timm.create_model(
CFG["model_name"],
pretrained=True,
num_classes=0 # drop original head
)
in_features = vit.num_features
model = nn.Sequential(
vit,
nn.Linear(in_features, 37) # 37 pet breeds
).to(device)
# ---------------------------------------------------
# 4. OPTIMIZER & LR SCHEDULER
# ---------------------------------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(),
lr=CFG["lr"],
weight_decay=CFG["weight_decay"])
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=CFG["epochs"]
)
# ---------------------------------------------------
# 5. TRAIN / VAL LOOP
# ---------------------------------------------------
for epoch in range(CFG["epochs"]):
# --- training ---
model.train()
running_loss, correct, seen = 0, 0, 0
for x, y in train_loader:
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
optimizer.zero_grad()
logits = model(x)
loss = criterion(logits, y)
loss.backward()
optimizer.step()
running_loss += loss.item() * x.size(0)
correct += logits.argmax(1).eq(y).sum().item()
seen += x.size(0)
train_acc = 100 * correct / seen
train_loss = running_loss / seen
# --- validation ---
model.eval()
correct, seen = 0, 0
with torch.no_grad():
for x, y in val_loader:
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
logits = model(x)
correct += logits.argmax(1).eq(y).sum().item()
seen += x.size(0)
val_acc = 100 * correct / seen
scheduler.step()
print(f"Epoch {epoch+1:02}/{CFG['epochs']} "
f"loss {train_loss:.3f} "
f"train acc {train_acc:.1f}% "
f"val acc {val_acc:.1f}%")
# ---------------------------------------------------
# 6. SAVE (optional)
# ---------------------------------------------------
torch.save(model.state_dict(), "vit_pet_classifier.pth")
Top comments (0)