Pydantic V2 Discriminated Unions in FastAPI: Modeling Polymorphic AI Feature Configs Without Schema Sprawl
I've built nine AI features into CitizenApp, and each one has a wildly different configuration shape. A summarization feature needs max_tokens and style. A classifier needs labels and confidence_threshold. A generator needs temperature, system_prompt, and output_format.
For months, I solved this with if/elif chains in my route handlers. It was a disaster. Schema mismatches lived in production. Clients sent invalid configs that slipped past validation. I'd catch them at runtime inside the Claude API call—expensive, embarrassing, and hard to debug.
Then I switched to Pydantic V2's discriminated unions. Now my FastAPI routes are one-liners. My database queries are type-safe. And every schema mismatch gets caught at the HTTP layer, not buried in a traceback three API calls deep.
This is how I'd tell my past self to do it.
The Problem: Polymorphism Without a Type System Is Just if/elif
Here's what my old code looked like:
# ❌ Before: Polymorphism via conditionals
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
app = FastAPI()
class FeatureConfig(BaseModel):
feature_type: str
config: dict # God object antipattern
@app.post("/features")
async def create_feature(payload: FeatureConfig):
if payload.feature_type == "summarization":
if "max_tokens" not in payload.config:
raise HTTPException(400, "missing max_tokens")
max_tokens = payload.config["max_tokens"]
# ...
elif payload.feature_type == "classification":
if "labels" not in payload.config:
raise HTTPException(400, "missing labels")
labels = payload.config["labels"]
# ...
else:
raise HTTPException(400, "unknown feature type")
This burns in three ways:
-
No type safety.
payload.configis adict. The type checker doesn't know what keys exist. You discover missing fields at runtime. - Validation logic lives in handlers. Every endpoint that touches features repeats the same checks. One team member forgets a validation step, and garbage flows into the database.
-
Clients guess the schema. Your API docs don't tell them what a
classificationconfig actually needs. They trial-and-error until something works.
The Solution: Discriminated Unions Force the Shape
Pydantic V2's Discriminator field makes polymorphism a first-class citizen. You define each variant as its own model, tag it with a discriminator field, and Pydantic does the rest:
# ✅ After: Discriminated unions
from pydantic import BaseModel, Field
from typing import Annotated, Literal, Union
from fastapi import FastAPI
app = FastAPI()
# Each feature type is its own model
class SummarizationConfig(BaseModel):
feature_type: Literal["summarization"]
max_tokens: int = Field(gt=0, le=4096)
style: Literal["bullet", "paragraph", "executive"] = "paragraph"
class ClassificationConfig(BaseModel):
feature_type: Literal["classification"]
labels: list[str] = Field(min_items=2, max_items=50)
confidence_threshold: float = Field(ge=0, le=1)
allow_multi_label: bool = False
class GenerationConfig(BaseModel):
feature_type: Literal["generation"]
temperature: float = Field(ge=0, le=2)
system_prompt: str = Field(min_length=10)
output_format: Literal["text", "json", "markdown"] = "text"
max_tokens: int = Field(gt=0, le=8000)
# Union of all variants, discriminated by feature_type
AIFeatureConfig = Annotated[
Union[SummarizationConfig, ClassificationConfig, GenerationConfig],
Field(discriminator="feature_type")
]
@app.post("/features")
async def create_feature(config: AIFeatureConfig):
# config is already the correct type
# Type checker knows exactly what fields exist
if isinstance(config, SummarizationConfig):
print(config.max_tokens) # Type checker sees this exists
print(config.style)
elif isinstance(config, ClassificationConfig):
print(config.labels)
print(config.confidence_threshold)
elif isinstance(config, GenerationConfig):
print(config.temperature)
print(config.system_prompt)
What just happened:
-
Validation is automatic. Pydantic reads
feature_type, routes to the correct model, validates all fields. If a client sendsclassificationwith missinglabels, they get a422response with a clear error message—before your handler runs. -
Type narrowing works. Once you check
isinstance(config, SummarizationConfig), the type checker knowsconfig.max_tokensexists. Nodictcasting, no runtime guessing. - The API contract is self-documenting. Your OpenAPI schema now has three distinct input shapes, each with its own validation rules. Clients can read the docs and know exactly what they need.
Real-World: Database Storage and Retrieval
In CitizenApp, feature configs live in PostgreSQL as JSONB. Here's how I handle polymorphic queries:
# models.py
from sqlalchemy import Column, String, JSON
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
pass
class AIFeature(Base):
__tablename__ = "ai_features"
id = Column(String, primary_key=True)
feature_type = Column(String, nullable=False, index=True)
config = Column(JSON, nullable=False) # Stored as JSONB
# crud.py
from sqlalchemy.orm import Session
from pydantic import ValidationError
async def create_feature(db: Session, config: AIFeatureConfig) -> AIFeature:
"""
The discriminated union validated the config shape.
Now we just store it.
"""
db_feature = AIFeature(
id=generate_id(),
feature_type=config.feature_type,
config=config.model_dump() # Always valid
)
db.add(db_feature)
db.commit()
return db_feature
async def get_feature(db: Session, feature_id: str) -> AIFeatureConfig:
"""
Retrieve from DB and re-validate against the union.
"""
row = db.query(AIFeature).filter(AIFeature.id == feature_id).one()
# This will fail loudly if the DB contains a schema mismatch
# (which shouldn't happen, but you catch accidental updates)
config = AIFeatureConfig.model_validate(
{"feature_type": row.feature_type, **row.config}
)
return config
Why this matters: If someone accidentally updates the database with a malformed config, model_validate catches it immediately. You don't ship a broken feature to production because a schema migration went sideways.
The Claude Integration: Type-Safe Prompts
Now that config is type-safe, Claude calls are cleaner:
# ai_service.py
import anthropic
async def invoke_feature(config: AIFeatureConfig, input_text: str) -> str:
"""
Type narrowing means we know exactly what fields exist.
No runtime schema lookups, no conditional prompt building.
"""
client = anthropic.Anthropic()
if isinstance(config, SummarizationConfig):
prompt = f"""Summarize the following in {config.style} format.
Max output: {config.max_tokens} tokens.
{input_text}"""
elif isinstance(config, ClassificationConfig):
prompt = f"""Classify the following text into one of these categories:
{', '.join(config.labels)}
Confidence threshold: {config.confidence_threshold}
Multi-label allowed: {config.allow_multi_label}
{input_text}"""
elif isinstance(config, GenerationConfig):
prompt = f"""{config.system_prompt}
User input: {input_text}
Respond in {config.output_format} format."""
response = client.messages.create(
model="claude-3-5-sonnet-20241022",
max_tokens=config.max_tokens if hasattr(config, "max_tokens") else 1024,
messages=[{"role": "user", "content": prompt}]
)
return response.content[0].text
This is the real win: once validation passes, your business logic doesn't need defensive coding. No checking if fields exist. No runtime type guessing. Just type-safe field access.
Gotcha: The Discriminator Must Be Consistent
I learned this the hard way. If you have nested discriminated unions, the discriminator field must be the same across all levels:
# ❌ This breaks
class OuterConfig(BaseModel):
config_type: Literal["outer"] # Different name!
inner: AIFeatureConfig
# ✅ This works
class OuterConfig(BaseModel):
feature_type: Literal["outer"] # Same discriminator name
inner: AIFeatureConfig
Also: Pydantic uses the discriminator value for routing, so if your models have overlapping Literal values, validation becomes ambiguous. Keep discriminator values unique across your entire union tree.
Missing: Versioning Polymorphic Schemas
I didn't plan for schema evolution. A year in, I needed to add a new field to ClassificationConfig. I should have baked in a schema_version field from the start:
class ClassificationConfig(BaseModel):
feature_type: Literal["classification"]
schema_version: Literal[1] = 1
labels: list[str]
confidence_threshold: float
Then when v2 ships, I can fork the union and migrate gradually.
Top comments (0)