Introduction
How to track AI model costs in real-time using Strands hooks and avoid budget surprises
The Problem: AI Costs Can Spiral Out of Control
You've built an amazing AI agent using Strands, and it's working beautifully. Users are chatting with it, getting helpful responses, and everything seems great. But then you check your AWS bill and see hundreds of dollars in unexpected charges. Sound familiar?
That's exactly what happened to me when I first deployed a Strands agent to production. Without proper cost monitoring, I had no visibility into:
- How many tokens each conversation was consuming
- Which interactions were the most expensive
- Whether costs were trending upward
- If there were any cost anomalies
- I needed a way to track model usage costs in real-time, and Strands hooks provided the perfect solution.
The Solution: Real-Time Cost Monitoring with Hooks
Strands hooks give you the ability to intercept and monitor every model call your agent makes. By using BeforeModelCallEvent and AfterModelCallEvent, we can track token usage and calculate costs in real-time.
The beauty of this approach is that it's completely non-intrusive - your agent doesn't need to know about cost tracking, and you can add or remove it without changing your core agent code.
Why Build Custom Cost Monitoring When AWS Has It?
You might be wondering: AWS Console already provides cost tracking, so why build custom monitoring?" This is an excellent question, and the answer reveals why custom cost monitoring is essential for AI applications.
Real-Time vs. Delayed Reporting
AWS Console: Shows costs with 24-48 hour delays
Our Hooks: Real-time cost tracking during development and testing
Benefit: Catch expensive operations immediately, not days later
Granular vs. Aggregate Data
AWS Console: Shows total costs by service/region
Our Hooks: Shows cost per conversation, per user, per interaction
Benefit: Understand which specific conversations are expensive
Development vs. Production Monitoring
AWS Console: Great for production cost analysis
Our Hooks: Essential for development and testing
Benefit: Prevent cost surprises during development
Actionable vs. Historical Data
AWS Console: "You spent $500 last month"
Our Hooks: "This conversation just cost $0.15 - should we optimize it?"
Benefit: Make real-time decisions about conversation flow
Custom Alerts and Budgets
AWS Console: Basic budget alerts
Our Hooks: Custom logic like "alert if single conversation > $1"
Benefit: Prevent runaway costs from specific user interactions
The Architecture: How Cost Monitoring Works
Let me walk you through the complete cost monitoring system I built:
1. The CostTrackingHooks Class
class CostTrackingHooks(HookProvider):
"""
A HookProvider that tracks model usage costs and token consumption.
Demonstrates BeforeModelCallEvent and AfterModelCallEvent usage.
"""
def __init__(self, cost_file="model_costs.json"):
self.cost_file = cost_file
self.tokens_sent = 0
self.tokens_received = 0
self.total_cost = 0.0
self.model_calls = 0
self.current_user_input = None
self._initialize_cost_file()
# Strands Agents default model pricing (as of 2025)
self.cost_per_1k_tokens = {
# Claude 4 Sonnet (Strands default model)
"us.anthropic.claude-sonnet-4-20250514-v1:0": 0.003,
"claude-sonnet-4-20250514": 0.003,
"claude-4-sonnet": 0.003,
# Claude 3.5 models
"claude-3-5-sonnet-20241022": 0.003,
"claude-3-5-haiku-20241022": 0.00025,
"claude-3-opus-20240229": 0.015,
# Default rate
"default": 0.003,
"unknown": 0.003
}
This is the heart of the cost monitoring system. It extends Strands' HookProvider class and manages all the cost tracking logic.
2. The Before Hook: Capturing Input Tokens
def before_model_call(self, event: BeforeModelCallEvent) -> None:
"""Hook called before model invocation."""
print("🔧 MODEL HOOK: Tracking model call...")
# Try to get prompt from event
prompt_text = ""
if hasattr(event, 'prompt') and event.prompt:
prompt_text = str(event.prompt)
elif hasattr(event, 'message') and event.message:
prompt_text = str(event.message)
elif hasattr(event, 'messages') and event.messages:
prompt_text = " ".join([str(msg) for msg in event.messages])
if prompt_text:
self.tokens_sent = self._estimate_tokens(prompt_text)
print(f" Estimated input tokens: {self.tokens_sent}")
else:
# Fallback: use stored user input
if self.current_user_input:
self.tokens_sent = self._estimate_tokens(self.current_user_input)
print(f" Estimated input tokens (from stored input): {self.tokens_sent}")
else:
self.tokens_sent = 0
# Get model name and cost rate
model_name = self._get_model_name(event)
cost_rate = self.cost_per_1k_tokens.get(model_name, self.cost_per_1k_tokens['unknown'])
print(f" Model: {model_name}")
print(f" Cost rate: ${cost_rate:.6f} per 1K tokens")
This hook fires right before your agent calls the model. It captures the input tokens and identifies which model is being used.
3. The After Hook: Calculating Costs
def after_model_call(self, event: AfterModelCallEvent) -> None:
"""Hook called after model invocation."""
print("📊 MODEL HOOK: Calculating costs...")
if event.stop_response:
# Extract response content
response_content = self._extract_response_content(event.stop_response)
if response_content:
self.tokens_received = self._estimate_tokens(response_content)
print(f" Estimated output tokens: {self.tokens_received}")
# Calculate cost
total_tokens = self.tokens_sent + self.tokens_received
model_name = getattr(self, 'current_model', 'unknown')
cost_rate = self.cost_per_1k_tokens.get(model_name, self.cost_per_1k_tokens['unknown'])
call_cost = (total_tokens / 1000) * cost_rate
self.total_cost += call_cost
self.model_calls += 1
print(f" Total tokens: {total_tokens}")
print(f" Cost for this call: ${call_cost:.6f}")
print(f" Running total cost: ${self.total_cost:.6f}")
# Log to file
self._log_cost_data(model_name, total_tokens, call_cost)
This hook fires after the model responds. It calculates the total cost and logs everything to a JSON file.
The Implementation: Complete Working Code
Here's the complete cost monitoring implementation:
import json
import time
import os
from datetime import datetime
from strands.agent import Agent
from strands.hooks import HookProvider, HookRegistry
from strands.hooks.events import (
BeforeInvocationEvent,
AfterInvocationEvent,
BeforeModelCallEvent,
AfterModelCallEvent
)
class CostTrackingHooks(HookProvider):
def __init__(self, cost_file="model_costs.json"):
self.cost_file = cost_file
self.tokens_sent = 0
self.tokens_received = 0
self.total_cost = 0.0
self.model_calls = 0
self.current_user_input = None
self.current_model = None
self._initialize_cost_file()
# Strands Agents default model pricing
self.cost_per_1k_tokens = {
"us.anthropic.claude-sonnet-4-20250514-v1:0": 0.003,
"claude-sonnet-4-20250514": 0.003,
"claude-4-sonnet": 0.003,
"claude-3-5-sonnet-20241022": 0.003,
"claude-3-5-haiku-20241022": 0.00025,
"claude-3-opus-20240229": 0.015,
"default": 0.003,
"unknown": 0.003
}
def register_hooks(self, registry: HookRegistry) -> None:
"""Register hooks for model call events."""
registry.add_callback(BeforeModelCallEvent, self.before_model_call)
registry.add_callback(AfterModelCallEvent, self.after_model_call)
print("💰 Cost tracking hooks registered successfully!")
def before_model_call(self, event: BeforeModelCallEvent) -> None:
"""Hook called before model invocation."""
print("🔧 MODEL HOOK: Tracking model call...")
# Try to get prompt from event
prompt_text = ""
if hasattr(event, 'prompt') and event.prompt:
prompt_text = str(event.prompt)
elif hasattr(event, 'message') and event.message:
prompt_text = str(event.message)
elif hasattr(event, 'messages') and event.messages:
prompt_text = " ".join([str(msg) for msg in event.messages])
if prompt_text:
self.tokens_sent = self._estimate_tokens(prompt_text)
print(f" Estimated input tokens: {self.tokens_sent}")
else:
# Fallback: use stored user input
if self.current_user_input:
self.tokens_sent = self._estimate_tokens(self.current_user_input)
print(f" Estimated input tokens (from stored input): {self.tokens_sent}")
else:
self.tokens_sent = 0
# Get model name
self.current_model = self._get_model_name(event)
cost_rate = self.cost_per_1k_tokens.get(self.current_model, self.cost_per_1k_tokens['unknown'])
print(f" Model: {self.current_model}")
print(f" Cost rate: ${cost_rate:.6f} per 1K tokens")
def after_model_call(self, event: AfterModelCallEvent) -> None:
"""Hook called after model invocation."""
print("📊 MODEL HOOK: Calculating costs...")
if event.stop_response:
# Extract response content
response_content = self._extract_response_content(event.stop_response)
if response_content:
self.tokens_received = self._estimate_tokens(response_content)
print(f" Estimated output tokens: {self.tokens_received}")
# Calculate cost
total_tokens = self.tokens_sent + self.tokens_received
model_name = getattr(self, 'current_model', 'unknown')
cost_rate = self.cost_per_1k_tokens.get(model_name, self.cost_per_1k_tokens['unknown'])
call_cost = (total_tokens / 1000) * cost_rate
self.total_cost += call_cost
self.model_calls += 1
print(f" Total tokens: {total_tokens}")
print(f" Cost for this call: ${call_cost:.6f}")
print(f" Running total cost: ${self.total_cost:.6f}")
# Log to file
self._log_cost_data(model_name, total_tokens, call_cost)
def _estimate_tokens(self, text: str) -> int:
"""Simple token estimation (rough approximation)."""
if not text:
return 0
# Rough estimation: ~4 characters per token for English text
return max(1, len(text) // 4)
def _get_model_name(self, event) -> str:
"""Extract model name from event."""
if hasattr(event, 'model'):
model_obj = event.model
if hasattr(model_obj, 'name'):
return model_obj.name
elif hasattr(model_obj, 'model_id'):
return model_obj.model_id
elif hasattr(model_obj, 'id'):
return model_obj.id
else:
return str(model_obj)
return 'unknown'
def _extract_response_content(self, stop_response) -> str:
"""Extract text content from model response."""
if hasattr(stop_response, 'content'):
return str(stop_response.content)
elif hasattr(stop_response, 'text'):
return str(stop_response.text)
elif hasattr(stop_response, 'message'):
return str(stop_response.message)
else:
return str(stop_response)
def _log_cost_data(self, model_name: str, total_tokens: int, call_cost: float):
"""Log cost data to JSON file."""
try:
with open(self.cost_file, 'r') as f:
cost_data = json.load(f)
cost_entry = {
"id": len(cost_data) + 1,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"model": model_name,
"tokens_sent": self.tokens_sent,
"tokens_received": self.tokens_received,
"total_tokens": total_tokens,
"cost": round(call_cost, 6),
"running_total": round(self.total_cost, 6)
}
cost_data.append(cost_entry)
with open(self.cost_file, 'w') as f:
json.dump(cost_data, f, indent=2)
except Exception as e:
print(f"❌ Error logging cost data: {e}")
def set_user_input(self, user_input: str) -> None:
"""Set the current user input for cost calculation."""
self.current_user_input = user_input
def get_cost_summary(self):
"""Get a summary of cost tracking data."""
return {
"total_model_calls": self.model_calls,
"total_cost": round(self.total_cost, 6),
"average_cost_per_call": round(self.total_cost / max(1, self.model_calls), 6)
}
Integration with Your Agent
To use cost monitoring with your agent, simply attach the hooks:
# Create the cost tracking hooks
cost_hooks = CostTrackingHooks()
# Create agent with cost monitoring
agent = Agent(
system_prompt="You are a helpful assistant.",
hooks=[cost_hooks] # Attach cost monitoring
)
# Use the agent normally
response = agent("Hello! How are you?")
# Check costs anytime
summary = cost_hooks.get_cost_summary()
print(f"Total cost so far: ${summary['total_cost']}")
Real-Time Cost Monitoring in Action
Here's what you'll see when the cost monitoring is working:
Advanced Features
1. Interactive Cost Commands
Add these commands to your chatbot for real-time cost monitoring:
# In your main loop
if user_input.lower() == 'costs':
summary = cost_hooks.get_cost_summary()
print(f"💰 Cost Summary:")
print(f" Total model calls: {summary['total_model_calls']}")
print(f" Total cost: ${summary['total_cost']}")
print(f" Average cost per call: ${summary['average_cost_per_call']}")
2. Cost Alerts
You can easily add cost alerts:
def after_model_call(self, event: AfterModelCallEvent) -> None:
# ... existing cost calculation code ...
# Alert if cost exceeds threshold
if call_cost > 0.01: # $0.01 per call
print(f"⚠️ HIGH COST ALERT: ${call_cost:.6f} for this call!")
# Alert if daily cost exceeds budget
if self.total_cost > 10.0: # $10 daily budget
print(f"🚨 BUDGET ALERT: Daily cost ${self.total_cost:.2f} exceeded!")
3. Cost Analytics
Generate cost reports:
def generate_cost_report(self):
"""Generate a detailed cost report."""
if not os.path.exists(self.cost_file):
return "No cost data available."
with open(self.cost_file, 'r') as f:
cost_data = json.load(f)
if not cost_data:
return "No cost data available."
# Calculate statistics
total_calls = len(cost_data)
total_cost = sum(entry['cost'] for entry in cost_data)
avg_cost = total_cost / total_calls
max_cost = max(entry['cost'] for entry in cost_data)
# Find most expensive calls
expensive_calls = sorted(cost_data, key=lambda x: x['cost'], reverse=True)[:5]
report = f"""
📊 Cost Report
=============
Total Calls: {total_calls}
Total Cost: ${total_cost:.6f}
Average Cost per Call: ${avg_cost:.6f}
Most Expensive Call: ${max_cost:.6f}
Top 5 Most Expensive Calls:
"""
for i, call in enumerate(expensive_calls, 1):
report += f"{i}. {call['timestamp']} - ${call['cost']:.6f} ({call['total_tokens']} tokens)\n"
return report
Why This Approach Works
Non-intrusive: Your agent code doesn't need to change at all
Real-time: Costs are calculated and displayed immediately
Accurate: Uses actual model pricing from Strands documentation
Persistent: All cost data is saved to JSON files
Extensible: Easy to add alerts, reports, and analytics
Official: Uses the official Strands hooks system
Production Considerations
1. Token Estimation Accuracy
The current implementation uses a simple character-based token estimation. For production use, consider:
def _estimate_tokens(self, text: str) -> int:
"""More accurate token estimation using tiktoken."""
try:
import tiktoken
encoding = tiktoken.get_encoding("cl100k_base") # For Claude models
return len(encoding.encode(text))
except ImportError:
# Fallback to character-based estimation
return max(1, len(text) // 4)
2. Database Storage
For high-volume production systems, consider using a database instead of JSON files:
def _log_cost_data(self, model_name: str, total_tokens: int, call_cost: float):
"""Log cost data to database."""
# Use SQLite, PostgreSQL, or your preferred database
# This allows for better querying and analytics
## 3. Cost Budgeting
Implement cost budgets and automatic shutdowns:
def check_budget(self, daily_budget: float = 50.0):
"""Check if daily budget has been exceeded."""
if self.total_cost > daily_budget:
print(f"🚨 BUDGET EXCEEDED: ${self.total_cost:.2f} > ${daily_budget}")
# Optionally shutdown or switch to cheaper model
return False
return True
The Bottom Line
Cost monitoring with Strands hooks gives you complete visibility into your AI spending. You can:
Track costs in real-time during development and production
Identify expensive interactions and optimize them
Set budgets and alerts to prevent cost overruns
Generate detailed reports for stakeholders
Make data-driven decisions about model usage
The key insight is that hooks let you monitor your agent's costs without changing how it works. It's like having a financial dashboard that doesn't interfere with your agent's operation.
Thanks
Sreeni Ramadorai





Top comments (0)