DEV Community

Joseph Joshua
Joseph Joshua

Posted on

FastAPI + PydanticAI + a2a-protocol

This post give a full implementation of AI agent with pydantic AI.

An AI agent combines the function of a LLM with tools that helps the AI interact with the real world.

Firstly, create your AI agent implemention

# Python import
import os
from typing import List, Optional
from uuid import uuid4

# Library import 
from pydantic_ai import Agent
from pydantic_ai.models.gemini import GeminiModel
from pydantic_ai.models.google import GoogleModel
from pydantic_ai.providers.google import GoogleProvider
from dotenv import load_dotenv
from fastapi.exceptions import HTTPException

# Module import
from models import A2AMessage, GrammarResponse, MessageConfiguration, MessagePart, TaskResult, TaskStatus

load_dotenv()


class GrammarAgent:
    SYSTEM_INSTRUCTIONS = (
            "You are a specialized assistant that helps users correct grammar, spelling, "
            "and phrasing mistakes in text"
            "Your goal is to return correct sentence and explanation"
            "If users provides unrelated topics, politely state that you can only help with grammar or writing task"
            )

    def __init__(self):

        provider = GoogleProvider(api_key=os.getenv("GOOGLE_API_KEY", "no Key"))

        model = GoogleModel("gemini-2.0-flash", provider=provider)

        self.agent = Agent(
                model=model,
                output_type=GrammarResponse,
                system_prompt=self.SYSTEM_INSTRUCTIONS
                )

    async def run(self, message: A2AMessage, context_id: Optional[str] = None, task_id: Optional[str] = None, config: Optional[MessageConfiguration] = None):

        context_id = context_id or str(uuid4())
        task_id = task_id or str(uuid4())

        user_messages = message.parts

        if not user_messages:
            raise ValueError("No message provided")

        # handle last message part
        last_part = user_messages[-1]

        user_text = ""

        if hasattr(last_part, "kind") and last_part.kind == "text":
            user_text = getattr(last_part, "text", "")
        elif hasattr(last_part, "data") and last_part.data:
            data_part = last_part.data[-1]
            if isinstance(data_part, dict) and data_part.get("kind") == "text":
                user_text = data_part.get("text", "").strip()
        else:
            user_text = ""

        if not user_text:
            raise ValueError("No text provided")

        try:
            response = await self.agent.run(user_prompt=user_text)

            response_message = A2AMessage(
                    role="agent",
                    parts=[MessagePart(kind="text", text=response.output.model_dump_json())],
                    taskId=task_id
                    )
            history = [message, response_message]

            task_result = TaskResult(
                    id=task_id,
                    contextId=context_id,
                    status=TaskStatus(state="completed", message=response_message),
                    history=history
                    )

            return task_result
        except Exception as e:
            print(e)
            raise HTTPException(status_code=500, detail=f"internal server error: {str(e)}")
Enter fullscreen mode Exit fullscreen mode

then implement the API endpoint that expose the agent

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from contextlib import asynccontextmanager
import uvicorn
import os
from models import A2AMessage, JSONRPCRequest, JSONRPCResponse
from agent import GrammarAgent

grammar_agent = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    global grammar_agent

    grammar_agent = GrammarAgent()

    yield

    if grammar_agent:
        grammar_agent = None

app = FastAPI(title="Grammar Agent", description="Ai agent for grammatical correction", version="1.0.0", lifespan=lifespan)

@app.post("/a2a/grammar-check")
async def grammar_check(request: Request):
    try:
        body = await request.json()

        if body.get("jsonrpc") != "2.0" or "id" not in body:
            return JSONResponse(
                status_code=400,
                content={
                    "jsonrpc": "2.0",
                    "id": body.get("id"),
                    "error": {
                        "code": -32600,
                        "message": "Invalid Request: jsonrpc must be '2.0' and id is required"
                    }
                }
            )
        rpc_request = JSONRPCRequest(**body)

        messages = []
        context_id = None
        task_id = None
        config = None


        if rpc_request.method == "message/send":
            messages = rpc_request.params.message
            config = rpc_request.params.configuration

        elif rpc_request.method == "execute":
            messages = rpc_request.params.messages
            context_id = rpc_request.params.contextId
            task_id = rpc_request.params.taskId

        result = await grammar_agent.run(
            message=messages,
            context_id=context_id,
            task_id=task_id,
            config=config
        )

        response = JSONRPCResponse(
            id=rpc_request.id,
            result=result
        )

        return response.model_dump()

    except Exception as e:
        return JSONResponse(
            status_code=500,
            content={
                "jsonrpc": "2.0",
                "id": None,
                "error": {
                    "code": -32000,
                    "message": str(e)
                }
            }
        )


if __name__ == "__main__":
    port = int(os.getenv("PORT", 5000))
    uvicorn.run("main:app", host="127.0.0.1", port=port, reload=True)
Enter fullscreen mode Exit fullscreen mode

Then set up all necessary schemas for validation

from pydantic import BaseModel, Field
from typing import Literal, Optional, List, Dict, Any
from datetime import datetime
from uuid import uuid4

class GrammarResponse(BaseModel):
    response: str
    explanation: str

class MessagePart(BaseModel):
    kind: Literal["text", "data"]
    text: Optional[str] = None
    data: Optional[List[Dict[str, Any]]] = None

class A2AMessage(BaseModel):
    kind: Literal["message"] = "message"
    role: Literal["user", "agent", "system"]
    parts: List[MessagePart]
    messageId: str = Field(default_factory=lambda: str(uuid4()))
    taskId: Optional[str] = None
    metadata: Optional[Dict[str, Any]] = None

class PushNotificationConfig(BaseModel):
    url: str
    token: Optional[str] = None
    authentication: Optional[Dict[str, Any]] = None

class MessageConfiguration(BaseModel):
    blocking: bool = True
    acceptedOutputModes: List[str] = ["text/plain", "image/png", "image/svg+xml"]
    pushNotificationConfig: Optional[PushNotificationConfig] = None

class MessageParams(BaseModel):
    message: A2AMessage
    configuration: MessageConfiguration = Field(default_factory=MessageConfiguration)

class ExecuteParams(BaseModel):
    contextId: Optional[str] = None
    taskId: Optional[str] = None
    messages: List[A2AMessage]

class JSONRPCRequest(BaseModel):
    jsonrpc: Literal["2.0"]
    id: str
    method: Literal["message/send", "execute"]
    params: MessageParams | ExecuteParams

class TaskStatus(BaseModel):
    state: Literal["working", "completed", "input-required", "failed"]
    timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
    message: Optional[A2AMessage] = None


class Artifact(BaseModel):
    artifactId: str = Field(default_factory=lambda: str(uuid4()))
    name: str
    parts: List[MessagePart]

class TaskResult(BaseModel):
    id: str
    contextId: str
    status: TaskStatus
    artifacts: List[Artifact] = []
    history: List[A2AMessage] = []
    kind: Literal["task"] = "task"

class JSONRPCResponse(BaseModel):
    jsonrpc: Literal["2.0"] = "2.0"
    id: str
    result: Optional[TaskResult] = None
    error: Optional[Dict[str, Any]] = None
Enter fullscreen mode Exit fullscreen mode

Move forward to set up .env variable
GOOGLE_API_KEY=YOUAPIKEY

Then you can run your code and enjoy after importing all required dependencies.

Top comments (0)