TL;DR
Build a two-stage logo pipeline:
Retrieval - generate image embeddings for small crops and match against a logo dictionary with FAISS cosine search. Use SigLIP-2 (NaFlex) so logos are not distorted and small marks still pop.
Verification - for top matches, ask LLaVA-OneVision-1.5 a strict JSON question ("Is this the X logo?") and accept only high-confidence "yes". It's a good model sir.
A bit longer post this time around.
Intro
Brand tagging in real-world video is hard: logos are tiny, partly occluded, moving, and often appear on textured backgrounds. A practical approach is a two-stage pipeline:
- First retrieve likely logo crops with a fast contrastive image encoder.
- Then verify each candidate with a vision-language model (VLM) that can read text and reason about shapes and context. Each candidate is a (frame, bbox, brand, retrieval_score) record that we pass to the VLM.
This tutorial combines a modern image-text encoder (SigLIP-2, NaFlex variant) for high-recall retrieval with a VLM from the LLaVA family for precise, structured yes/no verification.
You could train a YOLO-style detector for specific brands. This post, however, focuses on a more flexible "embedding + VLM" approach that adapts quickly to new logos without retraining. Some might find this fun.
Research
Contrastive encoders such as CLIP/SigLIP produce embeddings where similar visuals are close. FAISS makes nearest-neighbor search over many references instantaneous. You can treat logo search as nearest-neighbor lookup instead of training a custom detector.
Verification reduces false positives. A VLM can explicitly answer "Is this the Red Bull logo?" and justify the decision, improving precision on lookalikes, partial views, and blur.
Pointers to the underlying research for the initiated:
- CLIP: Learning Transferable Visual Models From Natural Language Supervision (arXiv)
- SigLIP: Sigmoid Loss for Language-Image Pre-Training (arXiv)
- LLaVA: Visual Instruction Tuning (arXiv)
- FAISS: Billion-scale similarity search with GPUs (arXiv)
Setup
You’ll need to be comfortable with basic Python and PyTorch. No prior experience with FAISS or VLMs required.
Hardware/software requirements:
- At least one NVIDIA GPU (at least 32 GB VRAM recommended)
- CUDA 12.x, Python 3.10–3.12
- Disk: 5–10 GB (models + caches)
- Tools: ffmpeg for frames
This tutorial was built and tested on a single NVIDIA H200.
Prepare the Python environment as follows:
python3 -m venv .venv && source .venv/bin/activate
pip install -U torch torchvision torchaudio \
"transformers>=4.45" accelerate pillow opencv-python faiss-cpu \
numpy pydantic polars
For simplicity we'll use ƒaiss-cpu as it's fine at this scale. It's a single logo after all. GPU for the models.
Example video (Creative Commons):
- Title: Red Bull Racing Pit Stop Practice (2015), 44 s, 1920x1080.
- License: CC BY-SA 4.0 - attribute ProtoplasmaKid / Wikimedia Commons.
- It has tiny, moving logos, occlusion, uniforms, car bodywork, pit rig and acts as a great stress-test.
Architecture
frames (2–4 FPS) ──▶ crops (multi-scale grid)
│
▼
[Stage 1] Retrieval (SigLIP-2 image features + FAISS cosine)
│ └─ top-K per frame/brand with heuristics
▼
[Stage 2] VLM verification (LLaVA-OneVision-1.5 JSON verdict)
│
▼
JSONL evidence
Models used in this post:
SigLIP-2 reports better zero-shot and retrieval performance than prior SigLIP/CLIP models on public benchmarks. NaFlex means the encoder resizes each crop to a grid of flexible patches instead of forcing a fixed square, so long thin logos don't get squashed.
OneVision-1.5 is an open VLM family. Its card reports strong benchmark leads vs other open models.
Both are Apache-2.0 licensed.
So let's do it!
Quickstart
1) Grab the video:
wget https://upload.wikimedia.org/wikipedia/commons/b/ba/Red_Bull_Racing_Pit_Stop_Practice.webm
Then extract frames. Two FPS, scaled down to 1280x720:
mkdir -p frames
ffmpeg -i "Red_Bull_Racing_Pit_Stop_Practice.webm" -vf "fps=2,scale=1280:-1:flags=lanczos" -q:v 3 frames/f_%06d.jpg
This yields 88 JPG files.
2) Prepare a logo dictionary
In this example we're only interested in the Red Bull logo. Create a logos directory you want to use. Grab the logo:
mkdir -p logos
wget -P logos https://upload.wikimedia.org/wikipedia/fi/a/a5/Red_Bull_logo.png
Disclaimer: This tutorial is for educational purposes only and is not affiliated with or endorsed by Red Bull. Red Bull is a registered trademark of Red Bull GmbH. It's a decent energy drink though.
3) Build the logo index (FAISS, SigLIP-2)
Save to build_logo_index.py:
import json, faiss, torch
from pathlib import Path
from PIL import Image
import numpy as np
from transformers import AutoModel, AutoProcessor
MODEL_ID = "google/siglip2-base-patch16-naflex" # NaFlex = native aspect ratio
OUT_DIR = Path("artifacts"); OUT_DIR.mkdir(exist_ok=True, parents=True)
def embed_images(paths, model, proc, batch=16):
imgs = [Image.open(p).convert("RGB") for p in paths]
feats = []
for i in range(0, len(imgs), batch):
chunk = imgs[i:i+batch]
inputs = proc(images=chunk, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
f = model.get_image_features(**inputs) # (B, d)
f = torch.nn.functional.normalize(f, dim=-1)
feats.append(f.cpu())
return torch.cat(feats, dim=0).numpy()
def main():
model = AutoModel.from_pretrained(MODEL_ID, dtype=torch.float32,
device_map="auto")
proc = AutoProcessor.from_pretrained(MODEL_ID)
logo_paths = sorted(list(Path("logos").glob("*.*")))
if not logo_paths:
raise SystemExit("No logo files found in 'logos' directory.")
brands = [p.stem for p in logo_paths]
vecs = embed_images(logo_paths, model, proc)
d = vecs.shape[1]
index = faiss.IndexFlatIP(d); index.add(vecs)
faiss.write_index(index, str(OUT_DIR / "logos.faiss"))
(OUT_DIR / "logos_meta.json").write_text(
json.dumps({"brands": brands, "files": [str(p) for p in logo_paths]},
indent=2)
)
print(f"Indexed {len(brands)} brands into {OUT_DIR/'logos.faiss'}")
if __name__ == "__main__":
main()
SigLIP-2 provides get_image_features and NaFlex dynamic resizing to minimize distortion on non-square inputs — useful for narrow/wide logos.
Run the tool:
$ python build_logo_index.py
...
Indexed 1 brands into artifacts/logos.faiss

4) Generate crops + Stage-1 retrieval
Save the following to retrieve.py:
import json, math, os, faiss, torch
from pathlib import Path
from PIL import Image, ImageDraw
import numpy as np
from transformers import AutoModel, AutoProcessor
EMB_ID = "google/siglip2-base-patch16-naflex"
ART = Path("artifacts"); ART.mkdir(exist_ok=True)
SIZES = [192, 256, 320] # square windows
STRIDE = 0.5 # 50% overlap
TOPK = int(os.environ.get("TOPK", "3"))
COSINE_TH = float(os.environ.get("COSINE_TH", "0.7")) # keep candidates above this
def _metric_type(index):
try:
return index.metric_type
except Exception:
return None
def _l2_to_cosine(d: np.ndarray) -> np.ndarray:
# If vectors are L2-normalized, cos = 1 - 0.5 * ||a - b||^2
return 1.0 - 0.5 * d
def _as_float32(x: np.ndarray) -> np.ndarray:
if x.dtype != np.float32:
x = x.astype(np.float32, copy=False)
return np.ascontiguousarray(x)
def grid_crops(im: Image.Image):
W, H = im.size
for s in SIZES:
step = max(1, int(s * STRIDE))
for y in range(0, max(1, H - s + 1), step):
for x in range(0, max(1, W - s + 1), step):
yield (x, y, s, s)
def embed(model, proc, pil_list, bs=24):
out = []
for i in range(0, len(pil_list), bs):
chunk = pil_list[i:i+bs]
inp = proc(images=chunk, return_tensors="pt")
inp = {k: v.to(model.device) for k, v in inp.items()}
with torch.no_grad():
f = model.get_image_features(**inp)
f = torch.nn.functional.normalize(f, dim=-1)
out.append(f.detach().to(torch.float32).cpu())
vecs = torch.cat(out, dim=0).numpy()
return _as_float32(vecs)
def main():
# load logo index
index = faiss.read_index(str(ART / "logos.faiss"))
meta = json.loads((ART / "logos_meta.json").read_text())
brands = meta["brands"]
try:
print(f"[retrieve] index ntotal={index.ntotal}, brands={len(brands)}")
except Exception:
pass
# load embedder
model = AutoModel.from_pretrained(EMB_ID, dtype=torch.float32, device_map="auto")
proc = AutoProcessor.from_pretrained(EMB_ID)
frames = sorted(Path("frames").glob("f_*.jpg"))
print(f"[retrieve] frames found={len(frames)}")
out = []
dump_all = bool(int(os.environ.get("DEBUG_DUMP_ALL", "0")))
dbg_all = [] if dump_all else None
debug_draw = dump_all and bool(int(os.environ.get("DEBUG_DRAW", "0")))
debug_draw_dir = ART / "debug_vis"
if debug_draw:
debug_draw_dir.mkdir(parents=True, exist_ok=True)
debug_draw_th = float(os.environ.get("DEBUG_DRAW_TH", "-1.0"))
debug_draw_max = int(os.environ.get("DEBUG_DRAW_MAX", "0")) # 0 = unlimited
best = {"score": -1.0, "frame": None, "brand": None, "bbox": None}
for fpath in frames:
im = Image.open(fpath).convert("RGB")
boxes, crops = [], []
for (x, y, w, h) in grid_crops(im):
boxes.append((x, y, w, h))
crops.append(im.crop((x, y, x+w, y+h)))
if not crops: continue
vecs = embed(model, proc, crops)
D, I = index.search(_as_float32(vecs), TOPK)
mt = _metric_type(index)
if mt == getattr(faiss, "METRIC_L2", 1):
scores_mat = _l2_to_cosine(D)
else:
scores_mat = D
# collect frame-local debug matches as well
frame_dbg = [] if dump_all else None
for i, (scores, ids) in enumerate(zip(scores_mat, I)):
if dump_all:
for r, (score, idx) in enumerate(zip(scores.tolist(), ids.tolist())):
rec = {
"frame": fpath.name,
"bbox": boxes[i],
"rank": int(r),
"score": float(score),
"brand": brands[idx]
}
dbg_all.append(rec)
if frame_dbg is not None:
frame_dbg.append(rec)
for score, idx in zip(scores, ids):
if score > best["score"]:
best.update({
"score": float(score),
"frame": fpath.name,
"brand": brands[idx],
"bbox": boxes[i]
})
if score < COSINE_TH:
continue
bx = boxes[i]
out.append({
"frame": fpath.name,
"bbox": bx,
"score_retr": float(score),
"brand": brands[idx]
})
# draw annotations for this frame if requested
if debug_draw and frame_dbg:
# sort by score desc
frame_dbg.sort(key=lambda r: r["score"], reverse=True)
if debug_draw_max > 0:
frame_dbg = frame_dbg[:debug_draw_max]
canvas = im.copy()
draw = ImageDraw.Draw(canvas)
for rec in frame_dbg:
if rec["score"] < debug_draw_th:
continue
x, y, w, h = rec["bbox"]
x2, y2 = x + w, y + h
color = (255, 0, 0)
draw.rectangle([x, y, x2, y2], outline=color, width=2)
label = f"{rec['brand']} {rec['score']:.3f}#{rec['rank']}"
# simple text; if background needed, draw a small filled box then text
draw.text((x + 3, y + 3), label, fill=color)
out_path = debug_draw_dir / f"{fpath.stem}_debug.jpg"
canvas.save(out_path, quality=90)
print(f"[retrieve] wrote debug visualization → {out_path}")
Path("candidates.jsonl").write_text("\n".join(json.dumps(x) for x in out))
print(f"wrote {len(out)} retrieval candidates → candidates.jsonl")
if not out and best["frame"] is not None:
print(f"[retrieve] no candidates above threshold {COSINE_TH}. "
f"Best observed: score={best['score']:.3f}, frame={best['frame']}, "
f"brand={best['brand']}, bbox={best['bbox']} — consider lowering COSINE_TH.")
if dump_all and dbg_all is not None:
Path("debug_matches.jsonl").write_text("\n".join(json.dumps(x) for x in dbg_all))
print(f"[retrieve] wrote {len(dbg_all)} raw matches → debug_matches.jsonl (DEBUG_DUMP_ALL=1)")
if __name__ == "__main__":
main()
This is quite a few things so breaking it down:
- Crops: multi-scale grid over each frame using
SIZESandSTRIDE, yielding square patches and theirbboxtuples. - Embeddings: SigLIP‑2
get_image_featureson each crop batch, then L2‑normalize features. - Search: FAISS over the logo index; uses inner‑product on normalized vectors (cosine). If the index is L2, we convert to cosine (
1 − 0.5·L2²). - Types: FAISS expects contiguous
float32; embeddings are cast/contiguous beforeindex.search. - Thresholds: keep top‑
TOPKper crop, then filter byCOSINE_TH. You can override at runtime:TOPK=10 COSINE_TH=0.65 python retrieve.py
- Debugging:
-
DEBUG_DUMP_ALL=1writes every raw match todebug_matches.jsonl(ranked with scores). -
DEBUG_DRAW=1also savesartifacts/debug_vis/*_debug.jpgwith [brand score#rank] boxes. - Optional:
DEBUG_DRAW_TH=0.3(only draw ≥ threshold),DEBUG_DRAW_MAX=200(cap boxes).
-
- Outputs: filtered candidates land in
candidates.jsonland are fed to the verifier stage. - Knobs to tune recall/precision: crop sizes, stride,
TOPK,COSINE_TH. For higher recall, increase sizes orTOPK; for precision, raise the threshold and later add a margin or temporal smoothing.
Output:
$ python retrieve.py
[retrieve] index ntotal=1, brands=1
[retrieve] frames found=88
wrote 25 retrieval candidates → candidates.jsonl
Depending on the source material this probably needs tuning for recall/precision. The debugging knobs are quite nice. Here's an example frame from the video showing how it looks like on frame #7:
Frame from Red Bull Racing Pit Stop Practice (2015), ProtoplasmaKid / Wikimedia Commons / CC BY-SA 4.0.
5) VLM verification
Save the following to verify.py:
import json, torch, os
from pathlib import Path
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
VLM_ID = "lmms-lab/LLaVA-OneVision-1.5-8B-Instruct"
SYSTEM = (
"You are a logo verification API. "
"Given an image crop and a target brand, answer in strict JSON with no extra text."
)
def build_user_prompt(brand: str):
return (
"Task: Verify whether the crop contains the specified brand's logo.\n"
f"Brand: {brand}\n\n"
"Return JSON only:\n"
"{\n"
' "verdict": "yes" | "no",\n'
' "confidence": 0.0-1.0,\n'
' "visual_cues": "short, literal cues proving the verdict (colors/shapes/text)"\n'
"}\n"
"Rules: output only the JSON object; no prose before or after. "
"Be literal; do not speculate; base confidence on how clearly the logo is visible."
)
def _normalize_quotes(s: str) -> str:
# Replace smart quotes with ASCII equivalents
return (
s.replace("\u201c", '"').replace("\u201d", '"')
.replace("\u2018", "'").replace("\u2019", "'")
)
def _extract_json_object(text: str):
text = _normalize_quotes(text)
# Find first balanced {...} block; handle braces within strings
in_str = False
escape = False
depth = 0
start = None
for i, ch in enumerate(text):
if ch == "\\" and not escape:
escape = True
continue
if ch == '"' and not escape:
in_str = not in_str
escape = False
if in_str:
continue
if ch == "{":
if depth == 0:
start = i
depth += 1
elif ch == "}":
if depth > 0:
depth -= 1
if depth == 0 and start is not None:
candidate = text[start:i+1]
try:
j = json.loads(candidate)
if isinstance(j, dict) and "verdict" in j and "confidence" in j:
return j
except Exception:
pass
return None
def run_once(proc, model, crop: Image.Image, brand: str, max_new=128):
msgs = [
{"role":"system", "content": SYSTEM},
{"role":"user", "content":[{"type":"image"}, {"type":"text", "text": build_user_prompt(brand)}]}
]
text = proc.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = proc(text=[text], images=[crop], padding=True, return_tensors="pt")
inputs = {k: (v.to(model.device) if hasattr(v,"to") else v) for k,v in inputs.items()}
with torch.inference_mode():
ids = model.generate(**inputs, max_new_tokens=max_new, temperature=0.0, do_sample=False)
out = proc.batch_decode(ids, skip_special_tokens=True)[0]
j = _extract_json_object(out)
if j is None:
if os.environ.get("DEBUG_VLM", "0") == "1":
Path("vlm_raw.txt").write_text(out)
return {"verdict":"no", "confidence":0.0, "visual_cues":"parse_error"}
# Coerce fields
verdict = str(j.get("verdict", "no")).strip().lower()
if verdict not in ("yes", "no"):
verdict = "no"
try:
conf = float(j.get("confidence", 0.0))
except Exception:
conf = 0.0
cues = str(j.get("visual_cues", ""))
return {"verdict": verdict, "confidence": conf, "visual_cues": cues}
def main():
proc = AutoProcessor.from_pretrained(VLM_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
VLM_ID, dtype=torch.float32, device_map="auto", trust_remote_code=True
)
results = []
for line in Path("candidates.jsonl").read_text().splitlines():
c = json.loads(line)
im = Image.open(Path("frames")/c["frame"]).convert("RGB")
x,y,w,h = c["bbox"]
crop = im.crop((x,y,x+w,y+h))
j = run_once(proc, model, crop, c["brand"])
results.append({
**c,
"verdict": j.get("verdict","no"),
"confidence_vlm": float(j.get("confidence",0.0)),
"rationale": j.get("visual_cues","")
})
Path("detections.jsonl").write_text("\n".join(json.dumps(r, ensure_ascii=False) for r in results))
print("wrote detections.jsonl")
if __name__ == "__main__":
main()
Again, breaking it down:
- Inputs:
candidates.jsonlentries withframe,bbox,score_retr,brand. - Crop: open the frame and
im.crop(bbox)per candidate. - Prompting: one system + one user message; user contains the target brand and an explicit "JSON only" schema.
- Generation:
temperature=0.0,do_sample=False,max_new_tokens=128for deterministic outputs. - Parsing: balanced-brace JSON extraction with smart-quote normalization; on failure returns
parse_error. SetDEBUG_VLM=1to dump raw text tovlm_raw.txt. - Output: writes
detections.jsonlwithverdict,confidence_vlm,rationalemerged onto each candidate. - Cost control: keep retrieval strict (lower
TOPK, higherCOSINE_TH) to limit VLM calls as that's the main runtime driver.
Output:
$ python verify.py
wrote detections.jsonl
And what does detections.jsonl look like? It has items like this (one per line):
{
"frame": "f_000007.jpg",
"bbox": [
384,
288,
192,
192
],
"score_retr": 0.7020304203033447,
"brand": "Red_Bull_logo",
"verdict": "yes",
"confidence_vlm": 0.95,
"rationale": "red bull charging bull silhouette, red and yellow colors"
}
Here 0.702 is the cosine similarity between the crop and the logo embedding.
Conclusion
I hope this showcases how a simple two-stage recipe - SigLIP-2 retrieval + VLM verification - can turn a semi-noisy video into reviewable brand evidence. I do want to mention that a pipeline like this is meant as a powerful filter, not an oracle. Human in the loop needed.
On this specific 44-second clip, with the thresholds above, I get N true positives, M false positives, and 0 missed clear logos (subjective visual check).
That being said, this is probably not a production-ready detector. We trade some accuracy and runtime for simplicity and transparency. Some improvement ideas I had in mind:
-
float32dtype could be replaced with lower precision, likebfloat16. - Add a margin filter (top1 − top2 ≥ 0.15) and simple temporal smoothing. Confirm across 2–3 adjacent frames.
- Add a lightweight OCR gate for texty marks to backstop retrieval.
- Calibrate thresholds per brand, as some logos need higher
COSINE_THorTH_VLM. - Expand the logo dictionary and try multi-resolution templates (flat vs curved surfaces).
- Speed/scale: use FAISS IVF/PQ for larger dictionaries.
- Quantize the VLM or batch crops.
- Maybe consider a second encoder for consensus (SigLIP‑2 + CLIP) to reduce lookalikes.
- For 4K frames, consider either down-scaling first or increasing STRIDE (e.g. 0.75) to avoid generating tens of thousands of crops per frame.
As always, credits: thanks to the SigLIP-2 authors & maintainers and the LLaVA-OneVision team for open releases. And attribution for the example video: ProtoplasmaKid / Wikimedia Commons / CC BY-SA 4.0. 

Top comments (0)