<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>DEV Community: Mansi Somayajula</title>
    <description>The latest articles on DEV Community by Mansi Somayajula (@mansisomayajula03).</description>
    <link>https://dev.to/mansisomayajula03</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F3872789%2Ff0a2e0f6-96ad-42c4-b17a-b8ecd160ce9b.png</url>
      <title>DEV Community: Mansi Somayajula</title>
      <link>https://dev.to/mansisomayajula03</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://dev.to/feed/mansisomayajula03"/>
    <language>en</language>
    <item>
      <title>Nobody Tells You This About Slow Transformer Models — I Fixed Mine in 3 Steps</title>
      <dc:creator>Mansi Somayajula</dc:creator>
      <pubDate>Sat, 11 Apr 2026 06:23:34 +0000</pubDate>
      <link>https://dev.to/mansisomayajula03/nobody-tells-you-this-about-slow-transformer-models-i-fixed-mine-in-3-steps-518c</link>
      <guid>https://dev.to/mansisomayajula03/nobody-tells-you-this-about-slow-transformer-models-i-fixed-mine-in-3-steps-518c</guid>
      <description>&lt;p&gt;Hot take: most "&lt;em&gt;my model is slow&lt;/em&gt;" problems are not model problems.&lt;br&gt;
They're inference problems. And the ML community almost never talks about that gap.&lt;br&gt;
Everyone's obsessed with architecture choices, parameter counts, quantization-aware training, distillation strategies... while the actual bottleneck is sitting right there in the inference code, completely ignored.&lt;br&gt;
I know because I was doing it wrong for longer than I'd like to admit.&lt;br&gt;
I had a DistilBERT classifier running at ~750ms per request in production. My first instinct was "I need a better machine." Turns out, I needed to stop processing one input at a time like it was 2015.&lt;br&gt;
Here's exactly what I did — three steps, same CPU, same model — and I got it down to 280ms.&lt;/p&gt;

&lt;p&gt;What I was building&lt;br&gt;
A support ticket classifier I'm calling SupportBot. Fine-tuned DistilBERT, three classes: billing, technical, general. Great accuracy. Terrible latency.&lt;br&gt;
Here's the embarrassing baseline:&lt;br&gt;
python# baseline.py&lt;br&gt;
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification&lt;br&gt;
import torch&lt;br&gt;
import time&lt;/p&gt;

&lt;p&gt;MODEL_PATH = "distilbert-base-uncased-finetuned-sst-2-english"&lt;br&gt;
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)&lt;br&gt;
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)&lt;br&gt;
model.eval()&lt;/p&gt;

&lt;p&gt;def predict(text: str):&lt;br&gt;
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)&lt;br&gt;
    with torch.no_grad():&lt;br&gt;
        outputs = model(**inputs)&lt;br&gt;
    return outputs.logits.argmax(dim=-1).item()&lt;/p&gt;

&lt;p&gt;texts = ["My payment didn't go through"] * 10&lt;br&gt;
start = time.time()&lt;br&gt;
for text in texts:&lt;br&gt;
    predict(text)&lt;br&gt;
elapsed = (time.time() - start) * 1000&lt;/p&gt;

&lt;p&gt;print(f"Avg: {elapsed/10:.0f}ms per request")&lt;/p&gt;

&lt;h1&gt;
  
  
  750ms. I wish I was joking.
&lt;/h1&gt;

&lt;p&gt;750ms. Let's fix this.&lt;/p&gt;

&lt;p&gt;Step 1: I was running a dishwasher for a single fork&lt;br&gt;
Processing one text at a time means one full forward pass per request. All the overhead — loading weights into cache, spinning up computation — for one input. Over and over.&lt;br&gt;
The fix is almost insulting in how simple it is: just send multiple texts through at once.&lt;/p&gt;

&lt;p&gt;python# step1_batching.py&lt;br&gt;
def predict_batch(texts: list, batch_size: int = 16) -&amp;gt; list:&lt;br&gt;
    all_predictions = []&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;for i in range(0, len(texts), batch_size):
    batch = texts[i : i + batch_size]

    inputs = tokenizer(
        batch,
        return_tensors="pt",
        truncation=True,
        max_length=128,
        padding=True  # pads all texts to the same length within the batch
    )

    with torch.no_grad():
        outputs = model(**inputs)

    all_predictions.extend(outputs.logits.argmax(dim=-1).tolist())

return all_predictions
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;texts = ["My payment didn't go through"] * 10&lt;br&gt;
start = time.time()&lt;br&gt;
predict_batch(texts)&lt;br&gt;
elapsed = (time.time() - start) * 1000&lt;/p&gt;

&lt;p&gt;print(f"Avg: {elapsed/10:.0f}ms per request")&lt;/p&gt;

&lt;h1&gt;
  
  
  480ms. Same hardware, one change.
&lt;/h1&gt;

&lt;p&gt;750ms → 480ms. 36% faster. One change. I genuinely stared at the screen for a moment.&lt;/p&gt;

&lt;p&gt;💡 What batch size? I'd start with 16. Drop to 8 if you're hitting memory limits, try 32 if you have room. Dynamic batching (queuing requests for ~30–50ms before processing) is the next level if you're building a real API.&lt;/p&gt;

&lt;p&gt;Step 2: PyTorch was carrying bags it didn't need to&lt;br&gt;
Here's something that surprised me: PyTorch is not built for inference. It's built for training.&lt;br&gt;
Every inference call drags along autograd, gradient tracking, training hooks... overhead I was paying for every single request but never using. ONNX Runtime strips all of that out and applies graph-level optimizations automatically. For CPU inference, the difference is real.&lt;br&gt;
First, I exported the model:&lt;/p&gt;

&lt;p&gt;python# export_onnx.py&lt;br&gt;
import torch&lt;br&gt;
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification&lt;/p&gt;

&lt;p&gt;MODEL_PATH = "distilbert-base-uncased-finetuned-sst-2-english"&lt;br&gt;
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)&lt;br&gt;
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)&lt;br&gt;
model.eval()&lt;/p&gt;

&lt;p&gt;dummy_input = tokenizer(&lt;br&gt;
    "sample text for export",&lt;br&gt;
    return_tensors="pt",&lt;br&gt;
    max_length=128,&lt;br&gt;
    padding="max_length",&lt;br&gt;
    truncation=True&lt;br&gt;
)&lt;/p&gt;

&lt;p&gt;torch.onnx.export(&lt;br&gt;
    model,&lt;br&gt;
    (dummy_input["input_ids"], dummy_input["attention_mask"]),&lt;br&gt;
    "supportbot.onnx",&lt;br&gt;
    input_names=["input_ids", "attention_mask"],&lt;br&gt;
    output_names=["logits"],&lt;br&gt;
    dynamic_axes={&lt;br&gt;
        # Critical — without this, ONNX bakes in fixed shapes from the dummy input&lt;br&gt;
        # and rejects any request that doesn't match exactly&lt;br&gt;
        "input_ids":      {0: "batch_size", 1: "seq_length"},&lt;br&gt;
        "attention_mask": {0: "batch_size", 1: "seq_length"},&lt;br&gt;
        "logits":         {0: "batch_size"}&lt;br&gt;
    },&lt;br&gt;
    opset_version=13&lt;br&gt;
)&lt;br&gt;
print("Exported.")&lt;/p&gt;

&lt;p&gt;⚠️ Don't skip dynamic_axes. I learned this the hard way. Without it you'll get cryptic shape errors in production and spend an hour wondering why requests work in your test script but fail in your API.&lt;/p&gt;

&lt;p&gt;Then switched to running inference with ONNX Runtime:&lt;/p&gt;

&lt;p&gt;python# step2_onnx.py&lt;br&gt;
import onnxruntime as ort&lt;br&gt;
import numpy as np&lt;br&gt;
from transformers import DistilBertTokenizer&lt;br&gt;
import time&lt;/p&gt;

&lt;p&gt;class SupportBotClassifier:&lt;br&gt;
    def &lt;strong&gt;init&lt;/strong&gt;(self, model_path: str, tokenizer_path: str):&lt;br&gt;
        opts = ort.SessionOptions()&lt;br&gt;
        opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL&lt;br&gt;
        opts.intra_op_num_threads = 4  # match your CPU core count&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;    self.session = ort.InferenceSession(
        model_path,
        sess_options=opts,
        providers=["CPUExecutionProvider"]
    )
    self.tokenizer = DistilBertTokenizer.from_pretrained(tokenizer_path)

def predict(self, texts: list) -&amp;gt; list:
    inputs = self.tokenizer(
        texts,
        return_tensors="np",  # numpy directly — no PyTorch tensors needed anymore
        max_length=128,
        padding=True,
        truncation=True
    )

    logits = self.session.run(
        ["logits"],
        {
            "input_ids":      inputs["input_ids"].astype(np.int64),
            "attention_mask": inputs["attention_mask"].astype(np.int64)
        }
    )[0]

    return np.argmax(logits, axis=-1).tolist()
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;classifier = SupportBotClassifier("supportbot.onnx", MODEL_PATH)&lt;br&gt;
texts = ["My payment didn't go through"] * 10&lt;br&gt;
start = time.time()&lt;br&gt;
classifier.predict(texts)&lt;br&gt;
elapsed = (time.time() - start) * 1000&lt;/p&gt;

&lt;p&gt;print(f"Avg: {elapsed/10:.0f}ms per request")&lt;/p&gt;

&lt;h1&gt;
  
  
  350ms.
&lt;/h1&gt;

&lt;p&gt;480ms → 350ms. Another 27% faster, zero model changes.&lt;/p&gt;

&lt;p&gt;Step 3: One function call. I'm not exaggerating.&lt;br&gt;
FP32 weights for a ticket classifier is overkill. Dynamic INT8 quantization compresses those weights to 8-bit integers — smaller memory, faster CPU math, almost no accuracy loss.&lt;br&gt;
python# step3_quantize.py&lt;br&gt;
from onnxruntime.quantization import quantize_dynamic, QuantType&lt;/p&gt;

&lt;p&gt;quantize_dynamic(&lt;br&gt;
    model_input="supportbot.onnx",&lt;br&gt;
    model_output="supportbot_quantized.onnx",&lt;br&gt;
    weight_type=QuantType.QInt8&lt;br&gt;
)&lt;br&gt;
print("Done.")&lt;br&gt;
That's it. That's the whole step.&lt;br&gt;
Always verify predictions didn't shift though:&lt;br&gt;
pythonoriginal  = SupportBotClassifier("supportbot.onnx", MODEL_PATH)&lt;br&gt;
quantized = SupportBotClassifier("supportbot_quantized.onnx", MODEL_PATH)&lt;/p&gt;

&lt;p&gt;tests = [&lt;br&gt;
    "My payment failed twice",&lt;br&gt;
    "App keeps crashing on iOS",&lt;br&gt;
    "Question about my plan",&lt;br&gt;
    "Can't log in",&lt;br&gt;
    "Where's my invoice?",&lt;br&gt;
]&lt;/p&gt;

&lt;p&gt;matches = sum(&lt;br&gt;
    o == q for o, q in&lt;br&gt;
    zip(original.predict(tests), quantized.predict(tests))&lt;br&gt;
)&lt;br&gt;
print(f"Match rate: {matches/len(tests):.0%}")&lt;/p&gt;

&lt;h1&gt;
  
  
  Match rate: 100%
&lt;/h1&gt;

&lt;p&gt;Final benchmark:&lt;br&gt;
pythonclassifier_v2 = SupportBotClassifier("supportbot_quantized.onnx", MODEL_PATH)&lt;br&gt;
start = time.time()&lt;br&gt;
classifier_v2.predict(["My payment didn't go through"] * 10)&lt;br&gt;
elapsed = (time.time() - start) * 1000&lt;br&gt;
print(f"Avg: {elapsed/10:.0f}ms per request")&lt;/p&gt;

&lt;h1&gt;
  
  
  280ms 🎉
&lt;/h1&gt;

&lt;p&gt;350ms → 280ms.&lt;/p&gt;

&lt;p&gt;The full picture&lt;br&gt;
ChangeLatencyvs. startBaseline (PyTorch, single)~750ms—+ Batch processing~480ms36% faster+ ONNX Runtime~350ms53% faster+ INT8 quantization~280ms63% faster&lt;br&gt;
Same model. Same CPU. Same hardware.&lt;/p&gt;

&lt;p&gt;The FastAPI wrapper I'm using&lt;br&gt;
python# api.py&lt;br&gt;
from fastapi import FastAPI, HTTPException&lt;br&gt;
from pydantic import BaseModel&lt;br&gt;
from typing import List&lt;br&gt;
import time&lt;/p&gt;

&lt;p&gt;app = FastAPI(title="SupportBot API")&lt;br&gt;
classifier = SupportBotClassifier("supportbot_quantized.onnx", MODEL_PATH)&lt;br&gt;
LABELS = ["billing", "technical", "general"]&lt;/p&gt;

&lt;p&gt;class ClassifyRequest(BaseModel):&lt;br&gt;
    texts: List[str]&lt;/p&gt;

&lt;p&gt;class ClassifyResponse(BaseModel):&lt;br&gt;
    predictions: List[str]&lt;br&gt;
    latency_ms: float&lt;/p&gt;

&lt;p&gt;@app.post("/classify", response_model=ClassifyResponse)&lt;br&gt;
async def classify(req: ClassifyRequest):&lt;br&gt;
    if not req.texts:&lt;br&gt;
        raise HTTPException(400, "texts can't be empty")&lt;br&gt;
    if len(req.texts) &amp;gt; 32:&lt;br&gt;
        raise HTTPException(400, "max 32 per request")&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;start = time.time()
raw = classifier.predict(req.texts)
ms = round((time.time() - start) * 1000, 2)

return ClassifyResponse(
    predictions=[LABELS[p] for p in raw],
    latency_ms=ms
)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;@app.get("/health")&lt;br&gt;
async def health():&lt;br&gt;
    return {"status": "ok"}&lt;br&gt;
bashuvicorn api:app --host 0.0.0.0 --port 8000&lt;/p&gt;

&lt;p&gt;curl -X POST &lt;a href="http://localhost:8000/classify" rel="noopener noreferrer"&gt;http://localhost:8000/classify&lt;/a&gt; \&lt;br&gt;
  -H "Content-Type: application/json" \&lt;br&gt;
  -d '{"texts": ["payment failed", "app crashed"]}'&lt;/p&gt;

&lt;h1&gt;
  
  
  {"predictions":["billing","technical"],"latency_ms":42.8}
&lt;/h1&gt;

&lt;p&gt;What I'd do differently&lt;br&gt;
Export to ONNX from day one. I burned time micro-optimizing PyTorch before I made the switch. Should've started there.&lt;br&gt;
Check actual input length distribution early. I defaulted to max_length=128. Turns out 80% of my inputs were under 64 tokens. Dropping max_length for short inputs gave me another ~15% I left on the table.&lt;br&gt;
Add confidence logging from the start, not after. Track average confidence scores over time. When they start drifting downward — your input distribution is shifting and it's time to retrain. I bolted this on late and missed some early signals I shouldn't have.&lt;/p&gt;

&lt;p&gt;The real takeaway&lt;br&gt;
Three things — in order of impact:&lt;/p&gt;

&lt;p&gt;Batch your inputs. Never process one at a time. Biggest win, easiest fix, most commonly ignored.&lt;br&gt;
Use ONNX Runtime on CPU. PyTorch is brilliant for training. For CPU serving? It's carrying too much.&lt;br&gt;
Quantize before you deploy. One function call. Almost never hurts accuracy on classification tasks.&lt;/p&gt;

&lt;p&gt;None of this is magic. It's just doing inference the right way.&lt;br&gt;
The model gets the credit. The inference pipeline does the actual work.&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;What's next&lt;br&gt;
Now that the model is fast, the next thing I ran into was keeping it honest. Models degrade silently in production — input distributions shift, confidence drops, and you usually don't notice until something's already broken.&lt;br&gt;
Next up: building a 3-level drift detection system to catch model degradation before it hits production. Follow along if you don't want to miss it 🚀&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;Tried this? Got different numbers? I'd genuinely love to know — drop it in the comments.&lt;/p&gt;

</description>
      <category>python</category>
      <category>machinelearning</category>
      <category>nlp</category>
      <category>onnx</category>
    </item>
  </channel>
</rss>
