DEV Community

Samcorp
Samcorp

Posted on • Edited on

How do I implement multi-turn memory in AI chatbots?

Here’s a practical, production-ready pattern for multi-turn memory that you can drop into a small API. It combines:

Short-term memory: rolling chat window (last K turns)

Conversation summaries: to keep context small

Long-term memory: vector store of user facts & past topics

Entity memory: lightweight key→value store (name, timezone, preferences)

Below is a complete FastAPI service with a simple SQLite + FAISS store. It’s model-agnostic, but an OpenAI adapter is included for convenience.

1) Files & setup

requirements.txt

fastapi
uvicorn
pydantic
openai>=1.30.0
python-dotenv
sqlalchemy
faiss-cpu
sentence-transformers


> .env

OPENAI_API_KEY=sk-...

#change these if you like:
OPENAI_CHAT_MODEL=gpt-4o-mini
OPENAI_EMBED_MODEL=text-embedding-3-small

Enter fullscreen mode Exit fullscreen mode

2) Data model

  • messages: all user/assistant turns (for short-term + summarization)
  • entities: simple key/value facts per user (entity memory)
  • memories: vector index of long-term memories with embedding + text

3) App code

app.py

import os
import time
from typing import List, Optional, Tuple
from dataclasses import dataclass

from dotenv import load_dotenv
from fastapi import FastAPI
from pydantic import BaseModel
from sqlalchemy import (create_engine, Column, Integer, String, Text, Float,
                        ForeignKey, select, func)
from sqlalchemy.orm import declarative_base, sessionmaker, relationship
import numpy as np

# Embeddings & LLM
import faiss
from sentence_transformers import SentenceTransformer
from openai import OpenAI

load_dotenv()

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
CHAT_MODEL = os.getenv("OPENAI_CHAT_MODEL", "gpt-4o-mini")
EMBED_MODEL_NAME = os.getenv("OPENAI_EMBED_MODEL", "text-embedding-3-small")

# --- DB setup ---
Base = declarative_base()
engine = create_engine("sqlite:///memory.db", echo=False, future=True)
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)

class User(Base):
    __tablename__ = "users"
    id = Column(String, primary_key=True)          # app-level user id
    created_at = Column(Float, default=lambda: time.time())
    messages = relationship("Message", back_populates="user", cascade="all, delete-orphan")
    entities = relationship("Entity", back_populates="user", cascade="all, delete-orphan")
    memories = relationship("Memory", back_populates="user", cascade="all, delete-orphan")

class Message(Base):
    __tablename__ = "messages"
    id = Column(Integer, primary_key=True, autoincrement=True)
    user_id = Column(String, ForeignKey("users.id"), index=True)
    role = Column(String)  # "user" or "assistant" or "system"
    text = Column(Text)
    created_at = Column(Float, default=lambda: time.time())
    user = relationship("User", back_populates="messages")

class Entity(Base):
    __tablename__ = "entities"
    id = Column(Integer, primary_key=True, autoincrement=True)
    user_id = Column(String, ForeignKey("users.id"), index=True)
    key = Column(String, index=True)
    value = Column(Text)
    updated_at = Column(Float, default=lambda: time.time())
    user = relationship("User", back_populates="entities")

class Memory(Base):
    __tablename__ = "memories"
    id = Column(Integer, primary_key=True, autoincrement=True)
    user_id = Column(String, ForeignKey("users.id"), index=True)
    text = Column(Text)
    kind = Column(String, default="fact")  # "fact", "topic", "preference", etc.
    vector = Column(Text)  # store as comma-joined string
    score = Column(Float, default=0.0)
    updated_at = Column(Float, default=lambda: time.time())
    user = relationship("User", back_populates="memories")

Base.metadata.create_all(engine)

# --- Embeddings ---
# We use local sentence-transformers for FAISS indexing (fast + private),
# and we store the same texts; you can swap to API embeddings if you prefer.
embedder = SentenceTransformer("all-MiniLM-L6-v2")

# FAISS index per process (we rebuild from DB on boot and when needed)
@dataclass
class FaissBundle:
    index: faiss.IndexFlatIP
    ids: List[int]  # Memory.id order aligned to FAISS

def normalize(v: np.ndarray) -> np.ndarray:
    norms = np.linalg.norm(v, axis=1, keepdims=True) + 1e-10
    return v / norms

def build_faiss_for_user(db, user_id: str) -> FaissBundle:
    mems = db.execute(select(Memory).where(Memory.user_id == user_id)).scalars().all()
    if not mems:
        return FaissBundle(faiss.IndexFlatIP(384), [])
    texts = [m.text for m in mems]
    vecs = embedder.encode(texts, convert_to_numpy=True)
    vecs = normalize(vecs.astype("float32"))
    index = faiss.IndexFlatIP(vecs.shape[1])
    index.add(vecs)
    return FaissBundle(index=index, ids=[m.id for m in mems])

# --- LLM client (OpenAI as example; swap to your provider easily) ---
client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None

def chat_completion(messages: List[dict], model: str = CHAT_MODEL, max_tokens: int = 600) -> str:
    if not client:
        # Offline/dev fallback so the server runs without a key
        return "(LLM disabled) You asked: " + messages[-1]["content"]
    resp = client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=0.4,
        max_tokens=max_tokens,
    )
    return resp.choices[0].message.content

def extract_salient_memories(latest_user_text: str) -> List[str]:
    """
    Ask the LLM to extract durable facts/preferences to remember.
    Keep it minimal; you can add rules per domain.
    """
    prompt = [
        {"role": "system", "content": (
            "You extract durable, reusable facts from a single user message. "
            "Return 0–5 bullet-worthy snippets, each a single short sentence. "
            "Only include things that would still be useful weeks later "
            "(preferences, profile facts, long-term goals, stable constraints). "
            "If nothing durable, return 'NONE'."
        )},
        {"role": "user", "content": latest_user_text}
    ]
    text = chat_completion(prompt, max_tokens=120)
    if "NONE" in text.strip().upper():
        return []
    # very light parsing; assumes bullets or sentences
    lines = [l.strip(" -•\t") for l in text.splitlines() if l.strip()]
    return [l for l in lines if len(l) > 2][:5]

def summarize_conversation(history: List[Tuple[str, str]]) -> str:
    """
    Summarize a long chat when token budget is tight.
    history: list of (role, text)
    """
    parts = "\n".join([f"{r.upper()}: {t}" for r, t in history[-30:]])
    prompt = [
        {"role": "system", "content": (
            "Summarize the conversation so far into 5–8 crisp bullets, "
            "preserving decisions, plans, constraints, and unresolved questions."
        )},
        {"role": "user", "content": parts}
    ]
    return chat_completion(prompt, max_tokens=180)

# --- Memory manager ---
class MemoryManager:
    def __init__(self):
        self._faiss_cache = {}  # user_id -> FaissBundle

    def ensure_user(self, db, user_id: str):
        user = db.get(User, user_id)
        if not user:
            user = User(id=user_id)
            db.add(user)
            db.commit()
        if user_id not in self._faiss_cache:
            self._faiss_cache[user_id] = build_faiss_for_user(db, user_id)
        return user

    def upsert_entity(self, db, user_id: str, key: str, value: str):
        ent = db.execute(select(Entity).where(
            Entity.user_id == user_id, Entity.key == key
        )).scalar_one_or_none()
        if ent:
            ent.value = value
            ent.updated_at = time.time()
        else:
            ent = Entity(user_id=user_id, key=key, value=value)
            db.add(ent)
        db.commit()

    def search_memories(self, db, user_id: str, query: str, k: int = 5) -> List[Memory]:
        bundle = self._faiss_cache.get(user_id)
        if not bundle or bundle.index.ntotal == 0:
            return []
        qv = embedder.encode([query], convert_to_numpy=True).astype("float32")
        qv = normalize(qv)
        scores, idx = bundle.index.search(qv, k)
        results = []
        for rank in idx[0]:
            if rank == -1: 
                continue
            mem_id = bundle.ids[rank]
            mem = db.get(Memory, mem_id)
            if mem:
                results.append(mem)
        return results

    def add_memories(self, db, user_id: str, texts: List[str], kind: str = "fact"):
        if not texts:
            return
        vecs = embedder.encode(texts, convert_to_numpy=True).astype("float32")
        vecs = normalize(vecs)
        new_ids = []
        for text, vec in zip(texts, vecs):
            m = Memory(
                user_id=user_id,
                text=text.strip(),
                kind=kind,
                vector=",".join(map(str, vec.tolist())),
                score=0.0
            )
            db.add(m)
            db.flush()
            new_ids.append(m.id)
        db.commit()
        # rebuild FAISS for this user (simple + safe)
        self._faiss_cache[user_id] = build_faiss_for_user(db, user_id)

memory_mgr = MemoryManager()

# --- FastAPI ---
app = FastAPI(title="Chat with Multi-Turn Memory")

class ChatRequest(BaseModel):
    user_id: str
    message: str
    # optional hints for entity memory
    user_name: Optional[str] = None
    timezone: Optional[str] = None

class ChatResponse(BaseModel):
    reply: str
    used_memories: List[str]
    summary_used: bool

SYSTEM_GUARDRAILS = (
    "You are a helpful assistant. "
    "Use retrieved memories if relevant. "
    "Be concise and avoid repeating the user."
)

MAX_WINDOW = 8   # last 8 turns kept verbatim before summarizing
MAX_TOKENS_BUDGETED = 4096  # conceptual; we’re using it to decide when to summarize

@app.post("/chat", response_model=ChatResponse)
def chat(req: ChatRequest):
    db = SessionLocal()
    try:
        user = memory_mgr.ensure_user(db, req.user_id)

        # Optional: update entity memory
        if req.user_name:
            memory_mgr.upsert_entity(db, req.user_id, "name", req.user_name)
        if req.timezone:
            memory_mgr.upsert_entity(db, req.user_id, "timezone", req.timezone)

        # Save incoming user message
        db.add(Message(user_id=req.user_id, role="user", text=req.message))
        db.commit()

        # 1) Extract durable facts from the new message and store to long-term memory
        facts = extract_salient_memories(req.message)
        memory_mgr.add_memories(db, req.user_id, facts, kind="fact")

        # 2) Retrieve relevant long-term memories for the current query
        retrieved = memory_mgr.search_memories(db, req.user_id, req.message, k=5)
        retrieved_texts = [f"- {m.text}" for m in retrieved]

        # 3) Build short-term context: last K turns (user+assistant)
        all_msgs = db.execute(select(Message).where(
            Message.user_id == req.user_id
        ).order_by(Message.created_at.asc())).scalars().all()
        turns = [(m.role, m.text) for m in all_msgs]

        window = turns[-(MAX_WINDOW*2):]  # roughly last K exchanges
        summary_used = False
        summary_text = ""
        # Optional: summarize if conversation is getting long
        if len(turns) > MAX_WINDOW * 2 + 2:
            summary_text = summarize_conversation(turns[:- (MAX_WINDOW*2)])
            summary_used = True

        # 4) Pull entity memory
        entities = db.execute(select(Entity).where(Entity.user_id == req.user_id)).scalars().all()
        entity_lines = [f"{e.key}: {e.value}" for e in entities]

        # 5) Compose final prompt
        messages = [{"role": "system", "content": SYSTEM_GUARDRAILS}]
        if entity_lines:
            messages.append({"role": "system", "content": "Known user entities:\n" + "\n".join(entity_lines)})
        if retrieved_texts:
            messages.append({"role": "system", "content": "Relevant long-term memories:\n" + "\n".join(retrieved_texts)})
        if summary_text:
            messages.append({"role": "system", "content": "Conversation summary so far:\n" + summary_text})

        for r, t in window:
            messages.append({"role": r, "content": t})

        # 6) Generate answer
        reply = chat_completion(messages)

        # 7) Save assistant reply
        db.add(Message(user_id=req.user_id, role="assistant", text=reply))
        db.commit()

        return ChatResponse(
            reply=reply,
            used_memories=[m.text for m in retrieved],
            summary_used=summary_used
        )
    finally:
        db.close()

Enter fullscreen mode Exit fullscreen mode

fastapi-env-api-keys-screenshot.png

4) Run it

pip install -r requirements.txt
uvicorn app:app --reload --port 8000
Enter fullscreen mode Exit fullscreen mode

Query it:

curl -X POST http://localhost:8000/chat \
  -H "Content-Type: application/json" \
  -d '{
    "user_id": "u123",
    "user_name": "Sanket",
    "timezone": "Asia/Kolkata",
    "message": "I prefer concise answers and dark UI themes. Also remind me to ship the Odoo article on Friday."
  }'
Enter fullscreen mode Exit fullscreen mode

Then another turn:

curl -X POST http://localhost:8000/chat \
  -H "Content-Type: application/json" \
  -d '{
    "user_id": "u123",
    "message": "What were my preferences again? And help me plan the Odoo article outline."
  }'

Enter fullscreen mode Exit fullscreen mode

You’ll see used_memories include the preference facts pulled from long-term memory, while the rolling window + summary keep the LLM context tight.

How it works (brief)

1. On each user turn

  • Save the raw message.
  • Ask the LLM to extract durable facts (preferences, profile, long-term goals). Store them in memories + FAISS.

2. Before responding

  • Retrieve top-k memories by semantic similarity to the new message.
  • Build the prompt from:
  • Guardrails
  • Entity memory
  • Retrieved long-term memories
  • Conversation summary (if history is long)
  • Recent short-term turns (last K exchanges)

3. After responding

  • Save the assistant message.

Tweaks you’ll likely add

  • Expiry & decay: lower score or prune old memories unless re-used.
  • Memory categories: preference, identity, project, task, etc., and retrieve by type.
  • Task memory: store “open loops” and have the bot proactively follow up.
  • Privacy switches: only store memories if the user opted in.
  • RAG: add a document retriever alongside personal memories.

Minimal no-DB variant (for prototypes)

If you just need short-term memory with summarization (no vector store), keep only a deque of messages and periodically compress it with the summarize_conversation function. That alone handles many chat UX cases.

Multi-turn memory is one of the biggest challenges in building production-grade AI assistants. While this guide shows a practical FastAPI pattern, in real-world projects we often combine this with advanced pipelines and deployment workflows. For end-to-end solutions, you can explore AI development that apply these techniques at scale.

Top comments (0)