DEV Community

Cover image for I Tried Vector Search on Molecules. Here Is What Actually Happened.
Divy Yadav
Divy Yadav

Posted on

I Tried Vector Search on Molecules. Here Is What Actually Happened.

In this article, I'll walk you through how I built a robust molecular similarity search system using ChemBERTa, RDKit, Qdrant, and what I actually learned along the way.

Canonical URL: https://medium.com/towards-artificial-intelligence/i-tried-vector-search-on-molecules-heres-what-happened-7391b755efe4


TL;DR
I wanted to see if vector search could work on molecules the same way it works on text. It can. I used ChemBERTa to embed SMILES strings into 768-dim vectors, indexed them in Qdrant with molecular property metadata, and ran similarity search with payload filters applied during retrieval. The system surfaced structurally similar candidates that fingerprint-based search missed. This post walks through every step, including where it breaks.

GitHub: github.com/dvy246/qdrant
Stack: Python 3.10+, RDKit, ChemBERTa, Qdrant, FastAPI, Streamlit


Why I Built This

I had been spending a lot of time with vector databases and embedding-based search. Every example I came across was about text: document search, FAQ retrieval, chatbot memory.

At some point, I asked myself a question that seemed a bit strange at the time: Can I apply the same thing to molecules?

Molecules can be written as text strings. SMILES (Simplified Molecular Input Line Entry System) is a notation that encodes molecular structure as a string. CC(=O)Oc1ccccc1C(=O)O is aspirin. If you can write a molecule as a string, you can feed it to a transformer. And if you can get a dense vector out of a transformer, you can search it with a vector database.

I had been reading about ChemBERTa, a transformer trained specifically on SMILES strings. Qdrant had just shipped a clean Python client with native payload filtering. I thought: what happens if I connect these two things?

So I built a small experiment and ran it.

Here is what the system does:

  • Downloads and caches the ZINC-250k dataset (2000 drug-like molecules by default)
  • Validates and canonicalizes each SMILES string using RDKit, computing a heuristic toxicity proxy per molecule
  • Embeds each molecule into a 768-dimensional vector using ChemBERTa
  • Indexes those vectors in Qdrant with molecular property metadata as searchable payload
  • Retrieves similar molecules using cosine similarity with optional filters on MW, LogP, and toxicity
  • Serves results through a FastAPI endpoint or a Streamlit UI

The Problem with Fingerprints

Most molecule search pipelines use fingerprint similarity. The basic idea: take a molecule, generate a Morgan fingerprint (a fixed-length bit vector that records which molecular fragments are present), compute a Tanimoto score against a query molecule (overlap between the two bit vectors, 0 to 1), and rank by that score.

For obvious structural matches, this works well. Fast, simple, interpretable.

But when I was testing it on a small set of compounds, I kept seeing it miss molecules that looked chemically related when I inspected them manually. Not enough fragment overlap for the score to reflect that relationship. The system just did not surface them.

*That gap is what pushed me toward trying something different.
*

The core issue is that compressing a molecule into a fixed-length bit vector loses information. Different substructures can hash to the same bit position. That collision is gone before you run a query. You are searching for a lossy representation of the molecule, not the molecule itself.

This gets particularly visible in what chemists call scaffold hopping: two molecules that act on the same biological target through structurally different cores. Their fingerprints can look completely different even when they are functionally related. Tanimoto has no way to catch that.

There is also the activity cliff problem. Adding a single methyl group (one carbon with three hydrogens) can completely flip how a molecule behaves biologically. Tanimoto treats that as a single-bit difference with no chemical context.

And at scale, Tanimoto is a full linear scan. With millions of compounds and thousands of queries, that is billions of pairwise comparisons. It requires dedicated engineering to keep running at a reasonable speed.

None of this was a deal-breaker for a small experiment. But it was enough to make me curious whether an approach that learned something about molecular structure, rather than just counting fragments, could surface what fingerprints miss.


The Idea: Embeddings Over Fragments

Instead of hashing molecules into fragment bit vectors, I wanted to try converting each molecule into a dense vector that captured its structural and physicochemical character in continuous space. The bet: similar molecules would land close together in that space even when their fingerprints look different.

ChemBERTa handles that. It is a RoBERTa-based transformer trained on SMILES strings from the ZINC database. During training, it learns to predict masked tokens from surrounding context, the same pretraining approach BERT uses for text. Because SMILES encodes molecular structure in a consistent grammar, the model picks up structural patterns across millions of molecules. The result is a 768-number vector per molecule.

One thing to be clear about: these vectors are structural estimates, not measures of biological activity. The base model was trained without any assay data. You are measuring learned structural closeness, not functional equivalence. That limitation matters and comes up again in the limitations section.

Once molecules are represented as vectors, approximate nearest neighbor search replaces the linear scan. That is what Qdrant handles.


Pipeline Overview

Image Showing Pipeline

Five stages, each doing one thing:

SMILES strings
    │
    ▼
RDKit validation and canonicalization
    │
    ▼
ChemBERTa embedding (Hugging Face Transformers)
    │
    ▼
Qdrant vector indexing (upsert with metadata payloads)
    │
    ▼
Similarity search API (FastAPI) or Web UI (Streamlit)
Enter fullscreen mode Exit fullscreen mode

Raw SMILES come in first. RDKit validates and standardizes them. Anything that does not parse gets dropped before it reaches the model, and more comes in invalid than you might expect from a curated source.

Validated SMILES go to ChemBERTa (seyonec/ChemBERTa-zinc-base-v1). Mean pooling on the final hidden state gives a 768-dimensional vector per molecule. That vector, along with metadata like molecular weight and LogP, gets stored as a point in Qdrant. HNSW handles the indexing. At query time, the same model embeds the query, Qdrant searches by cosine similarity, and results can be filtered by molecular properties during retrieval, not after.


Why Qdrant

I had one core requirement: apply numeric filters on molecular properties during retrieval, not after it. That ruled out more options than I expected.

Pinecone was the easiest to get running. Basic vector search worked fine. The problem: combining similarity search with numeric range filters on molecular weight and LogP produced inconsistent behavior.

Milvus gave good raw search performance. The problem: local installation required spinning up etcd and MinIO just to run a dev experiment. Not acceptable for something I was iterating on quickly.

Weaviate had a clean schema system I liked. The problem: chaining multiple numeric range filters with a vector search query felt awkward. The semantics were not intuitive enough for this use case.

Qdrant solved the specific problems the others did not.

The key difference is how it handles filters. In this kind of search you rarely query by vector alone. You want to say: find molecules similar to this one, but only where molecular weight is under 500 and LogP is below 5. Qdrant applies those constraints natively during HNSW graph traversal, not after. Post-filtering sounds like the same thing but it is not. If your filter is restrictive, you silently end up with fewer than top-k results because candidates are removed after the search finishes. That was not acceptable for how I needed this to work.

The in-memory client (QdrantClient(":memory:")) was a big deal for development. I could build and test the entire pipeline on a laptop without Docker or any external service. Moving to a persistent Qdrant instance later was one line. That made the iteration loop fast.


When to Use Embeddings vs. Fingerprints

If you are working with a well-characterized scaffold series where structure-activity relationships are already mapped, ECFP4 and Tanimoto are probably still the right tool. Fast, interpretable, no model required. Adding an embedding pipeline for that use case adds complexity without giving much back.

Embeddings start to make sense when you are trying to surface candidates that do not share obvious substructure with your query, or when your library is large enough that a linear Tanimoto scan becomes a bottleneck.

ChemBERTa is not the only option. Mol2Vec produces lighter embeddings and runs faster on CPU. GNN-based models can encode 3D molecular geometry, which matters if shape drives binding for your target. I went with ChemBERTa because it takes SMILES as direct input and is well-supported on Hugging Face. For this kind of experiment, that simplicity was worth more than the extra capability of a more complex model.


Similarity Metric

Before indexing, every vector gets L2 normalized. Once vectors are unit length, cosine similarity and dot product give identical results.

I went with cosine similarity for one practical reason: the scores are easy to read. Cosine scores sit between -1 and 1, so a score of 0.95 versus 0.72 tells you something concrete when you are staring at retrieval results. When you are debugging why certain molecules surface and others do not, readable scores matter.


Environment Setup

Python 3.10+ required. All dependencies from PyPI.

# Core cheminformatics
pip install rdkit

# Transformer framework, tokenizer, and PyTorch
pip install transformers torch

# Vector database client (requires Qdrant server v1.10+ for the Query API)
pip install "qdrant-client>=1.10.0"

# API or UI (pick one or both)
pip install fastapi uvicorn
pip install streamlit
Enter fullscreen mode Exit fullscreen mode

Note: The PyPI package is rdkit, not rdkit-pypi (older name). The codebase uses X | Y union type syntax which requires Python 3.10+. ChemBERTa runs on CPU for smaller datasets but GPU inference is significantly faster above 100k molecules.

Verify before starting:

import rdkit
from rdkit import Chem
from transformers import AutoTokenizer, AutoModel
from qdrant_client import QdrantClient

print(f"RDKit version: {rdkit.__version__}")
print("All imports successful.")
Enter fullscreen mode Exit fullscreen mode

Building It: Step by Step

Step 1: Load the Dataset and Validate SMILES

SMILES strings from real databases are messy. The same molecule can have multiple valid representations. Some entries contain salts or stereoisomer annotations. Others are just invalid. Feeding any of that directly to a transformer produces garbage embeddings or silent failures.

I went with ZINC-250k, a widely used benchmark subset of drug-like molecules from the ZINC database. It is the same source ChemBERTa was pre-trained on, which makes it a natural fit.

# data_loader.py
from __future__ import annotations

import csv
import logging
from pathlib import Path

from rdkit import Chem
from molsearch.config import DATA_CACHE_DIR, DATASET_SIZE

logger = logging.getLogger(__name__)

ZINC_250K_URL = "https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv"
ZINC_FILENAME = "zinc_250k.csv"

_FALLBACK_SMILES: list[str] = [
    "CC(=O)Oc1ccccc1C(=O)O",          # Aspirin
    "CC(C)Cc1ccc(cc1)C(C)C(=O)O",     # Ibuprofen
    "CC(=O)Nc1ccc(O)cc1",             # Acetaminophen
    "O=C(O)Cc1ccccc1Nc1c(Cl)cccc1Cl", # Diclofenac
    "COc1ccc2cc(CC(C)C(=O)O)ccc2c1",  # Naproxen
]


def load_dataset(
    max_molecules: int = DATASET_SIZE,
) -> tuple[list[str], list[float | None]]:
    """
    Load molecules from the ZINC-250k dataset.
    Downloads on first call, caches locally for subsequent runs.
    Toxicity scores are None — computed dynamically in molecule_processor.py.
    Falls back to built-in list if download fails.
    """
    cache_dir = Path(DATA_CACHE_DIR)
    cache_dir.mkdir(parents=True, exist_ok=True)
    cache_path = cache_dir / ZINC_FILENAME

    if not cache_path.exists():
        try:
            import urllib.request
            logger.info("Downloading ZINC-250k from %s ...", ZINC_250K_URL)
            urllib.request.urlretrieve(ZINC_250K_URL, str(cache_path))
        except Exception:
            logger.warning("Download failed; using %d fallback molecules", len(_FALLBACK_SMILES))
            smiles = _FALLBACK_SMILES[:max_molecules]
            return smiles, [None] * len(smiles)

    seen: set[str] = set()
    valid_smiles: list[str] = []

    with open(cache_path, newline="", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            raw = row.get("smiles", "").strip()
            if not raw:
                continue
            mol = Chem.MolFromSmiles(raw)
            if mol is None or mol.GetNumAtoms() == 0:
                continue
            canonical = Chem.MolToSmiles(mol)
            if canonical in seen:
                continue
            seen.add(canonical)
            valid_smiles.append(canonical)
            if len(valid_smiles) >= max_molecules:
                break

    toxicity_scores: list[float | None] = [None] * len(valid_smiles)
    logger.info("Dataset ready: %d molecules", len(valid_smiles))
    return valid_smiles, toxicity_scores
Enter fullscreen mode Exit fullscreen mode

DATASET_SIZE defaults to 2000, controlled via MOLSEARCH_DATASET_SIZE env var. Toxicity scores are all None here because ZINC has no toxicity labels. They get computed per molecule in the validation step.

Validation and canonicalization runs over every SMILES. RDKit checks that the string parses to a valid molecule and converts it to canonical form so there is exactly one text representation per structure regardless of how it was originally written.

Heads up on stereochemistry: RDKit canonicalization does not preserve stereocenters if they are not explicitly defined. Two enantiomers can end up with identical canonical SMILES and identical vectors. If chirality matters for your use case, handle it before this step.

from __future__ import annotations

import logging
import math

from rdkit import Chem
from rdkit.Chem import Descriptors, rdMolDescriptors

logger = logging.getLogger(__name__)


def compute_toxicity_proxy(mol: Chem.Mol) -> float:
    """
    Heuristic toxicity estimate from RDKit descriptors.

    NOT a real toxicity prediction. Rule-based proxy using properties
    commonly associated with toxicity risk. Returns float 0.0–1.0.
    """
    mw     = Descriptors.MolWt(mol)
    logp   = Descriptors.MolLogP(mol)
    hbd    = Descriptors.NumHDonors(mol)
    hba    = Descriptors.NumHAcceptors(mol)
    tpsa   = Descriptors.TPSA(mol)
    n_arom = rdMolDescriptors.CalcNumAromaticRings(mol)

    mw_score   = max(0.0, min((mw - 350)  / 450, 1.0))
    logp_score = max(0.0, min((logp - 3)  / 5,   1.0))
    hbd_score  = max(0.0, min((hbd - 2)   / 5,   1.0))
    hba_score  = max(0.0, min((hba - 5)   / 10,  1.0))
    arom_score = min(n_arom / 6, 1.0)
    tpsa_score = max(0.0, min((75 - tpsa) / 75,  1.0))

    raw = (
        0.25 * mw_score
        + 0.25 * logp_score
        + 0.15 * hbd_score
        + 0.15 * hba_score
        + 0.10 * arom_score
        + 0.10 * tpsa_score
    )
    return round(max(0.0, min(raw, 1.0)), 3)


def validate_and_canonicalize(
    smiles: str,
    toxicity_score: float | None = None,
) -> dict | None:
    """
    Validate a SMILES string and return canonical form with descriptors.
    Returns None if the SMILES is invalid.
    """
    normalized = smiles.strip()
    if not normalized:
        return None

    mol = Chem.MolFromSmiles(normalized)
    if mol is None:
        return None

    num_atoms = mol.GetNumAtoms()
    if num_atoms == 0 or mol.GetNumBonds() == 0:
        return None

    canonical_smiles = Chem.MolToSmiles(mol)

    # Round-trip check: re-parse canonical SMILES and verify atom count
    verify_mol = Chem.MolFromSmiles(canonical_smiles)
    if verify_mol is None or verify_mol.GetNumAtoms() != num_atoms:
        return None

    payload = {
        "smiles":            canonical_smiles,
        "molecular_weight":  round(Descriptors.MolWt(mol), 2),
        "logp":              round(Descriptors.MolLogP(mol), 2),
        "num_h_donors":      Descriptors.NumHDonors(mol),
        "num_h_acceptors":   Descriptors.NumHAcceptors(mol),
        "tpsa":              round(Descriptors.TPSA(mol), 2),
    }

    if toxicity_score is not None:
        if not math.isfinite(toxicity_score):
            raise ValueError("toxicity_score must be a finite float")
        payload["toxicity_score"] = float(toxicity_score)
    else:
        payload["toxicity_score"] = compute_toxicity_proxy(mol)

    return payload
Enter fullscreen mode Exit fullscreen mode

Two things worth calling out. When toxicity_score is None (always the case for ZINC data), the function falls through to compute_toxicity_proxy. Every molecule in the index ends up with a toxicity_score in its payload, so filtered searches on that field always have something to work with. The proxy is a weighted combination of six RDKit descriptors. It is a deterministic heuristic, not a toxicity model.

The round-trip check after canonicalization catches edge cases where RDKit generates a canonical form it then cannot parse back. Rare, but it happens.


Step 2: Generate Molecular Embeddings with ChemBERTa

ChemBERTa is a RoBERTa model pre-trained on SMILES strings from ZINC. Checkpoint: seyonec/ChemBERTa-zinc-base-v1 on Hugging Face.

The model does not come with a dedicated pooling layer. To get a single vector per molecule, you pool the token-level outputs yourself. I use mean pooling — averaging the output vectors across all non-padding tokens. This is the standard approach for RoBERTa-based models because the [CLS] token (a special first token added to every input in BERT-style models) was not trained with a sentence-level objective. Taking just that token gives unstable representations. Averaging across all tokens works better.

from __future__ import annotations

import logging
import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer

logger = logging.getLogger(__name__)

MODEL_NAME = "seyonec/ChemBERTa-zinc-base-v1"
VECTOR_DIM  = 768
BATCH_SIZE  = 32


def _get_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


class MoleculeEmbedder:
    """Generates L2-normalized ChemBERTa embeddings from SMILES."""

    def __init__(self, model_name: str = MODEL_NAME):
        self.device    = _get_device()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model     = AutoModel.from_pretrained(model_name)
        self.model.to(self.device)
        self.model.eval()
        self.vector_dim = self.model.config.hidden_size
        if self.vector_dim != VECTOR_DIM:
            raise ValueError(
                f"Model hidden size ({self.vector_dim}) does not match "
                f"VECTOR_DIM ({VECTOR_DIM})"
            )
        logger.info("Loaded %s on %s", model_name, self.device)

    def _mean_pool(
        self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        mask_expanded = (
            attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        )
        sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1)
        sum_mask   = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
        return sum_hidden / sum_mask

    def embed(
        self,
        smiles_list: list[str],
        batch_size: int = BATCH_SIZE,
    ) -> np.ndarray:
        """
        Embed SMILES into dense vectors.
        Returns (n, vector_dim) float32 array, L2-normalized.
        Oversized SMILES (>400 chars) are skipped per-molecule and filled with np.nan.
        """
        n = len(smiles_list)
        if n == 0:
            return np.empty((0, self.vector_dim), dtype=np.float32)

        result = np.empty((n, self.vector_dim), dtype=np.float32)

        for start in range(0, n, batch_size):
            end   = min(start + batch_size, n)
            batch = smiles_list[start:end]

            # Handle oversized SMILES per-molecule, not per-batch
            oversized = [i for i, s in enumerate(batch) if len(s) > 400]
            if oversized:
                for idx in oversized:
                    logger.warning(
                        "SMILES at index %d too long (%d chars), skipping",
                        start + idx, len(batch[idx]),
                    )
                    result[start + idx] = np.nan
                safe_indices = [i for i in range(len(batch)) if i not in oversized]
                if not safe_indices:
                    continue
                batch = [batch[i] for i in safe_indices]
            else:
                safe_indices = list(range(len(batch)))

            encoded = self.tokenizer(
                batch, padding=True, truncation=False, return_tensors="pt"
            )
            if encoded["input_ids"].shape[1] > 512:
                logger.error("Batch %d–%d: context limit exceeded", start, end)
                result[start:end] = np.nan
                continue

            encoded = {k: v.to(self.device) for k, v in encoded.items()}
            with torch.no_grad():
                outputs = self.model(**encoded)

            embeddings = self._mean_pool(
                outputs.last_hidden_state, encoded["attention_mask"]
            )
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
            batch_result = embeddings.cpu().numpy()

            for out_i, safe_i in enumerate(safe_indices):
                result[start + safe_i] = batch_result[out_i]

        return result
Enter fullscreen mode Exit fullscreen mode

A few implementation details worth knowing.

truncation=False is deliberate. Silently truncating a SMILES string produces an embedding that only represents part of the molecule — that is worse than no embedding at all. Instead, any SMILES over 400 characters is flagged and skipped per-molecule. If the tokenized batch still exceeds 512 tokens after that, the whole batch gets rejected. Failed slots are filled with np.nan so downstream code can detect the problem rather than silently work with bad data.

Result array is pre-allocated with np.empty rather than appending batch results to a list. For large datasets, np.vstack temporarily doubles memory because it copies everything into a new contiguous block. Pre-allocation avoids that.

CPU performance: expect somewhere around 100–300 molecules per second depending on SMILES length. Above 100k molecules, GPU inference saves a meaningful amount of time.


Step 3: Index Embeddings in Qdrant

With valid molecules and their embeddings ready, the next step is storing them in Qdrant. Each point holds a vector and a JSON payload with canonical SMILES and all the descriptors from Step 1.

from __future__ import annotations

import uuid
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams

COLLECTION_NAME    = "molecules"
VECTOR_DIM         = 768
UPSERT_BATCH_SIZE  = 1000
USE_PERSISTENT_QDRANT = False
QDRANT_HOST        = "localhost"
QDRANT_PORT        = 6333


def get_qdrant_client() -> QdrantClient:
    """In-memory for development, persistent for production."""
    if USE_PERSISTENT_QDRANT:
        client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT, timeout=10)
        client.get_collections()  # verify connection
        return client
    return QdrantClient(":memory:")


def create_collection(client: QdrantClient) -> None:
    """Non-destructive: reuse existing collection if vector size matches."""
    if client.collection_exists(collection_name=COLLECTION_NAME):
        info = client.get_collection(collection_name=COLLECTION_NAME)
        size = getattr(info.config.params.vectors, "size", None)
        if size is not None and size != VECTOR_DIM:
            raise ValueError(
                f"Existing collection has vector size {size}, expected {VECTOR_DIM}"
            )
        return
    client.create_collection(
        collection_name=COLLECTION_NAME,
        vectors_config=VectorParams(size=VECTOR_DIM, distance=Distance.COSINE),
    )


def _smiles_to_uuid(smiles: str) -> str:
    """Deterministic UUID from canonical SMILES. Safe for incremental upserts."""
    namespace = uuid.UUID("6ba7b810-9dad-11d1-80b4-00c04fd430c8")
    return str(uuid.uuid5(namespace, smiles))


def upsert_molecules(
    client: QdrantClient,
    molecules: list[dict],
    embeddings: np.ndarray,
    batch_size: int = UPSERT_BATCH_SIZE,
) -> None:
    """Upsert molecule vectors and payloads in batches. Skips nan/inf/zero vectors."""
    n = len(molecules)
    valid_mask = (
        ~np.isnan(embeddings).any(axis=1)
        & ~np.isinf(embeddings).any(axis=1)
    )
    valid_mask &= np.linalg.norm(embeddings, axis=1) > 0.0

    for start in range(0, n, batch_size):
        end = min(start + batch_size, n)
        points = [
            PointStruct(
                id=_smiles_to_uuid(mol["smiles"]),
                vector=emb.tolist(),
                payload=mol,
            )
            for i, (mol, emb) in enumerate(
                zip(molecules[start:end], embeddings[start:end]), start=start
            )
            if valid_mask[i]
        ]
        if points:
            client.upsert(collection_name=COLLECTION_NAME, points=points)


# Create payload indexes for filtered search
from qdrant_client.models import PayloadSchemaType

for field_name in ["molecular_weight", "logp", "toxicity_score"]:
    client.create_payload_index(
        collection_name=COLLECTION_NAME,
        field_name=field_name,
        field_schema=PayloadSchemaType.FLOAT,
    )
Enter fullscreen mode Exit fullscreen mode

Point IDs are generated using uuid.uuid5 derived from canonical SMILES. Re-indexing the same molecule just overwrites the existing entry. Re-running the indexing script without wiping the collection is safe.

The payload index step matters at scale. Without it, filtered searches slow down noticeably once the collection grows, because Qdrant has to scan payload fields without an index.


Step 4: Similarity Search with Hybrid Ranking

Given a query SMILES, embed it with the same ChemBERTa model and search Qdrant for nearest neighbors.

from __future__ import annotations

from rdkit.Chem import rdFingerprintGenerator, DataStructs
from qdrant_client.models import Filter, FieldCondition, Range


def search_similar_molecules(
    query_smiles: str,
    embedder: MoleculeEmbedder,
    client: QdrantClient,
    collection_name: str = COLLECTION_NAME,
    top_k: int = 5,
    mw_max: float | None = None,
    logp_max: float | None = None,
    toxicity_max: float | None = None,
) -> list[dict]:
    try:
        embeddings = embedder.embed([query_smiles])
        if np.any(np.isnan(embeddings)):
            raise ValueError("query vector generation failed (nan)")
        query_vector = embeddings[0].tolist()
    except Exception as exc:
        logger.error("Failed to embed query SMILES: %s", exc)
        raise RuntimeError(f"Query embedding failed: {exc}") from exc

    conditions = []
    if mw_max is not None:
        conditions.append(FieldCondition(key="molecular_weight", range=Range(lte=mw_max)))
    if logp_max is not None:
        conditions.append(FieldCondition(key="logp", range=Range(lte=logp_max)))
    if toxicity_max is not None:
        conditions.append(FieldCondition(key="toxicity_score", range=Range(lte=toxicity_max)))

    query_filter = Filter(must=conditions) if conditions else None
    fetch_limit  = top_k * 5  # fetch wider pool before reranking

    results = None
    for attempt in range(3):
        try:
            results = client.query_points(
                collection_name=collection_name,
                query=query_vector,
                query_filter=query_filter,
                limit=fetch_limit,
                with_payload=True,
                timeout=30,
            )
            break
        except Exception as exc:
            if attempt == 2:
                logger.error("Qdrant query failed after 3 attempts: %s", exc)
                return []

    if results is None:
        return []

    query_mol = Chem.MolFromSmiles(query_smiles)
    if query_mol is not None:
        generator = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=2048)
        query_fp  = generator.GetFingerprint(query_mol)
    else:
        query_fp = None

    hits = []
    for point in results.points:
        payload    = point.payload or {}
        hit_smiles = payload.get("smiles", "")
        tanimoto   = 0.0

        if query_fp is not None and hit_smiles:
            hit_mol = Chem.MolFromSmiles(hit_smiles)
            if hit_mol is not None:
                hit_fp   = generator.GetFingerprint(hit_mol)
                tanimoto = DataStructs.TanimotoSimilarity(query_fp, hit_fp)

        # TODO: make fusion weights configurable
        fused_score = 0.5 * point.score + 0.5 * tanimoto

        hits.append({
            "smiles":           hit_smiles,
            "score":            round(point.score, 4),
            "tanimoto_score":   round(tanimoto, 4),
            "fused_score":      round(fused_score, 4),
            "molecular_weight": payload.get("molecular_weight", 0.0),
            "logp":             payload.get("logp", 0.0),
            "toxicity_score":   (
                payload.get("toxicity_score")
                if isinstance(payload.get("toxicity_score"), (int, float)) else None
            ),
        })

    hits.sort(key=lambda x: x["fused_score"], reverse=True)
    return hits[:top_k]
Enter fullscreen mode Exit fullscreen mode

The search retrieves top_k * 5 candidates from Qdrant before reranking. HNSW is approximate, so fetching a wider pool gives a better shot at returning the actual best matches after fusion.

The final ranking combines two signals: ChemBERTa cosine similarity (learned structural patterns) and Tanimoto fingerprint score (explicit substructure overlap). The equal 50/50 weight split is a starting point. The TODO in the code is intentional — for any real use case, tune those weights against your own data.


Step 5: FastAPI Service

# api_server.py
from __future__ import annotations

import asyncio
import logging
import math
from contextlib import asynccontextmanager
from functools import partial

from fastapi import FastAPI, HTTPException, Response
from pydantic import BaseModel, Field
from rdkit import Chem

from molsearch.config import MAX_SMILES_LENGTH
from molsearch.data_loader import load_dataset
from molsearch.embedder import MoleculeEmbedder
from molsearch.molecule_processor import process_smiles_batch
from molsearch.qdrant_indexer import (
    check_system_health, collection_exists_and_populated,
    create_collection, create_payload_indexes,
    get_qdrant_client, search_similar_molecules, upsert_molecules,
)

logger = logging.getLogger(__name__)

_embedder: MoleculeEmbedder | None = None
_client = None


@asynccontextmanager
async def lifespan(app: FastAPI):
    global _embedder, _client
    try:
        _embedder = MoleculeEmbedder()
    except Exception as exc:
        logger.error("Failed to load embedder: %s", exc)
        _embedder = None

    try:
        _client = get_qdrant_client()
        create_collection(_client)
        create_payload_indexes(_client)
        if not collection_exists_and_populated(_client):
            if _embedder is not None:
                smiles_list, toxicity_scores = load_dataset()
                molecules   = process_smiles_batch(smiles_list, toxicity_scores=toxicity_scores)
                embeddings  = _embedder.embed([m["smiles"] for m in molecules])
                upsert_molecules(_client, molecules, embeddings)
                logger.info("Indexed %d molecules.", len(molecules))
        else:
            logger.info("Collection already populated — skipping indexing.")
    except Exception as exc:
        logger.error("Failed to initialize Qdrant: %s", exc)
        _client = None

    yield
    _client   = None
    _embedder = None


app = FastAPI(title="Molecule Similarity Search API", lifespan=lifespan)


class SearchRequest(BaseModel):
    smiles:       str         = Field(..., min_length=1, max_length=MAX_SMILES_LENGTH)
    top_k:        int         = Field(default=5, ge=1, le=100)
    mw_max:       float | None = Field(default=None, ge=0, allow_inf_nan=False)
    logp_max:     float | None = Field(default=None, allow_inf_nan=False)
    toxicity_max: float | None = Field(default=None, ge=0, allow_inf_nan=False)


class MoleculeHit(BaseModel):
    smiles:           str
    score:            float
    molecular_weight: float
    logp:             float
    toxicity_score:   float | None = None
    tanimoto_score:   float | None = None
    fused_score:      float | None = None


class SearchResponse(BaseModel):
    query_smiles:    str
    canonical_smiles: str
    results:         list[MoleculeHit]


@app.post("/search", response_model=SearchResponse)
async def search(request: SearchRequest):
    if _embedder is None or _client is None:
        raise HTTPException(status_code=503, detail="Service not initialized")

    query_smiles = request.smiles.strip()
    if not query_smiles:
        raise HTTPException(status_code=400, detail="empty smiles")

    mol = Chem.MolFromSmiles(query_smiles)
    if mol is None or mol.GetNumAtoms() == 0:
        raise HTTPException(status_code=400, detail="invalid smiles")

    canonical = Chem.MolToSmiles(mol)
    if len(canonical) > 400:
        raise HTTPException(status_code=400, detail={
            "error":   "too_long",
            "message": "smiles exceeds length limit",
            "length":  len(canonical),
        })

    loop = asyncio.get_running_loop()
    try:
        hits = await loop.run_in_executor(
            None,
            partial(
                search_similar_molecules,
                query_smiles=canonical,
                embedder=_embedder,
                client=_client,
                top_k=request.top_k,
                mw_max=request.mw_max,
                logp_max=request.logp_max,
                toxicity_max=request.toxicity_max,
            ),
        )
    except RuntimeError as exc:
        raise HTTPException(status_code=500, detail=f"Search failed: {exc}") from exc

    return SearchResponse(
        query_smiles=request.smiles,
        canonical_smiles=canonical,
        results=[MoleculeHit(**h) for h in hits],
    )


@app.get("/health")
def health(response: Response):
    health_status = check_system_health(_embedder, _client)
    response.status_code = 200 if health_status["status"] == "ok" else 503
    return health_status
Enter fullscreen mode Exit fullscreen mode

ChemBERTa inference is CPU-bound. Running it directly inside the FastAPI endpoint would block the async event loop. run_in_executor hands the transformer inference and Qdrant search to a separate thread, keeping the event loop free for other requests.

uvicorn molsearch.api_server:app --host 0.0.0.0 --port 8000 --reload
Enter fullscreen mode Exit fullscreen mode

Test it:

curl -X POST http://localhost:8000/search \
  -H "Content-Type: application/json" \
  -d '{"smiles": "CC(=O)Oc1ccccc1C(=O)O", "top_k": 3}'
Enter fullscreen mode Exit fullscreen mode

Step 6: Streamlit UI

# streamlit_app.py
from __future__ import annotations

import streamlit as st
from rdkit import Chem
from rdkit.Chem import Descriptors, Draw

from molsearch.data_loader import load_dataset
from molsearch.embedder import MoleculeEmbedder
from molsearch.molecule_processor import process_smiles_batch
from molsearch.qdrant_indexer import (
    collection_exists_and_populated, create_collection,
    create_payload_indexes, get_qdrant_client,
    search_similar_molecules, upsert_molecules,
)

st.set_page_config(page_title="Molecule Similarity Search", layout="wide")
st.title("Molecule Similarity Search")
st.caption("Powered by ChemBERTa embeddings and Qdrant vector search")


@st.cache_resource
def load_resources():
    embedder = MoleculeEmbedder()
    client   = get_qdrant_client()
    create_collection(client)
    create_payload_indexes(client)
    if not collection_exists_and_populated(client):
        smiles_list, toxicity_scores = load_dataset()
        molecules  = process_smiles_batch(smiles_list, toxicity_scores=toxicity_scores)
        embeddings = embedder.embed([m["smiles"] for m in molecules])
        upsert_molecules(client, molecules, embeddings)
    return embedder, client


embedder, client = load_resources()

st.sidebar.header("Search Parameters")
query_smiles = st.sidebar.text_input(
    "SMILES string",
    value="CC(=O)Oc1ccccc1C(=O)O",
    help="Enter a valid SMILES string",
)
top_k = st.sidebar.slider("Number of results", min_value=1, max_value=20, value=5)

use_mw_filter = st.sidebar.checkbox("Filter by molecular weight")
mw_filter     = (
    st.sidebar.number_input("Max molecular weight", value=500.0, step=50.0, min_value=0.0)
    if use_mw_filter else None
)

use_logp_filter = st.sidebar.checkbox("Filter by LogP")
logp_filter     = (
    st.sidebar.number_input("Max LogP", value=5.0, step=0.5)
    if use_logp_filter else None
)

use_tox_filter  = st.sidebar.checkbox("Filter by toxicity")
toxicity_filter = (
    st.sidebar.number_input("Max toxicity score", value=0.5, step=0.1, min_value=0.0)
    if use_tox_filter else None
)

search_clicked = st.sidebar.button("Search", type="primary", use_container_width=True)

if search_clicked:
    mol = Chem.MolFromSmiles(query_smiles.strip())
    if mol is None or mol.GetNumAtoms() == 0:
        st.error(f"Invalid SMILES: {query_smiles}")
    else:
        canonical_query = Chem.MolToSmiles(mol)
        col_query, col_info = st.columns([1, 2])
        with col_query:
            st.subheader("Query Molecule")
            st.image(Draw.MolToImage(mol, size=(300, 300)), caption=canonical_query)
        with col_info:
            st.subheader("Query Info")
            st.metric("Molecular Weight", f"{Descriptors.MolWt(mol):.2f}")
            st.metric("LogP", f"{Descriptors.MolLogP(mol):.2f}")
        st.divider()

        results = search_similar_molecules(
            query_smiles=canonical_query, embedder=embedder, client=client,
            top_k=top_k, mw_max=mw_filter, logp_max=logp_filter, toxicity_max=toxicity_filter,
        )
        st.subheader(f"Top {len(results)} Similar Molecules")

        if not results:
            st.info("No results found. Try relaxing the filters.")
        else:
            for i, hit in enumerate(results):
                with st.container():
                    c1, c2 = st.columns([1, 2])
                    with c1:
                        hit_mol = Chem.MolFromSmiles(hit["smiles"])
                        if hit_mol:
                            st.image(Draw.MolToImage(hit_mol, size=(250, 250)))
                    with c2:
                        st.markdown(
                            f"**Rank {i + 1}** | "
                            f"Fused: **{hit['fused_score']}** "
                            f"(Tanimoto: {hit['tanimoto_score']}, Latent: {hit['score']})"
                        )
                        st.code(hit["smiles"], language=None)
                        toxicity_text = (
                            f"{hit['toxicity_score']:.3f}"
                            if isinstance(hit.get("toxicity_score"), (int, float)) else "n/a"
                        )
                        st.write(f"MW: {hit['molecular_weight']} | LogP: {hit['logp']} | Toxicity: {toxicity_text}")
                    st.divider()
else:
    st.info("Enter a SMILES string in the sidebar and click Search.")
Enter fullscreen mode Exit fullscreen mode
streamlit run src/molsearch/streamlit_app.py
Enter fullscreen mode Exit fullscreen mode

@st.cache_resource keeps the 400MB ChemBERTa model from reloading on every interaction. Without it, adjusting any slider would reload the entire model. The load_dataset() call inside load_resources() also only runs once per server session.


End-to-End Standalone Script

If you want to run the full pipeline without the modular project structure:

"""
molecule_search_pipeline.py
End-to-end molecular similarity search.
Usage: python molecule_search_pipeline.py
"""
from __future__ import annotations

import uuid
import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem import Descriptors
from transformers import AutoTokenizer, AutoModel
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct

MODEL_NAME      = "seyonec/ChemBERTa-zinc-base-v1"
COLLECTION_NAME = "molecules"
BATCH_SIZE      = 32


def process_smiles(smiles_list: list[str]) -> list[dict]:
    results = []
    for smi in smiles_list:
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            continue
        results.append({
            "smiles":           Chem.MolToSmiles(mol),
            "molecular_weight": round(Descriptors.MolWt(mol), 2),
            "logp":             round(Descriptors.MolLogP(mol), 2),
            "num_h_donors":     Descriptors.NumHDonors(mol),
            "num_h_acceptors":  Descriptors.NumHAcceptors(mol),
            "tpsa":             round(Descriptors.TPSA(mol), 2),
        })
    return results


def load_model(model_name: str = MODEL_NAME):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model     = AutoModel.from_pretrained(model_name)
    model.eval()
    return tokenizer, model, model.config.hidden_size


def embed_smiles(smiles_list, tokenizer, model, vector_dim=768) -> np.ndarray:
    if not smiles_list:
        return np.empty((0, vector_dim), dtype=np.float32)
    all_embeddings = []
    for i in range(0, len(smiles_list), BATCH_SIZE):
        batch   = smiles_list[i:i + BATCH_SIZE]
        encoded = tokenizer(batch, padding=True, truncation=True, max_length=512, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**encoded)
        mask     = encoded["attention_mask"].unsqueeze(-1).expand(outputs.last_hidden_state.size()).float()
        summed   = torch.sum(outputs.last_hidden_state * mask, dim=1)
        counted  = torch.clamp(mask.sum(dim=1), min=1e-9)
        embeddings = summed / counted
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        all_embeddings.append(embeddings.cpu().numpy())
    return np.vstack(all_embeddings)


def index_molecules(client, molecules, embeddings, vector_dim):
    if client.collection_exists(collection_name=COLLECTION_NAME):
        client.delete_collection(collection_name=COLLECTION_NAME)
    client.create_collection(
        collection_name=COLLECTION_NAME,
        vectors_config=VectorParams(size=vector_dim, distance=Distance.COSINE),
    )
    points = [
        PointStruct(
            id=str(uuid.uuid5(uuid.UUID("6ba7b810-9dad-11d1-80b4-00c04fd430c8"), mol["smiles"])),
            vector=emb.tolist(),
            payload=mol,
        )
        for mol, emb in zip(molecules, embeddings)
    ]
    client.upsert(collection_name=COLLECTION_NAME, points=points)
    print(f"  Indexed {len(points)} molecules.")


def search(client, query_smiles, tokenizer, model, vector_dim=768, top_k=5):
    from qdrant_client.models import Filter, FieldCondition, Range
    mol = Chem.MolFromSmiles(query_smiles)
    if mol is None:
        raise ValueError(f"Invalid query SMILES: {query_smiles}")
    canonical  = Chem.MolToSmiles(mol)
    query_vec  = embed_smiles([canonical], tokenizer, model, vector_dim)[0].tolist()
    results    = client.query_points(
        collection_name=COLLECTION_NAME,
        query=query_vec,
        limit=top_k,
        with_payload=True,
    )
    return results.points


def main():
    from molsearch.data_loader import load_dataset

    print("Loading ZINC-250k dataset...")
    raw_smiles, _ = load_dataset()
    print(f"  {len(raw_smiles)} molecules loaded.\n")

    print("Step 1: Processing SMILES...")
    molecules = process_smiles(raw_smiles)
    print(f"  {len(molecules)} valid molecules.\n")

    print("Step 2: Loading model and generating embeddings...")
    tokenizer, model, vector_dim = load_model()
    embeddings = embed_smiles([m["smiles"] for m in molecules], tokenizer, model, vector_dim)
    print(f"  Embedding matrix shape: {embeddings.shape}\n")

    print("Step 3: Indexing in Qdrant...")
    client = QdrantClient(":memory:")
    index_molecules(client, molecules, embeddings, vector_dim)

    print("\nStep 4: Searching...")
    query = "CC(=O)Oc1ccccc1C(=O)O"  # Aspirin
    print(f"  Query: Aspirin ({query})")
    hits = search(client, query, tokenizer, model, vector_dim, top_k=5)
    print(f"\n  {'Rank':<6}{'SMILES':<50}{'Score':<10}")
    print(f"  {'-' * 66}")
    for i, hit in enumerate(hits, 1):
        print(f"  {i:<6}{hit.payload['smiles']:<50}{hit.score:<10.4f}")


if __name__ == "__main__":
    main()
Enter fullscreen mode Exit fullscreen mode

Expected output:

Step 1: Processing SMILES...
  2000 valid molecules.

Step 2: Loading model and generating embeddings...
  Embedding matrix shape: (2000, 768)

Step 3: Indexing in Qdrant...
  Indexed 2000 molecules.

Step 4: Searching...
  Query: Aspirin (CC(=O)Oc1ccccc1C(=O)O)

  Rank  SMILES                                            Score
  ------------------------------------------------------------------
  1     CC(=O)Oc1ccccc1C(=O)O                             1.0000
  2     OC(=O)c1ccccc1O                                   0.95xx
  3     OC(=O)c1ccccc1                                    0.93xx
  ...
Enter fullscreen mode Exit fullscreen mode

Scaling Considerations

The example above indexes 2000 molecules. Once the pipeline was working end to end, the obvious question was: what changes at millions?

HNSW Tuning

from qdrant_client.models import HnswConfigDiff, OptimizersConfigDiff

client.create_collection(
    collection_name="molecules_prod",
    vectors_config=VectorParams(size=768, distance=Distance.COSINE),
    hnsw_config=HnswConfigDiff(
        m=16,             # edges per node (default 16)
        ef_construct=128, # search depth during construction (default 100)
    ),
    optimizers_config=OptimizersConfigDiff(
        indexing_threshold=20000,  # triggers around ~6.7k vectors at 768-dim float32
    ),
)
Enter fullscreen mode Exit fullscreen mode
Parameter Low (fast, lower recall) High (slower, higher recall)
m 8 32
ef_construct 64 256
Query ef 64 256

For similarity search, recall matters more than shaving milliseconds. A missed candidate is a missed result.

Set ef at query time:

from qdrant_client import models

results = client.query_points(
    collection_name=COLLECTION_NAME,
    query=query_vec,
    search_params=models.SearchParams(hnsw_ef=128),
    limit=top_k,
    with_payload=True,
)
Enter fullscreen mode Exit fullscreen mode

Embedding at Scale

For datasets over 100k molecules:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

encoded = tokenizer(batch, padding=True, truncation=True, max_length=512, return_tensors="pt")
encoded = {k: v.to(device) for k, v in encoded.items()}
Enter fullscreen mode Exit fullscreen mode

Other optimizations that help: increase batch size to 128 or 256, use torch.amp.autocast("cuda") for mixed-precision inference, pre-compute embeddings and store in Parquet files so you do not rerun the model on every experiment iteration.

Persistent Deployment

docker run -p 6333:6333 -v $(pwd)/qdrant_storage:/qdrant/storage qdrant/qdrant
Enter fullscreen mode Exit fullscreen mode
client = QdrantClient(host="localhost", port=6333)
Enter fullscreen mode Exit fullscreen mode

Dev to production is one line. The rest of the code does not change.


Honest Limitations

No experiment writeup is complete without the part where you explain where it breaks.

The toxicity score is not a toxicity model. compute_toxicity_proxy produces a number between 0 and 1 based on six RDKit descriptors. It is useful for demonstrating filtered retrieval. Do not make safety decisions with it.

ChemBERTa operates in 1D and 2D. It reads SMILES strings and learns topological patterns. It has no concept of 3D geometry. If shape complementarity matters for your target, this pipeline alone is not enough. You would need to layer in 3D-aware models or shape-based screening.

Similarity scores are structural estimates. The base model was trained on ZINC using masked language modeling with no bioactivity labels. Two molecules that look close in vector space can behave completely differently in an assay. Vector retrieval is a starting point, not a replacement for experimental validation.

Implementation-level gotchas:

  • Canonical SMILES can differ across RDKit versions. Pin your version and keep it consistent between indexing and query environments.
  • HNSW is approximate. If you need high recall, increase ef at query time.
  • Salt-containing SMILES like [Na+].[Cl-] will be parsed as a whole fragment including the counterion. Add rdkit.Chem.SaltRemover.SaltRemover to preprocessing if that matters for your data.
  • The token limit for ChemBERTa is 512. The production implementation rejects SMILES over 400 characters before tokenization. The standalone script uses truncation=True, which silently cuts long inputs. Know which behavior you are running.

Where to Take This Next

Full code: github.com/dvy246/qdrant

Three directions worth thinking about.

Fine-tuning. The base model has no activity labels, so the similarity it measures is purely structural. Fine-tuning ChemBERTa on specific assay data with a contrastive objective — so active molecules pull together in vector space — is the most direct way to make retrieval useful for a specific target. Without that, you are doing structural exploration, not activity prediction.

Automated ingestion. Connect Qdrant updates into your compound registration workflow so new molecules get indexed as they come in rather than in periodic batch jobs. Qdrant's deterministic UUID approach makes this safe. Re-indexing the same molecule just overwrites.

Target-aware search. Add protein pocket embeddings alongside molecular embeddings so retrieval accounts for both the molecule and the target. That starts to look more like a real virtual screening tool.


Conclusion

The experiment worked, and honestly it worked better than I expected.

The pipeline loads 2000 molecules from ZINC-250k, validates and canonicalizes each SMILES with RDKit (computing a heuristic toxicity score per molecule while doing it), converts each structure into a 768-dimensional vector using ChemBERTa, stores that in Qdrant with molecular property metadata, and retrieves structurally similar candidates using cosine similarity. Filters on molecular weight, LogP, and toxicity score apply during retrieval rather than after it. Final ranking combines embedding similarity with a Tanimoto fingerprint score so you get both the learned structural space and explicit substructure overlap in one result.

This does not replace fingerprint search. For exact scaffold matching, Tanimoto is faster, simpler, and more interpretable. What this gives you is a way to find candidates that share structural and physicochemical patterns even when their fingerprints look different. That is the case where fingerprint search fails.

The hard parts were not the transformer or the vector database. They were the data cleaning, the SMILES canonicalization edge cases, and deciding when to trust the similarity scores. Those are the things worth understanding before scaling this up.

If you are a developer exploring vector search beyond text, molecules turn out to be a surprisingly good playground. The data is structured, the search requirements are concrete, and the gap between what fingerprints can do and what embeddings can do is real enough to make the experiment worthwhile.


References

  1. ChemBERTa paper — arXiv:2010.09885
  2. ChemBERTa model card — seyonec/ChemBERTa-zinc-base-v1
  3. ZINC-250k dataset
  4. RDKit documentation
  5. Qdrant documentation
  6. Qdrant Query API
  7. qdrant-client on PyPI
  8. Hugging Face Transformers
  9. FastAPI documentation
  10. Streamlit documentation
  11. PyTorch documentation
  12. ZINC database
  13. HNSW algorithm — arXiv:1603.09320
  14. MoleculeNet benchmarks
  15. Sentence-BERT — arXiv:1908.10084
  16. PEP 604 — Union types

Top comments (0)