In this article, I'll walk you through how I built a robust molecular similarity search system using ChemBERTa, RDKit, Qdrant, and what I actually learned along the way.
Canonical URL: https://medium.com/towards-artificial-intelligence/i-tried-vector-search-on-molecules-heres-what-happened-7391b755efe4
GitHub: github.com/dvy246/qdrant
Stack: Python 3.10+, RDKit, ChemBERTa, Qdrant, FastAPI, Streamlit
Why I Built This
I had been spending a lot of time with vector databases and embedding-based search. Every example I came across was about text: document search, FAQ retrieval, chatbot memory.
At some point, I asked myself a question that seemed a bit strange at the time: Can I apply the same thing to molecules?
Molecules can be written as text strings. SMILES (Simplified Molecular Input Line Entry System) is a notation that encodes molecular structure as a string. CC(=O)Oc1ccccc1C(=O)O is aspirin. If you can write a molecule as a string, you can feed it to a transformer. And if you can get a dense vector out of a transformer, you can search it with a vector database.
I had been reading about ChemBERTa, a transformer trained specifically on SMILES strings. Qdrant had just shipped a clean Python client with native payload filtering. I thought: what happens if I connect these two things?
So I built a small experiment and ran it.
Here is what the system does:
- Downloads and caches the ZINC-250k dataset (2000 drug-like molecules by default)
- Validates and canonicalizes each SMILES string using RDKit, computing a heuristic toxicity proxy per molecule
- Embeds each molecule into a 768-dimensional vector using ChemBERTa
- Indexes those vectors in Qdrant with molecular property metadata as searchable payload
- Retrieves similar molecules using cosine similarity with optional filters on MW, LogP, and toxicity
- Serves results through a FastAPI endpoint or a Streamlit UI
The Problem with Fingerprints
Most molecule search pipelines use fingerprint similarity. The basic idea: take a molecule, generate a Morgan fingerprint (a fixed-length bit vector that records which molecular fragments are present), compute a Tanimoto score against a query molecule (overlap between the two bit vectors, 0 to 1), and rank by that score.
For obvious structural matches, this works well. Fast, simple, interpretable.
But when I was testing it on a small set of compounds, I kept seeing it miss molecules that looked chemically related when I inspected them manually. Not enough fragment overlap for the score to reflect that relationship. The system just did not surface them.
*That gap is what pushed me toward trying something different.
*
The core issue is that compressing a molecule into a fixed-length bit vector loses information. Different substructures can hash to the same bit position. That collision is gone before you run a query. You are searching for a lossy representation of the molecule, not the molecule itself.
This gets particularly visible in what chemists call scaffold hopping: two molecules that act on the same biological target through structurally different cores. Their fingerprints can look completely different even when they are functionally related. Tanimoto has no way to catch that.
There is also the activity cliff problem. Adding a single methyl group (one carbon with three hydrogens) can completely flip how a molecule behaves biologically. Tanimoto treats that as a single-bit difference with no chemical context.
And at scale, Tanimoto is a full linear scan. With millions of compounds and thousands of queries, that is billions of pairwise comparisons. It requires dedicated engineering to keep running at a reasonable speed.
None of this was a deal-breaker for a small experiment. But it was enough to make me curious whether an approach that learned something about molecular structure, rather than just counting fragments, could surface what fingerprints miss.
The Idea: Embeddings Over Fragments
Instead of hashing molecules into fragment bit vectors, I wanted to try converting each molecule into a dense vector that captured its structural and physicochemical character in continuous space. The bet: similar molecules would land close together in that space even when their fingerprints look different.
ChemBERTa handles that. It is a RoBERTa-based transformer trained on SMILES strings from the ZINC database. During training, it learns to predict masked tokens from surrounding context, the same pretraining approach BERT uses for text. Because SMILES encodes molecular structure in a consistent grammar, the model picks up structural patterns across millions of molecules. The result is a 768-number vector per molecule.
One thing to be clear about: these vectors are structural estimates, not measures of biological activity. The base model was trained without any assay data. You are measuring learned structural closeness, not functional equivalence. That limitation matters and comes up again in the limitations section.
Once molecules are represented as vectors, approximate nearest neighbor search replaces the linear scan. That is what Qdrant handles.
Pipeline Overview
Five stages, each doing one thing:
SMILES strings
│
▼
RDKit validation and canonicalization
│
▼
ChemBERTa embedding (Hugging Face Transformers)
│
▼
Qdrant vector indexing (upsert with metadata payloads)
│
▼
Similarity search API (FastAPI) or Web UI (Streamlit)
Raw SMILES come in first. RDKit validates and standardizes them. Anything that does not parse gets dropped before it reaches the model, and more comes in invalid than you might expect from a curated source.
Validated SMILES go to ChemBERTa (seyonec/ChemBERTa-zinc-base-v1). Mean pooling on the final hidden state gives a 768-dimensional vector per molecule. That vector, along with metadata like molecular weight and LogP, gets stored as a point in Qdrant. HNSW handles the indexing. At query time, the same model embeds the query, Qdrant searches by cosine similarity, and results can be filtered by molecular properties during retrieval, not after.
Why Qdrant
I had one core requirement: apply numeric filters on molecular properties during retrieval, not after it. That ruled out more options than I expected.
Pinecone was the easiest to get running. Basic vector search worked fine. The problem: combining similarity search with numeric range filters on molecular weight and LogP produced inconsistent behavior.
Milvus gave good raw search performance. The problem: local installation required spinning up etcd and MinIO just to run a dev experiment. Not acceptable for something I was iterating on quickly.
Weaviate had a clean schema system I liked. The problem: chaining multiple numeric range filters with a vector search query felt awkward. The semantics were not intuitive enough for this use case.
Qdrant solved the specific problems the others did not.
The key difference is how it handles filters. In this kind of search you rarely query by vector alone. You want to say: find molecules similar to this one, but only where molecular weight is under 500 and LogP is below 5. Qdrant applies those constraints natively during HNSW graph traversal, not after. Post-filtering sounds like the same thing but it is not. If your filter is restrictive, you silently end up with fewer than top-k results because candidates are removed after the search finishes. That was not acceptable for how I needed this to work.
The in-memory client (QdrantClient(":memory:")) was a big deal for development. I could build and test the entire pipeline on a laptop without Docker or any external service. Moving to a persistent Qdrant instance later was one line. That made the iteration loop fast.
When to Use Embeddings vs. Fingerprints
If you are working with a well-characterized scaffold series where structure-activity relationships are already mapped, ECFP4 and Tanimoto are probably still the right tool. Fast, interpretable, no model required. Adding an embedding pipeline for that use case adds complexity without giving much back.
Embeddings start to make sense when you are trying to surface candidates that do not share obvious substructure with your query, or when your library is large enough that a linear Tanimoto scan becomes a bottleneck.
ChemBERTa is not the only option. Mol2Vec produces lighter embeddings and runs faster on CPU. GNN-based models can encode 3D molecular geometry, which matters if shape drives binding for your target. I went with ChemBERTa because it takes SMILES as direct input and is well-supported on Hugging Face. For this kind of experiment, that simplicity was worth more than the extra capability of a more complex model.
Similarity Metric
Before indexing, every vector gets L2 normalized. Once vectors are unit length, cosine similarity and dot product give identical results.
I went with cosine similarity for one practical reason: the scores are easy to read. Cosine scores sit between -1 and 1, so a score of 0.95 versus 0.72 tells you something concrete when you are staring at retrieval results. When you are debugging why certain molecules surface and others do not, readable scores matter.
Environment Setup
Python 3.10+ required. All dependencies from PyPI.
# Core cheminformatics
pip install rdkit
# Transformer framework, tokenizer, and PyTorch
pip install transformers torch
# Vector database client (requires Qdrant server v1.10+ for the Query API)
pip install "qdrant-client>=1.10.0"
# API or UI (pick one or both)
pip install fastapi uvicorn
pip install streamlit
Note: The PyPI package is
rdkit, notrdkit-pypi(older name). The codebase usesX | Yunion type syntax which requires Python 3.10+. ChemBERTa runs on CPU for smaller datasets but GPU inference is significantly faster above 100k molecules.
Verify before starting:
import rdkit
from rdkit import Chem
from transformers import AutoTokenizer, AutoModel
from qdrant_client import QdrantClient
print(f"RDKit version: {rdkit.__version__}")
print("All imports successful.")
Building It: Step by Step
Step 1: Load the Dataset and Validate SMILES
SMILES strings from real databases are messy. The same molecule can have multiple valid representations. Some entries contain salts or stereoisomer annotations. Others are just invalid. Feeding any of that directly to a transformer produces garbage embeddings or silent failures.
I went with ZINC-250k, a widely used benchmark subset of drug-like molecules from the ZINC database. It is the same source ChemBERTa was pre-trained on, which makes it a natural fit.
# data_loader.py
from __future__ import annotations
import csv
import logging
from pathlib import Path
from rdkit import Chem
from molsearch.config import DATA_CACHE_DIR, DATASET_SIZE
logger = logging.getLogger(__name__)
ZINC_250K_URL = "https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv"
ZINC_FILENAME = "zinc_250k.csv"
_FALLBACK_SMILES: list[str] = [
"CC(=O)Oc1ccccc1C(=O)O", # Aspirin
"CC(C)Cc1ccc(cc1)C(C)C(=O)O", # Ibuprofen
"CC(=O)Nc1ccc(O)cc1", # Acetaminophen
"O=C(O)Cc1ccccc1Nc1c(Cl)cccc1Cl", # Diclofenac
"COc1ccc2cc(CC(C)C(=O)O)ccc2c1", # Naproxen
]
def load_dataset(
max_molecules: int = DATASET_SIZE,
) -> tuple[list[str], list[float | None]]:
"""
Load molecules from the ZINC-250k dataset.
Downloads on first call, caches locally for subsequent runs.
Toxicity scores are None — computed dynamically in molecule_processor.py.
Falls back to built-in list if download fails.
"""
cache_dir = Path(DATA_CACHE_DIR)
cache_dir.mkdir(parents=True, exist_ok=True)
cache_path = cache_dir / ZINC_FILENAME
if not cache_path.exists():
try:
import urllib.request
logger.info("Downloading ZINC-250k from %s ...", ZINC_250K_URL)
urllib.request.urlretrieve(ZINC_250K_URL, str(cache_path))
except Exception:
logger.warning("Download failed; using %d fallback molecules", len(_FALLBACK_SMILES))
smiles = _FALLBACK_SMILES[:max_molecules]
return smiles, [None] * len(smiles)
seen: set[str] = set()
valid_smiles: list[str] = []
with open(cache_path, newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
raw = row.get("smiles", "").strip()
if not raw:
continue
mol = Chem.MolFromSmiles(raw)
if mol is None or mol.GetNumAtoms() == 0:
continue
canonical = Chem.MolToSmiles(mol)
if canonical in seen:
continue
seen.add(canonical)
valid_smiles.append(canonical)
if len(valid_smiles) >= max_molecules:
break
toxicity_scores: list[float | None] = [None] * len(valid_smiles)
logger.info("Dataset ready: %d molecules", len(valid_smiles))
return valid_smiles, toxicity_scores
DATASET_SIZE defaults to 2000, controlled via MOLSEARCH_DATASET_SIZE env var. Toxicity scores are all None here because ZINC has no toxicity labels. They get computed per molecule in the validation step.
Validation and canonicalization runs over every SMILES. RDKit checks that the string parses to a valid molecule and converts it to canonical form so there is exactly one text representation per structure regardless of how it was originally written.
Heads up on stereochemistry: RDKit canonicalization does not preserve stereocenters if they are not explicitly defined. Two enantiomers can end up with identical canonical SMILES and identical vectors. If chirality matters for your use case, handle it before this step.
from __future__ import annotations
import logging
import math
from rdkit import Chem
from rdkit.Chem import Descriptors, rdMolDescriptors
logger = logging.getLogger(__name__)
def compute_toxicity_proxy(mol: Chem.Mol) -> float:
"""
Heuristic toxicity estimate from RDKit descriptors.
NOT a real toxicity prediction. Rule-based proxy using properties
commonly associated with toxicity risk. Returns float 0.0–1.0.
"""
mw = Descriptors.MolWt(mol)
logp = Descriptors.MolLogP(mol)
hbd = Descriptors.NumHDonors(mol)
hba = Descriptors.NumHAcceptors(mol)
tpsa = Descriptors.TPSA(mol)
n_arom = rdMolDescriptors.CalcNumAromaticRings(mol)
mw_score = max(0.0, min((mw - 350) / 450, 1.0))
logp_score = max(0.0, min((logp - 3) / 5, 1.0))
hbd_score = max(0.0, min((hbd - 2) / 5, 1.0))
hba_score = max(0.0, min((hba - 5) / 10, 1.0))
arom_score = min(n_arom / 6, 1.0)
tpsa_score = max(0.0, min((75 - tpsa) / 75, 1.0))
raw = (
0.25 * mw_score
+ 0.25 * logp_score
+ 0.15 * hbd_score
+ 0.15 * hba_score
+ 0.10 * arom_score
+ 0.10 * tpsa_score
)
return round(max(0.0, min(raw, 1.0)), 3)
def validate_and_canonicalize(
smiles: str,
toxicity_score: float | None = None,
) -> dict | None:
"""
Validate a SMILES string and return canonical form with descriptors.
Returns None if the SMILES is invalid.
"""
normalized = smiles.strip()
if not normalized:
return None
mol = Chem.MolFromSmiles(normalized)
if mol is None:
return None
num_atoms = mol.GetNumAtoms()
if num_atoms == 0 or mol.GetNumBonds() == 0:
return None
canonical_smiles = Chem.MolToSmiles(mol)
# Round-trip check: re-parse canonical SMILES and verify atom count
verify_mol = Chem.MolFromSmiles(canonical_smiles)
if verify_mol is None or verify_mol.GetNumAtoms() != num_atoms:
return None
payload = {
"smiles": canonical_smiles,
"molecular_weight": round(Descriptors.MolWt(mol), 2),
"logp": round(Descriptors.MolLogP(mol), 2),
"num_h_donors": Descriptors.NumHDonors(mol),
"num_h_acceptors": Descriptors.NumHAcceptors(mol),
"tpsa": round(Descriptors.TPSA(mol), 2),
}
if toxicity_score is not None:
if not math.isfinite(toxicity_score):
raise ValueError("toxicity_score must be a finite float")
payload["toxicity_score"] = float(toxicity_score)
else:
payload["toxicity_score"] = compute_toxicity_proxy(mol)
return payload
Two things worth calling out. When toxicity_score is None (always the case for ZINC data), the function falls through to compute_toxicity_proxy. Every molecule in the index ends up with a toxicity_score in its payload, so filtered searches on that field always have something to work with. The proxy is a weighted combination of six RDKit descriptors. It is a deterministic heuristic, not a toxicity model.
The round-trip check after canonicalization catches edge cases where RDKit generates a canonical form it then cannot parse back. Rare, but it happens.
Step 2: Generate Molecular Embeddings with ChemBERTa
ChemBERTa is a RoBERTa model pre-trained on SMILES strings from ZINC. Checkpoint: seyonec/ChemBERTa-zinc-base-v1 on Hugging Face.
The model does not come with a dedicated pooling layer. To get a single vector per molecule, you pool the token-level outputs yourself. I use mean pooling — averaging the output vectors across all non-padding tokens. This is the standard approach for RoBERTa-based models because the [CLS] token (a special first token added to every input in BERT-style models) was not trained with a sentence-level objective. Taking just that token gives unstable representations. Averaging across all tokens works better.
from __future__ import annotations
import logging
import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer
logger = logging.getLogger(__name__)
MODEL_NAME = "seyonec/ChemBERTa-zinc-base-v1"
VECTOR_DIM = 768
BATCH_SIZE = 32
def _get_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
class MoleculeEmbedder:
"""Generates L2-normalized ChemBERTa embeddings from SMILES."""
def __init__(self, model_name: str = MODEL_NAME):
self.device = _get_device()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.model.to(self.device)
self.model.eval()
self.vector_dim = self.model.config.hidden_size
if self.vector_dim != VECTOR_DIM:
raise ValueError(
f"Model hidden size ({self.vector_dim}) does not match "
f"VECTOR_DIM ({VECTOR_DIM})"
)
logger.info("Loaded %s on %s", model_name, self.device)
def _mean_pool(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
mask_expanded = (
attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
)
sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1)
sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
return sum_hidden / sum_mask
def embed(
self,
smiles_list: list[str],
batch_size: int = BATCH_SIZE,
) -> np.ndarray:
"""
Embed SMILES into dense vectors.
Returns (n, vector_dim) float32 array, L2-normalized.
Oversized SMILES (>400 chars) are skipped per-molecule and filled with np.nan.
"""
n = len(smiles_list)
if n == 0:
return np.empty((0, self.vector_dim), dtype=np.float32)
result = np.empty((n, self.vector_dim), dtype=np.float32)
for start in range(0, n, batch_size):
end = min(start + batch_size, n)
batch = smiles_list[start:end]
# Handle oversized SMILES per-molecule, not per-batch
oversized = [i for i, s in enumerate(batch) if len(s) > 400]
if oversized:
for idx in oversized:
logger.warning(
"SMILES at index %d too long (%d chars), skipping",
start + idx, len(batch[idx]),
)
result[start + idx] = np.nan
safe_indices = [i for i in range(len(batch)) if i not in oversized]
if not safe_indices:
continue
batch = [batch[i] for i in safe_indices]
else:
safe_indices = list(range(len(batch)))
encoded = self.tokenizer(
batch, padding=True, truncation=False, return_tensors="pt"
)
if encoded["input_ids"].shape[1] > 512:
logger.error("Batch %d–%d: context limit exceeded", start, end)
result[start:end] = np.nan
continue
encoded = {k: v.to(self.device) for k, v in encoded.items()}
with torch.no_grad():
outputs = self.model(**encoded)
embeddings = self._mean_pool(
outputs.last_hidden_state, encoded["attention_mask"]
)
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
batch_result = embeddings.cpu().numpy()
for out_i, safe_i in enumerate(safe_indices):
result[start + safe_i] = batch_result[out_i]
return result
A few implementation details worth knowing.
truncation=False is deliberate. Silently truncating a SMILES string produces an embedding that only represents part of the molecule — that is worse than no embedding at all. Instead, any SMILES over 400 characters is flagged and skipped per-molecule. If the tokenized batch still exceeds 512 tokens after that, the whole batch gets rejected. Failed slots are filled with np.nan so downstream code can detect the problem rather than silently work with bad data.
Result array is pre-allocated with np.empty rather than appending batch results to a list. For large datasets, np.vstack temporarily doubles memory because it copies everything into a new contiguous block. Pre-allocation avoids that.
CPU performance: expect somewhere around 100–300 molecules per second depending on SMILES length. Above 100k molecules, GPU inference saves a meaningful amount of time.
Step 3: Index Embeddings in Qdrant
With valid molecules and their embeddings ready, the next step is storing them in Qdrant. Each point holds a vector and a JSON payload with canonical SMILES and all the descriptors from Step 1.
from __future__ import annotations
import uuid
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams
COLLECTION_NAME = "molecules"
VECTOR_DIM = 768
UPSERT_BATCH_SIZE = 1000
USE_PERSISTENT_QDRANT = False
QDRANT_HOST = "localhost"
QDRANT_PORT = 6333
def get_qdrant_client() -> QdrantClient:
"""In-memory for development, persistent for production."""
if USE_PERSISTENT_QDRANT:
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT, timeout=10)
client.get_collections() # verify connection
return client
return QdrantClient(":memory:")
def create_collection(client: QdrantClient) -> None:
"""Non-destructive: reuse existing collection if vector size matches."""
if client.collection_exists(collection_name=COLLECTION_NAME):
info = client.get_collection(collection_name=COLLECTION_NAME)
size = getattr(info.config.params.vectors, "size", None)
if size is not None and size != VECTOR_DIM:
raise ValueError(
f"Existing collection has vector size {size}, expected {VECTOR_DIM}"
)
return
client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=VECTOR_DIM, distance=Distance.COSINE),
)
def _smiles_to_uuid(smiles: str) -> str:
"""Deterministic UUID from canonical SMILES. Safe for incremental upserts."""
namespace = uuid.UUID("6ba7b810-9dad-11d1-80b4-00c04fd430c8")
return str(uuid.uuid5(namespace, smiles))
def upsert_molecules(
client: QdrantClient,
molecules: list[dict],
embeddings: np.ndarray,
batch_size: int = UPSERT_BATCH_SIZE,
) -> None:
"""Upsert molecule vectors and payloads in batches. Skips nan/inf/zero vectors."""
n = len(molecules)
valid_mask = (
~np.isnan(embeddings).any(axis=1)
& ~np.isinf(embeddings).any(axis=1)
)
valid_mask &= np.linalg.norm(embeddings, axis=1) > 0.0
for start in range(0, n, batch_size):
end = min(start + batch_size, n)
points = [
PointStruct(
id=_smiles_to_uuid(mol["smiles"]),
vector=emb.tolist(),
payload=mol,
)
for i, (mol, emb) in enumerate(
zip(molecules[start:end], embeddings[start:end]), start=start
)
if valid_mask[i]
]
if points:
client.upsert(collection_name=COLLECTION_NAME, points=points)
# Create payload indexes for filtered search
from qdrant_client.models import PayloadSchemaType
for field_name in ["molecular_weight", "logp", "toxicity_score"]:
client.create_payload_index(
collection_name=COLLECTION_NAME,
field_name=field_name,
field_schema=PayloadSchemaType.FLOAT,
)
Point IDs are generated using uuid.uuid5 derived from canonical SMILES. Re-indexing the same molecule just overwrites the existing entry. Re-running the indexing script without wiping the collection is safe.
The payload index step matters at scale. Without it, filtered searches slow down noticeably once the collection grows, because Qdrant has to scan payload fields without an index.
Step 4: Similarity Search with Hybrid Ranking
Given a query SMILES, embed it with the same ChemBERTa model and search Qdrant for nearest neighbors.
from __future__ import annotations
from rdkit.Chem import rdFingerprintGenerator, DataStructs
from qdrant_client.models import Filter, FieldCondition, Range
def search_similar_molecules(
query_smiles: str,
embedder: MoleculeEmbedder,
client: QdrantClient,
collection_name: str = COLLECTION_NAME,
top_k: int = 5,
mw_max: float | None = None,
logp_max: float | None = None,
toxicity_max: float | None = None,
) -> list[dict]:
try:
embeddings = embedder.embed([query_smiles])
if np.any(np.isnan(embeddings)):
raise ValueError("query vector generation failed (nan)")
query_vector = embeddings[0].tolist()
except Exception as exc:
logger.error("Failed to embed query SMILES: %s", exc)
raise RuntimeError(f"Query embedding failed: {exc}") from exc
conditions = []
if mw_max is not None:
conditions.append(FieldCondition(key="molecular_weight", range=Range(lte=mw_max)))
if logp_max is not None:
conditions.append(FieldCondition(key="logp", range=Range(lte=logp_max)))
if toxicity_max is not None:
conditions.append(FieldCondition(key="toxicity_score", range=Range(lte=toxicity_max)))
query_filter = Filter(must=conditions) if conditions else None
fetch_limit = top_k * 5 # fetch wider pool before reranking
results = None
for attempt in range(3):
try:
results = client.query_points(
collection_name=collection_name,
query=query_vector,
query_filter=query_filter,
limit=fetch_limit,
with_payload=True,
timeout=30,
)
break
except Exception as exc:
if attempt == 2:
logger.error("Qdrant query failed after 3 attempts: %s", exc)
return []
if results is None:
return []
query_mol = Chem.MolFromSmiles(query_smiles)
if query_mol is not None:
generator = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=2048)
query_fp = generator.GetFingerprint(query_mol)
else:
query_fp = None
hits = []
for point in results.points:
payload = point.payload or {}
hit_smiles = payload.get("smiles", "")
tanimoto = 0.0
if query_fp is not None and hit_smiles:
hit_mol = Chem.MolFromSmiles(hit_smiles)
if hit_mol is not None:
hit_fp = generator.GetFingerprint(hit_mol)
tanimoto = DataStructs.TanimotoSimilarity(query_fp, hit_fp)
# TODO: make fusion weights configurable
fused_score = 0.5 * point.score + 0.5 * tanimoto
hits.append({
"smiles": hit_smiles,
"score": round(point.score, 4),
"tanimoto_score": round(tanimoto, 4),
"fused_score": round(fused_score, 4),
"molecular_weight": payload.get("molecular_weight", 0.0),
"logp": payload.get("logp", 0.0),
"toxicity_score": (
payload.get("toxicity_score")
if isinstance(payload.get("toxicity_score"), (int, float)) else None
),
})
hits.sort(key=lambda x: x["fused_score"], reverse=True)
return hits[:top_k]
The search retrieves top_k * 5 candidates from Qdrant before reranking. HNSW is approximate, so fetching a wider pool gives a better shot at returning the actual best matches after fusion.
The final ranking combines two signals: ChemBERTa cosine similarity (learned structural patterns) and Tanimoto fingerprint score (explicit substructure overlap). The equal 50/50 weight split is a starting point. The TODO in the code is intentional — for any real use case, tune those weights against your own data.
Step 5: FastAPI Service
# api_server.py
from __future__ import annotations
import asyncio
import logging
import math
from contextlib import asynccontextmanager
from functools import partial
from fastapi import FastAPI, HTTPException, Response
from pydantic import BaseModel, Field
from rdkit import Chem
from molsearch.config import MAX_SMILES_LENGTH
from molsearch.data_loader import load_dataset
from molsearch.embedder import MoleculeEmbedder
from molsearch.molecule_processor import process_smiles_batch
from molsearch.qdrant_indexer import (
check_system_health, collection_exists_and_populated,
create_collection, create_payload_indexes,
get_qdrant_client, search_similar_molecules, upsert_molecules,
)
logger = logging.getLogger(__name__)
_embedder: MoleculeEmbedder | None = None
_client = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global _embedder, _client
try:
_embedder = MoleculeEmbedder()
except Exception as exc:
logger.error("Failed to load embedder: %s", exc)
_embedder = None
try:
_client = get_qdrant_client()
create_collection(_client)
create_payload_indexes(_client)
if not collection_exists_and_populated(_client):
if _embedder is not None:
smiles_list, toxicity_scores = load_dataset()
molecules = process_smiles_batch(smiles_list, toxicity_scores=toxicity_scores)
embeddings = _embedder.embed([m["smiles"] for m in molecules])
upsert_molecules(_client, molecules, embeddings)
logger.info("Indexed %d molecules.", len(molecules))
else:
logger.info("Collection already populated — skipping indexing.")
except Exception as exc:
logger.error("Failed to initialize Qdrant: %s", exc)
_client = None
yield
_client = None
_embedder = None
app = FastAPI(title="Molecule Similarity Search API", lifespan=lifespan)
class SearchRequest(BaseModel):
smiles: str = Field(..., min_length=1, max_length=MAX_SMILES_LENGTH)
top_k: int = Field(default=5, ge=1, le=100)
mw_max: float | None = Field(default=None, ge=0, allow_inf_nan=False)
logp_max: float | None = Field(default=None, allow_inf_nan=False)
toxicity_max: float | None = Field(default=None, ge=0, allow_inf_nan=False)
class MoleculeHit(BaseModel):
smiles: str
score: float
molecular_weight: float
logp: float
toxicity_score: float | None = None
tanimoto_score: float | None = None
fused_score: float | None = None
class SearchResponse(BaseModel):
query_smiles: str
canonical_smiles: str
results: list[MoleculeHit]
@app.post("/search", response_model=SearchResponse)
async def search(request: SearchRequest):
if _embedder is None or _client is None:
raise HTTPException(status_code=503, detail="Service not initialized")
query_smiles = request.smiles.strip()
if not query_smiles:
raise HTTPException(status_code=400, detail="empty smiles")
mol = Chem.MolFromSmiles(query_smiles)
if mol is None or mol.GetNumAtoms() == 0:
raise HTTPException(status_code=400, detail="invalid smiles")
canonical = Chem.MolToSmiles(mol)
if len(canonical) > 400:
raise HTTPException(status_code=400, detail={
"error": "too_long",
"message": "smiles exceeds length limit",
"length": len(canonical),
})
loop = asyncio.get_running_loop()
try:
hits = await loop.run_in_executor(
None,
partial(
search_similar_molecules,
query_smiles=canonical,
embedder=_embedder,
client=_client,
top_k=request.top_k,
mw_max=request.mw_max,
logp_max=request.logp_max,
toxicity_max=request.toxicity_max,
),
)
except RuntimeError as exc:
raise HTTPException(status_code=500, detail=f"Search failed: {exc}") from exc
return SearchResponse(
query_smiles=request.smiles,
canonical_smiles=canonical,
results=[MoleculeHit(**h) for h in hits],
)
@app.get("/health")
def health(response: Response):
health_status = check_system_health(_embedder, _client)
response.status_code = 200 if health_status["status"] == "ok" else 503
return health_status
ChemBERTa inference is CPU-bound. Running it directly inside the FastAPI endpoint would block the async event loop. run_in_executor hands the transformer inference and Qdrant search to a separate thread, keeping the event loop free for other requests.
uvicorn molsearch.api_server:app --host 0.0.0.0 --port 8000 --reload
Test it:
curl -X POST http://localhost:8000/search \
-H "Content-Type: application/json" \
-d '{"smiles": "CC(=O)Oc1ccccc1C(=O)O", "top_k": 3}'
Step 6: Streamlit UI
# streamlit_app.py
from __future__ import annotations
import streamlit as st
from rdkit import Chem
from rdkit.Chem import Descriptors, Draw
from molsearch.data_loader import load_dataset
from molsearch.embedder import MoleculeEmbedder
from molsearch.molecule_processor import process_smiles_batch
from molsearch.qdrant_indexer import (
collection_exists_and_populated, create_collection,
create_payload_indexes, get_qdrant_client,
search_similar_molecules, upsert_molecules,
)
st.set_page_config(page_title="Molecule Similarity Search", layout="wide")
st.title("Molecule Similarity Search")
st.caption("Powered by ChemBERTa embeddings and Qdrant vector search")
@st.cache_resource
def load_resources():
embedder = MoleculeEmbedder()
client = get_qdrant_client()
create_collection(client)
create_payload_indexes(client)
if not collection_exists_and_populated(client):
smiles_list, toxicity_scores = load_dataset()
molecules = process_smiles_batch(smiles_list, toxicity_scores=toxicity_scores)
embeddings = embedder.embed([m["smiles"] for m in molecules])
upsert_molecules(client, molecules, embeddings)
return embedder, client
embedder, client = load_resources()
st.sidebar.header("Search Parameters")
query_smiles = st.sidebar.text_input(
"SMILES string",
value="CC(=O)Oc1ccccc1C(=O)O",
help="Enter a valid SMILES string",
)
top_k = st.sidebar.slider("Number of results", min_value=1, max_value=20, value=5)
use_mw_filter = st.sidebar.checkbox("Filter by molecular weight")
mw_filter = (
st.sidebar.number_input("Max molecular weight", value=500.0, step=50.0, min_value=0.0)
if use_mw_filter else None
)
use_logp_filter = st.sidebar.checkbox("Filter by LogP")
logp_filter = (
st.sidebar.number_input("Max LogP", value=5.0, step=0.5)
if use_logp_filter else None
)
use_tox_filter = st.sidebar.checkbox("Filter by toxicity")
toxicity_filter = (
st.sidebar.number_input("Max toxicity score", value=0.5, step=0.1, min_value=0.0)
if use_tox_filter else None
)
search_clicked = st.sidebar.button("Search", type="primary", use_container_width=True)
if search_clicked:
mol = Chem.MolFromSmiles(query_smiles.strip())
if mol is None or mol.GetNumAtoms() == 0:
st.error(f"Invalid SMILES: {query_smiles}")
else:
canonical_query = Chem.MolToSmiles(mol)
col_query, col_info = st.columns([1, 2])
with col_query:
st.subheader("Query Molecule")
st.image(Draw.MolToImage(mol, size=(300, 300)), caption=canonical_query)
with col_info:
st.subheader("Query Info")
st.metric("Molecular Weight", f"{Descriptors.MolWt(mol):.2f}")
st.metric("LogP", f"{Descriptors.MolLogP(mol):.2f}")
st.divider()
results = search_similar_molecules(
query_smiles=canonical_query, embedder=embedder, client=client,
top_k=top_k, mw_max=mw_filter, logp_max=logp_filter, toxicity_max=toxicity_filter,
)
st.subheader(f"Top {len(results)} Similar Molecules")
if not results:
st.info("No results found. Try relaxing the filters.")
else:
for i, hit in enumerate(results):
with st.container():
c1, c2 = st.columns([1, 2])
with c1:
hit_mol = Chem.MolFromSmiles(hit["smiles"])
if hit_mol:
st.image(Draw.MolToImage(hit_mol, size=(250, 250)))
with c2:
st.markdown(
f"**Rank {i + 1}** | "
f"Fused: **{hit['fused_score']}** "
f"(Tanimoto: {hit['tanimoto_score']}, Latent: {hit['score']})"
)
st.code(hit["smiles"], language=None)
toxicity_text = (
f"{hit['toxicity_score']:.3f}"
if isinstance(hit.get("toxicity_score"), (int, float)) else "n/a"
)
st.write(f"MW: {hit['molecular_weight']} | LogP: {hit['logp']} | Toxicity: {toxicity_text}")
st.divider()
else:
st.info("Enter a SMILES string in the sidebar and click Search.")
streamlit run src/molsearch/streamlit_app.py
@st.cache_resource keeps the 400MB ChemBERTa model from reloading on every interaction. Without it, adjusting any slider would reload the entire model. The load_dataset() call inside load_resources() also only runs once per server session.
End-to-End Standalone Script
If you want to run the full pipeline without the modular project structure:
"""
molecule_search_pipeline.py
End-to-end molecular similarity search.
Usage: python molecule_search_pipeline.py
"""
from __future__ import annotations
import uuid
import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem import Descriptors
from transformers import AutoTokenizer, AutoModel
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
MODEL_NAME = "seyonec/ChemBERTa-zinc-base-v1"
COLLECTION_NAME = "molecules"
BATCH_SIZE = 32
def process_smiles(smiles_list: list[str]) -> list[dict]:
results = []
for smi in smiles_list:
mol = Chem.MolFromSmiles(smi)
if mol is None:
continue
results.append({
"smiles": Chem.MolToSmiles(mol),
"molecular_weight": round(Descriptors.MolWt(mol), 2),
"logp": round(Descriptors.MolLogP(mol), 2),
"num_h_donors": Descriptors.NumHDonors(mol),
"num_h_acceptors": Descriptors.NumHAcceptors(mol),
"tpsa": round(Descriptors.TPSA(mol), 2),
})
return results
def load_model(model_name: str = MODEL_NAME):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()
return tokenizer, model, model.config.hidden_size
def embed_smiles(smiles_list, tokenizer, model, vector_dim=768) -> np.ndarray:
if not smiles_list:
return np.empty((0, vector_dim), dtype=np.float32)
all_embeddings = []
for i in range(0, len(smiles_list), BATCH_SIZE):
batch = smiles_list[i:i + BATCH_SIZE]
encoded = tokenizer(batch, padding=True, truncation=True, max_length=512, return_tensors="pt")
with torch.no_grad():
outputs = model(**encoded)
mask = encoded["attention_mask"].unsqueeze(-1).expand(outputs.last_hidden_state.size()).float()
summed = torch.sum(outputs.last_hidden_state * mask, dim=1)
counted = torch.clamp(mask.sum(dim=1), min=1e-9)
embeddings = summed / counted
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
all_embeddings.append(embeddings.cpu().numpy())
return np.vstack(all_embeddings)
def index_molecules(client, molecules, embeddings, vector_dim):
if client.collection_exists(collection_name=COLLECTION_NAME):
client.delete_collection(collection_name=COLLECTION_NAME)
client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=vector_dim, distance=Distance.COSINE),
)
points = [
PointStruct(
id=str(uuid.uuid5(uuid.UUID("6ba7b810-9dad-11d1-80b4-00c04fd430c8"), mol["smiles"])),
vector=emb.tolist(),
payload=mol,
)
for mol, emb in zip(molecules, embeddings)
]
client.upsert(collection_name=COLLECTION_NAME, points=points)
print(f" Indexed {len(points)} molecules.")
def search(client, query_smiles, tokenizer, model, vector_dim=768, top_k=5):
from qdrant_client.models import Filter, FieldCondition, Range
mol = Chem.MolFromSmiles(query_smiles)
if mol is None:
raise ValueError(f"Invalid query SMILES: {query_smiles}")
canonical = Chem.MolToSmiles(mol)
query_vec = embed_smiles([canonical], tokenizer, model, vector_dim)[0].tolist()
results = client.query_points(
collection_name=COLLECTION_NAME,
query=query_vec,
limit=top_k,
with_payload=True,
)
return results.points
def main():
from molsearch.data_loader import load_dataset
print("Loading ZINC-250k dataset...")
raw_smiles, _ = load_dataset()
print(f" {len(raw_smiles)} molecules loaded.\n")
print("Step 1: Processing SMILES...")
molecules = process_smiles(raw_smiles)
print(f" {len(molecules)} valid molecules.\n")
print("Step 2: Loading model and generating embeddings...")
tokenizer, model, vector_dim = load_model()
embeddings = embed_smiles([m["smiles"] for m in molecules], tokenizer, model, vector_dim)
print(f" Embedding matrix shape: {embeddings.shape}\n")
print("Step 3: Indexing in Qdrant...")
client = QdrantClient(":memory:")
index_molecules(client, molecules, embeddings, vector_dim)
print("\nStep 4: Searching...")
query = "CC(=O)Oc1ccccc1C(=O)O" # Aspirin
print(f" Query: Aspirin ({query})")
hits = search(client, query, tokenizer, model, vector_dim, top_k=5)
print(f"\n {'Rank':<6}{'SMILES':<50}{'Score':<10}")
print(f" {'-' * 66}")
for i, hit in enumerate(hits, 1):
print(f" {i:<6}{hit.payload['smiles']:<50}{hit.score:<10.4f}")
if __name__ == "__main__":
main()
Expected output:
Step 1: Processing SMILES...
2000 valid molecules.
Step 2: Loading model and generating embeddings...
Embedding matrix shape: (2000, 768)
Step 3: Indexing in Qdrant...
Indexed 2000 molecules.
Step 4: Searching...
Query: Aspirin (CC(=O)Oc1ccccc1C(=O)O)
Rank SMILES Score
------------------------------------------------------------------
1 CC(=O)Oc1ccccc1C(=O)O 1.0000
2 OC(=O)c1ccccc1O 0.95xx
3 OC(=O)c1ccccc1 0.93xx
...
Scaling Considerations
The example above indexes 2000 molecules. Once the pipeline was working end to end, the obvious question was: what changes at millions?
HNSW Tuning
from qdrant_client.models import HnswConfigDiff, OptimizersConfigDiff
client.create_collection(
collection_name="molecules_prod",
vectors_config=VectorParams(size=768, distance=Distance.COSINE),
hnsw_config=HnswConfigDiff(
m=16, # edges per node (default 16)
ef_construct=128, # search depth during construction (default 100)
),
optimizers_config=OptimizersConfigDiff(
indexing_threshold=20000, # triggers around ~6.7k vectors at 768-dim float32
),
)
| Parameter | Low (fast, lower recall) | High (slower, higher recall) |
|---|---|---|
m |
8 | 32 |
ef_construct |
64 | 256 |
Query ef
|
64 | 256 |
For similarity search, recall matters more than shaving milliseconds. A missed candidate is a missed result.
Set ef at query time:
from qdrant_client import models
results = client.query_points(
collection_name=COLLECTION_NAME,
query=query_vec,
search_params=models.SearchParams(hnsw_ef=128),
limit=top_k,
with_payload=True,
)
Embedding at Scale
For datasets over 100k molecules:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
encoded = tokenizer(batch, padding=True, truncation=True, max_length=512, return_tensors="pt")
encoded = {k: v.to(device) for k, v in encoded.items()}
Other optimizations that help: increase batch size to 128 or 256, use torch.amp.autocast("cuda") for mixed-precision inference, pre-compute embeddings and store in Parquet files so you do not rerun the model on every experiment iteration.
Persistent Deployment
docker run -p 6333:6333 -v $(pwd)/qdrant_storage:/qdrant/storage qdrant/qdrant
client = QdrantClient(host="localhost", port=6333)
Dev to production is one line. The rest of the code does not change.
Honest Limitations
No experiment writeup is complete without the part where you explain where it breaks.
The toxicity score is not a toxicity model. compute_toxicity_proxy produces a number between 0 and 1 based on six RDKit descriptors. It is useful for demonstrating filtered retrieval. Do not make safety decisions with it.
ChemBERTa operates in 1D and 2D. It reads SMILES strings and learns topological patterns. It has no concept of 3D geometry. If shape complementarity matters for your target, this pipeline alone is not enough. You would need to layer in 3D-aware models or shape-based screening.
Similarity scores are structural estimates. The base model was trained on ZINC using masked language modeling with no bioactivity labels. Two molecules that look close in vector space can behave completely differently in an assay. Vector retrieval is a starting point, not a replacement for experimental validation.
Implementation-level gotchas:
- Canonical SMILES can differ across RDKit versions. Pin your version and keep it consistent between indexing and query environments.
- HNSW is approximate. If you need high recall, increase
efat query time. - Salt-containing SMILES like
[Na+].[Cl-]will be parsed as a whole fragment including the counterion. Addrdkit.Chem.SaltRemover.SaltRemoverto preprocessing if that matters for your data. - The token limit for ChemBERTa is 512. The production implementation rejects SMILES over 400 characters before tokenization. The standalone script uses
truncation=True, which silently cuts long inputs. Know which behavior you are running.
Where to Take This Next
Full code: github.com/dvy246/qdrant
Three directions worth thinking about.
Fine-tuning. The base model has no activity labels, so the similarity it measures is purely structural. Fine-tuning ChemBERTa on specific assay data with a contrastive objective — so active molecules pull together in vector space — is the most direct way to make retrieval useful for a specific target. Without that, you are doing structural exploration, not activity prediction.
Automated ingestion. Connect Qdrant updates into your compound registration workflow so new molecules get indexed as they come in rather than in periodic batch jobs. Qdrant's deterministic UUID approach makes this safe. Re-indexing the same molecule just overwrites.
Target-aware search. Add protein pocket embeddings alongside molecular embeddings so retrieval accounts for both the molecule and the target. That starts to look more like a real virtual screening tool.
Conclusion
The experiment worked, and honestly it worked better than I expected.
The pipeline loads 2000 molecules from ZINC-250k, validates and canonicalizes each SMILES with RDKit (computing a heuristic toxicity score per molecule while doing it), converts each structure into a 768-dimensional vector using ChemBERTa, stores that in Qdrant with molecular property metadata, and retrieves structurally similar candidates using cosine similarity. Filters on molecular weight, LogP, and toxicity score apply during retrieval rather than after it. Final ranking combines embedding similarity with a Tanimoto fingerprint score so you get both the learned structural space and explicit substructure overlap in one result.
This does not replace fingerprint search. For exact scaffold matching, Tanimoto is faster, simpler, and more interpretable. What this gives you is a way to find candidates that share structural and physicochemical patterns even when their fingerprints look different. That is the case where fingerprint search fails.
The hard parts were not the transformer or the vector database. They were the data cleaning, the SMILES canonicalization edge cases, and deciding when to trust the similarity scores. Those are the things worth understanding before scaling this up.
If you are a developer exploring vector search beyond text, molecules turn out to be a surprisingly good playground. The data is structured, the search requirements are concrete, and the gap between what fingerprints can do and what embeddings can do is real enough to make the experiment worthwhile.
References
- ChemBERTa paper — arXiv:2010.09885
- ChemBERTa model card — seyonec/ChemBERTa-zinc-base-v1
- ZINC-250k dataset
- RDKit documentation
- Qdrant documentation
- Qdrant Query API
- qdrant-client on PyPI
- Hugging Face Transformers
- FastAPI documentation
- Streamlit documentation
- PyTorch documentation
- ZINC database
- HNSW algorithm — arXiv:1603.09320
- MoleculeNet benchmarks
- Sentence-BERT — arXiv:1908.10084
- PEP 604 — Union types

Top comments (0)