Introduction
The frontier of AI research has seen remarkable intersections. Merging the domains of computer vision and natural language processing, the question arises: Can an AI discern and understand language directly from its visual representation, i.e., from raw pixels? In this blog, I attempt to figure out how well AI can understand natural language directly from images.
The link to the repository is here. Please keep in mind that the project is WIP and that the README and code still require some tweaking.
Problem Statement
Interpreting language from visual cues isn’t merely an OCR task; it’s about understanding context, semantics, and even masked information. In Layman’s terms, I am trying to figure out how well AI models can understand text from raw image pixels, the same way we do. To that end, I used the self-supervised masked-language modeling (MLM) paradigm.
Fig 1 - Examples of small, medium and large images used for training. Tokens were masked similarly to the BERT MLM task. Texts with less than 50 tokens are considered small, less than 100 medium and others are large. This was done to make data generation easier. It was also motivated by curriculum learning. The dataset contains many font sizes and font types.
To provide a bit of context — initially, I structured the task without the MLM part. The goal was to literally reconstruct the text from an image, the same way OCR does it but end-to-end. This worked fine, but the model failed to understand the text.
To alleviate that issue, I masked some words on the image, leaving the target text unmasked; several examples can be seen on Fig 1. This lead to a significant boost in terms of understanding the text, but there is still a lot of room for improvement and more ways to formulate the task. One of them could be to task the model with reconstructing removed pixel from an image.
On a high level, the architecture is quite simple:
Fig 2 — An overview of the architecture used for training. For downstream tasks, such as classification, the transformer decoder and linear layers are removed.
The CNN is used to capture textual features and is important for convergence. Encoder should learn to understand the context, which is then used by the decoder to output either the reconstructed token or the predicted mask token. The linear layer maps the decoder’s output to a token from our vocabulary.
It is possible the CNN layer can be removed (we can use patches similar to ViT), but I haven’t managed to get good results without it so far.
Architecture
This section is slightly more technical, so feel free to skip it if interested only in results!
The project’s backbone relies on the hybrid Convolutional Neural Networks (CNNs) and transformers model. Lets dive a bit deeper into the architecture!
Feature Extraction with CNNs - This layer was important for convergence, but merely increasing its complexity didn't amplify the results. This suggests potential saturation or the necessity for more intricate architectural adjustments
- ResNetFeatureExtractor - Our model employs a ResNet-based CNN, slightly adjusted for the task at hand. This module takes care of converting the raw image into a flattened set of feature maps. Each of these feature maps captures intricate patterns present in the image, readying it for further processing.
import torch.nn as nn
from torchvision.models import resnet50
class ResNetFeatureExtractor(nn.Module):
def __init__(self, feature_map_size, out_features_size):
super(ResNetFeatureExtractor, self).__init__()
self.feature_extractor = nn.Sequential(
*list(resnet50(pretrained=True).children())[:-2])
self.input_proj = nn.Conv2d(feature_map_size,
out_features_size,
kernel_size=1)
def forward(self, x):
x = self.feature_extractor(x)
x = self.input_proj(x)
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1)
return x
SinePositionalEncoding: Subsequent to the ResNet module, the SinePositionalEncoding takes over. This layer is crucial, as it infuses our feature maps with positional information. Unlike sequences in pure NLP tasks, images don't inherently have a sequence, and this positional encoding provides our model with spatial insights. Below is the code used for the positional encoding. A potential improvement could be to use 2D positional encoding.
class SinePositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=1000):
super(SinePositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
return x + self.pe[:, :x.size(1)].requires_grad_(False)
The Transformer and VLP model - The choice of a transformer was motivated by its universal appeal and scalability. The input from the CNN was forwarded to a transformer encoder, which then passed its output to a transformer decoder. This encoder-decoder process allowed us to handle the image data token-by-token, including the prediction of masked tokens in some cases.
- Transformer configuration: With the model_config_factory function, our architecture remains highly configurable. It allows for the easy specification of model parameters like model_dim, num_heads, and more. This ensures that I can quickly adjust and scale our model based on the demands of our data and the insights from previous experiments. I follow the traditional transformer implementation; the full code can be found here. Below is the config used for the best performing model:
...,
"encoder_decoder_lg": {
"model_dim": 768,
"ff_dim": 4096,
"num_heads": 16,
"num_layers": 12,
"feature_map_size": 2048, # from the resnet model
"dec_div": 2
}, ...
VLP Class: This is the centerpiece. Once the image features are extracted using the feature_extractor, these features are passed into our transformer for processing. The structure of the transformer can vary based on the choice – an encoder, a decoder, or a combination of both.
class VLP(nn.Module):
def __init__(self, model_dim, num_layers, ff_dim, num_heads,
feature_map_size, vocab_size, dropout, transformer_type,
dec_div):
super(VLP, self).__init__()
self.feature_extractor = nn.Sequential(
ResNetFeatureExtractor(feature_map_size=feature_map_size,
out_features_size=model_dim),
SinePositionalEncoding(model_dim))
self.transformer = get_transformer(num_layers=num_layers,
model_dim=model_dim,
ff_dim=ff_dim,
num_heads=num_heads,
vocab_size=vocab_size,
dropout=dropout,
transformer_type=transformer_type,
dec_div=dec_div)
def get_image_features(self, images):
return self.feature_extractor(images)
def forward(self, images, tgt=None, tgt_mask=None):
image_features = self.get_image_features(images)
return self.transformer(image_features, tgt, tgt_mask=tgt_mask)
- Handling Text with VLPForTextMLM: Built on top of the VLP model, the VLPForTextMLM class is tailored for our Masked Language Model (MLM) task. It has an additional linear layer that maps the transformer's outputs to our desired number of classes (tokens).
class VLPForTextMLM(nn.Module):
def __init__(self,
model_dim,
num_layers,
num_heads,
ff_dim,
feature_map_size,
num_classes,
dec_div=2,
dropout=0.0):
super(VLPForTextMLM, self).__init__()
self.vlp = VLP(num_layers=num_layers,
model_dim=model_dim,
ff_dim=ff_dim,
num_heads=num_heads,
feature_map_size=feature_map_size,
vocab_size=None,
dropout=dropout,
transformer_type="encoder",
dec_div=dec_div)
self.out = nn.Linear(model_dim, num_classes)
def forward(self, images):
out = self.vlp(images)
return self.out(out)
Asymmetry by Design - The encoder and decoder weren't symmetrical in size. An intentional design choice, the encoder was equipped with a significantly larger number of parameters. The aim was to have the encoder store as much linguistic knowledge as possible. This choice reflected positively on the MNLI dataset results.
Inference is done in autoregressive fashion, using greedy decoding.
Data and Training
Our training data for the MLM task comprised from the subsets (around 30–50% - because of hardware and time constraints) of Wikipedia and Bookcorpus datasets, with texts filtered to maintain a length of less than 144 tokens. The largest portion of my time went into creating the dataset - there are still many more improvements on this front. The model was trained for around 1.5M iterations, with a batch size of 16. The encoder_decoder_lg variant of the model has 129M parameters.
The choice of this token length was a balance between computational efficiency and sufficient contextual information. Also, it is more difficult to create images with large texts.
The training loop is quite simple:
- Forward passing the image through VLPForTextMLM.
- Calculating the loss based on the differences between the predicted tokens and actual tokens.
- Backpropagating the errors to adjust model weights.
- Iterating over this process for many iterations or until convergence or until the model's performance plateaus.
I used the CrossEntropy loss, AdamW optimizer, OneCycle cosine anneal learning rate policy and mixed precision for training.
Results and Observations
To test the quality of learned representations, I trained the VLP model on the MNLI textual entailment downstream task. Given a premise sentence and a hypothesis sentence, the task is to predict whether the premise entails the hypothesis (entailment), contradicts the hypothesis (contradiction), or neither (neutral). This is classification task, so the decoder layer is removed.
Fig 3 - On the left is a image fed to the model and on the right is the attention map when inferring on the MNLI dataset. The attention map was extracted using the Integrated gradients method.
Achieving an F1 score of 0.73 on the MNLI dataset, the model showed promise. However, it wasn’t devoid of challenges. Its heavy reliance on the CNN layer for convergence and stagnant results post CNN size increase, indicated potential bottlenecks.
Fig 4 - Additional examples on validation images from the MNLI dataset. Both examples are correctly classified.
To analyze the quality of learned embeddings, I performed linear probing on the imdb sentiment classification task (positive / negative) in two settings:
- Not limiting the number of tokens — in this setting, the VLP model achieved an F1 score of 0.7, comapred to BERT’s 0.8. This made sense, since the VLP model wasn’t trained on texts with more than 144 tokens.
- Limiting the number of tokens to 144 — in this setting, VLP faired a lot better, scoring an F1 score of 0.78, compared to BERT’s 0.82.
Overall, linear probing showed promising results, but there is still a lot of room for potential refinements in the representation learning approach.
Also, it would be interesting to explore how well the VLP model would scale with larger context sizes.
Fig 5 — The model can recognize text decently even from random images from the internet. The output for this image is Pray Pathole and’’ Tech enthusiasts! My entire house is smart. Tech workers : The only piece of technology in my house is a printer and I keep a gun next to it so I can shoot it if it makes a noise I don’t recognize. 851 AM. April 12, 2019, Twitter for iphone 7, 387 feltees — 656 Oak Tweets 238K Likes _ Q — 11.
Current Status and Future Directions
The project, in its current state, works well for text recognition. It adeptly understands linguistic constructs such as verb usage, punctuation, and co-references. However, it faltered in tasks requiring memorization, such as The capital of France is [MASK]
(wouldn't yield Paris). The project is ripe for exploration, with plans to train on datasets like SQUAD and Ontonotes.
Some of the most obvious ways to address the performance plateau could be:
- expanding and improving the training datasets,
- revisiting the architecture and introducing alternative techniques.
I am very eager to hear from fellow researchers and enthusiasts, your insights and contributions are very welcome! Please feel free to contact me if you have any ideas or questions! And thank you for taking the time to read the blog post, hope you had fun! 🙏😁
Special thanks to Igor Tica, Nikola Tomic and Filip Baturan for giving me many notes and providing a very detailed review of the blogpost! 🙏
Email: basarafilip@gmail.com
LinkedIn: https://www.linkedin.com/in/filip-basara-84a694195/
Twitter: https://twitter.com/basarafilip
Github: https://github.com/filipbasara0
Top comments (0)