DEV Community

Mansi Somayajula
Mansi Somayajula

Posted on

Nobody Tells You This About Slow Transformer Models — I Fixed Mine in 3 Steps

Hot take: most "my model is slow" problems are not model problems.
They're inference problems. And the ML community almost never talks about that gap.
Everyone's obsessed with architecture choices, parameter counts, quantization-aware training, distillation strategies... while the actual bottleneck is sitting right there in the inference code, completely ignored.
I know because I was doing it wrong for longer than I'd like to admit.
I had a DistilBERT classifier running at ~750ms per request in production. My first instinct was "I need a better machine." Turns out, I needed to stop processing one input at a time like it was 2015.
Here's exactly what I did — three steps, same CPU, same model — and I got it down to 280ms.

What I was building
A support ticket classifier I'm calling SupportBot. Fine-tuned DistilBERT, three classes: billing, technical, general. Great accuracy. Terrible latency.
Here's the embarrassing baseline:
python# baseline.py
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
import time

MODEL_PATH = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
model.eval()

def predict(text: str):
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
return outputs.logits.argmax(dim=-1).item()

texts = ["My payment didn't go through"] * 10
start = time.time()
for text in texts:
predict(text)
elapsed = (time.time() - start) * 1000

print(f"Avg: {elapsed/10:.0f}ms per request")

750ms. I wish I was joking.

750ms. Let's fix this.

Step 1: I was running a dishwasher for a single fork
Processing one text at a time means one full forward pass per request. All the overhead — loading weights into cache, spinning up computation — for one input. Over and over.
The fix is almost insulting in how simple it is: just send multiple texts through at once.

python# step1_batching.py
def predict_batch(texts: list, batch_size: int = 16) -> list:
all_predictions = []

for i in range(0, len(texts), batch_size):
    batch = texts[i : i + batch_size]

    inputs = tokenizer(
        batch,
        return_tensors="pt",
        truncation=True,
        max_length=128,
        padding=True  # pads all texts to the same length within the batch
    )

    with torch.no_grad():
        outputs = model(**inputs)

    all_predictions.extend(outputs.logits.argmax(dim=-1).tolist())

return all_predictions
Enter fullscreen mode Exit fullscreen mode

texts = ["My payment didn't go through"] * 10
start = time.time()
predict_batch(texts)
elapsed = (time.time() - start) * 1000

print(f"Avg: {elapsed/10:.0f}ms per request")

480ms. Same hardware, one change.

750ms → 480ms. 36% faster. One change. I genuinely stared at the screen for a moment.

💡 What batch size? I'd start with 16. Drop to 8 if you're hitting memory limits, try 32 if you have room. Dynamic batching (queuing requests for ~30–50ms before processing) is the next level if you're building a real API.

Step 2: PyTorch was carrying bags it didn't need to
Here's something that surprised me: PyTorch is not built for inference. It's built for training.
Every inference call drags along autograd, gradient tracking, training hooks... overhead I was paying for every single request but never using. ONNX Runtime strips all of that out and applies graph-level optimizations automatically. For CPU inference, the difference is real.
First, I exported the model:

python# export_onnx.py
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

MODEL_PATH = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
model.eval()

dummy_input = tokenizer(
"sample text for export",
return_tensors="pt",
max_length=128,
padding="max_length",
truncation=True
)

torch.onnx.export(
model,
(dummy_input["input_ids"], dummy_input["attention_mask"]),
"supportbot.onnx",
input_names=["input_ids", "attention_mask"],
output_names=["logits"],
dynamic_axes={
# Critical — without this, ONNX bakes in fixed shapes from the dummy input
# and rejects any request that doesn't match exactly
"input_ids": {0: "batch_size", 1: "seq_length"},
"attention_mask": {0: "batch_size", 1: "seq_length"},
"logits": {0: "batch_size"}
},
opset_version=13
)
print("Exported.")

⚠️ Don't skip dynamic_axes. I learned this the hard way. Without it you'll get cryptic shape errors in production and spend an hour wondering why requests work in your test script but fail in your API.

Then switched to running inference with ONNX Runtime:

python# step2_onnx.py
import onnxruntime as ort
import numpy as np
from transformers import DistilBertTokenizer
import time

class SupportBotClassifier:
def init(self, model_path: str, tokenizer_path: str):
opts = ort.SessionOptions()
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
opts.intra_op_num_threads = 4 # match your CPU core count

    self.session = ort.InferenceSession(
        model_path,
        sess_options=opts,
        providers=["CPUExecutionProvider"]
    )
    self.tokenizer = DistilBertTokenizer.from_pretrained(tokenizer_path)

def predict(self, texts: list) -> list:
    inputs = self.tokenizer(
        texts,
        return_tensors="np",  # numpy directly — no PyTorch tensors needed anymore
        max_length=128,
        padding=True,
        truncation=True
    )

    logits = self.session.run(
        ["logits"],
        {
            "input_ids":      inputs["input_ids"].astype(np.int64),
            "attention_mask": inputs["attention_mask"].astype(np.int64)
        }
    )[0]

    return np.argmax(logits, axis=-1).tolist()
Enter fullscreen mode Exit fullscreen mode

classifier = SupportBotClassifier("supportbot.onnx", MODEL_PATH)
texts = ["My payment didn't go through"] * 10
start = time.time()
classifier.predict(texts)
elapsed = (time.time() - start) * 1000

print(f"Avg: {elapsed/10:.0f}ms per request")

350ms.

480ms → 350ms. Another 27% faster, zero model changes.

Step 3: One function call. I'm not exaggerating.
FP32 weights for a ticket classifier is overkill. Dynamic INT8 quantization compresses those weights to 8-bit integers — smaller memory, faster CPU math, almost no accuracy loss.
python# step3_quantize.py
from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
model_input="supportbot.onnx",
model_output="supportbot_quantized.onnx",
weight_type=QuantType.QInt8
)
print("Done.")
That's it. That's the whole step.
Always verify predictions didn't shift though:
pythonoriginal = SupportBotClassifier("supportbot.onnx", MODEL_PATH)
quantized = SupportBotClassifier("supportbot_quantized.onnx", MODEL_PATH)

tests = [
"My payment failed twice",
"App keeps crashing on iOS",
"Question about my plan",
"Can't log in",
"Where's my invoice?",
]

matches = sum(
o == q for o, q in
zip(original.predict(tests), quantized.predict(tests))
)
print(f"Match rate: {matches/len(tests):.0%}")

Match rate: 100%

Final benchmark:
pythonclassifier_v2 = SupportBotClassifier("supportbot_quantized.onnx", MODEL_PATH)
start = time.time()
classifier_v2.predict(["My payment didn't go through"] * 10)
elapsed = (time.time() - start) * 1000
print(f"Avg: {elapsed/10:.0f}ms per request")

280ms 🎉

350ms → 280ms.

The full picture
ChangeLatencyvs. startBaseline (PyTorch, single)~750ms—+ Batch processing~480ms36% faster+ ONNX Runtime~350ms53% faster+ INT8 quantization~280ms63% faster
Same model. Same CPU. Same hardware.

The FastAPI wrapper I'm using
python# api.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import time

app = FastAPI(title="SupportBot API")
classifier = SupportBotClassifier("supportbot_quantized.onnx", MODEL_PATH)
LABELS = ["billing", "technical", "general"]

class ClassifyRequest(BaseModel):
texts: List[str]

class ClassifyResponse(BaseModel):
predictions: List[str]
latency_ms: float

@app.post("/classify", response_model=ClassifyResponse)
async def classify(req: ClassifyRequest):
if not req.texts:
raise HTTPException(400, "texts can't be empty")
if len(req.texts) > 32:
raise HTTPException(400, "max 32 per request")

start = time.time()
raw = classifier.predict(req.texts)
ms = round((time.time() - start) * 1000, 2)

return ClassifyResponse(
    predictions=[LABELS[p] for p in raw],
    latency_ms=ms
)
Enter fullscreen mode Exit fullscreen mode

@app.get("/health")
async def health():
return {"status": "ok"}
bashuvicorn api:app --host 0.0.0.0 --port 8000

curl -X POST http://localhost:8000/classify \
-H "Content-Type: application/json" \
-d '{"texts": ["payment failed", "app crashed"]}'

{"predictions":["billing","technical"],"latency_ms":42.8}

What I'd do differently
Export to ONNX from day one. I burned time micro-optimizing PyTorch before I made the switch. Should've started there.
Check actual input length distribution early. I defaulted to max_length=128. Turns out 80% of my inputs were under 64 tokens. Dropping max_length for short inputs gave me another ~15% I left on the table.
Add confidence logging from the start, not after. Track average confidence scores over time. When they start drifting downward — your input distribution is shifting and it's time to retrain. I bolted this on late and missed some early signals I shouldn't have.

The real takeaway
Three things — in order of impact:

Batch your inputs. Never process one at a time. Biggest win, easiest fix, most commonly ignored.
Use ONNX Runtime on CPU. PyTorch is brilliant for training. For CPU serving? It's carrying too much.
Quantize before you deploy. One function call. Almost never hurts accuracy on classification tasks.

None of this is magic. It's just doing inference the right way.
The model gets the credit. The inference pipeline does the actual work.

What's next
Now that the model is fast, the next thing I ran into was keeping it honest. Models degrade silently in production — input distributions shift, confidence drops, and you usually don't notice until something's already broken.
Next up: building a 3-level drift detection system to catch model degradation before it hits production. Follow along if you don't want to miss it 🚀

Tried this? Got different numbers? I'd genuinely love to know — drop it in the comments.

Top comments (0)