Explainable Causal Reinforcement Learning for precision oncology clinical workflows under real-time policy constraints
A Personal Journey into the Heart of AI-Driven Medicine
My journey into this fascinating intersection of AI and oncology began during a late-night research session about three years ago. I was experimenting with standard reinforcement learning (RL) for treatment optimization when I encountered a particularly troubling case study. The model had recommended a treatment regimen that, statistically, should have worked—but the patient experienced severe adverse effects. While exploring the model's decision-making process, I discovered a fundamental limitation: traditional RL could identify correlations between treatments and outcomes, but it couldn't distinguish causation from mere association. This realization sent me down a rabbit hole of causal inference literature, structural equation modeling, and eventually to the frontier of explainable AI.
In my research of precision oncology workflows, I realized that clinicians weren't just asking "what treatment works?" but "why does this treatment work for this specific patient?" and "what would happen if we tried something different?" These questions demanded more than predictive accuracy—they required causal understanding and transparent reasoning. Through studying dozens of clinical trials and electronic health records, I learned that oncology decisions operate under severe real-time constraints: treatment windows measured in days, rapidly changing patient conditions, and the constant pressure of tumor progression.
Technical Background: The Convergence of Three Disciplines
The Causal Revolution in Machine Learning
While exploring causal inference frameworks, I discovered that Judea Pearl's ladder of causation provides the perfect conceptual framework for medical AI. Most machine learning operates on the first rung—association. We observe that treatment A correlates with outcome B. Causal models ascend to the second rung—intervention. They answer "what if" questions: what would happen if we administered treatment A? The highest rung, counterfactuals, addresses what would have happened had we chosen a different treatment.
One interesting finding from my experimentation with do-calculus was that many medical "best practices" were actually causal relationships waiting to be formalized. For instance, the relationship between chemotherapy dosage and white blood cell count isn't just correlational—it's fundamentally causal, with clear directionality and potential confounders.
Reinforcement Learning with Policy Constraints
During my investigation of constrained RL, I found that oncology presents unique challenges. Unlike games where agents can explore freely, medical RL must operate within ethical and clinical boundaries. Real-time policy constraints aren't just optimization boundaries—they're hard requirements derived from pharmacokinetics, toxicity thresholds, and clinical guidelines.
As I was experimenting with constrained policy optimization, I came across the fundamental tension between exploration and safety. In oncology, exploration doesn't mean trying random treatments—it means carefully navigating the space of evidence-based options while respecting individual patient tolerances.
Explainability as a Clinical Necessity
My exploration of explainable AI in medical contexts revealed that clinicians don't just want feature importance scores. They need causal explanations: "The model recommends reducing the dosage because your recent liver function tests show decreased metabolic capacity, which would otherwise lead to toxic accumulation based on these pharmacokinetic simulations."
Implementation Details: Building an Explainable Causal RL System
Causal Model Representation
Through studying structural causal models (SCMs), I developed a hybrid approach that combines domain knowledge with data-driven discovery. Here's a simplified representation of how we structure the causal graph for oncology decision-making:
import networkx as nx
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple
class OncologyCausalGraph:
def __init__(self, patient_data: pd.DataFrame):
"""Initialize causal graph for oncology decision-making"""
self.graph = nx.DiGraph()
self.patient_data = patient_data
self._build_base_structure()
def _build_base_structure(self):
"""Build the fundamental causal relationships in oncology"""
# Treatment nodes
self.graph.add_node("chemotherapy_dose",
node_type="treatment",
constraints={"min": 0, "max": 100})
self.graph.add_node("immunotherapy",
node_type="treatment",
constraints={"binary": True})
# Patient state nodes
self.graph.add_node("tumor_burden",
node_type="state",
measurement_unit="mm^3")
self.graph.add_node("immune_response",
node_type="state",
measurement_unit="cell_count")
# Physiological nodes
self.graph.add_node("liver_function",
node_type="physiology",
constraints={"normal_range": (40, 120)})
self.graph.add_node("kidney_function",
node_type="physiology")
# Outcome nodes
self.graph.add_node("toxicity",
node_type="outcome",
severity_levels=["mild", "moderate", "severe"])
self.graph.add_node("progression_free_survival",
node_type="outcome",
measurement_unit="days")
# Causal relationships (edges with effect sizes and confidence)
self.graph.add_edge("chemotherapy_dose", "tumor_burden",
effect_type="negative",
confidence=0.85)
self.graph.add_edge("chemotherapy_dose", "immune_response",
effect_type="negative",
confidence=0.75)
self.graph.add_edge("chemotherapy_dose", "toxicity",
effect_type="positive",
confidence=0.90)
self.graph.add_edge("immunotherapy", "immune_response",
effect_type="positive",
confidence=0.80)
self.graph.add_edge("liver_function", "chemotherapy_dose",
effect_type="moderator",
confidence=0.95) # Liver function affects dose tolerance
def estimate_causal_effect(self, treatment: str, outcome: str,
adjustment_set: List[str]) -> Dict:
"""Estimate causal effect using backdoor criterion"""
# Simplified implementation - in practice would use do-calculus
# or double machine learning
treatment_data = self.patient_data[treatment]
outcome_data = self.patient_data[outcome]
# Adjust for confounders
if adjustment_set:
confounders = self.patient_data[adjustment_set]
# Use linear regression for simplicity (real implementation would be more robust)
from sklearn.linear_model import LinearRegression
model = LinearRegression()
model.fit(confounders, treatment_data)
treatment_residuals = treatment_data - model.predict(confounders)
model.fit(confounders, outcome_data)
outcome_residuals = outcome_data - model.predict(confounders)
effect = np.corrcoef(treatment_residuals, outcome_residuals)[0, 1]
else:
effect = np.corrcoef(treatment_data, outcome_data)[0, 1]
return {
"treatment": treatment,
"outcome": outcome,
"causal_effect": effect,
"adjustment_set": adjustment_set,
"interpretation": self._interpret_effect(effect, treatment, outcome)
}
Constrained Reinforcement Learning Implementation
One of the most challenging aspects I encountered was implementing real-time policy constraints. While learning about constrained Markov decision processes (CMDPs), I observed that oncology constraints are often dynamic—they change based on patient response.
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
class ConstrainedOncologyAgent(nn.Module):
def __init__(self, state_dim: int, action_dim: int, constraint_dim: int):
super().__init__()
self.state_dim = state_dim
self.action_dim = action_dim
self.constraint_dim = constraint_dim
# Policy network
self.policy_net = nn.Sequential(
nn.Linear(state_dim, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, action_dim * 2) # Mean and log_std for each action
)
# Value network
self.value_net = nn.Sequential(
nn.Linear(state_dim, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 1)
)
# Constraint network - predicts constraint violations
self.constraint_net = nn.Sequential(
nn.Linear(state_dim + action_dim, 64),
nn.ReLU(),
nn.Linear(64, constraint_dim)
)
def forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through policy network"""
policy_output = self.policy_net(state)
mean = policy_output[:, :self.action_dim]
log_std = policy_output[:, self.action_dim:]
std = torch.exp(log_std)
return mean, std
def get_action(self, state: torch.Tensor,
current_constraints: torch.Tensor) -> torch.Tensor:
"""Sample action with constraint satisfaction"""
mean, std = self.forward(state)
normal_dist = torch.distributions.Normal(mean, std)
# Sample candidate actions
candidate_actions = normal_dist.sample((100,)) # Sample 100 candidates
# Evaluate constraints for each candidate
with torch.no_grad():
constraint_violations = []
for action in candidate_actions:
# Predict constraint violations
state_action = torch.cat([state, action.unsqueeze(0)], dim=-1)
violations = self.constraint_net(state_action)
constraint_violations.append(violations)
constraint_violations = torch.stack(constraint_violations)
# Find actions that satisfy all constraints
constraint_satisfied = (constraint_violations <= current_constraints).all(dim=-1)
if constraint_satisfied.any():
# Choose best action among those satisfying constraints
valid_indices = torch.where(constraint_satisfied)[0]
# Evaluate expected value for valid actions
state_repeated = state.repeat(len(valid_indices), 1)
valid_actions = candidate_actions[valid_indices]
values = self.value_net(state_repeated)
best_idx = valid_indices[torch.argmax(values)]
return candidate_actions[best_idx]
else:
# No action satisfies all constraints - choose least violating
total_violations = constraint_violations.sum(dim=-1)
best_idx = torch.argmin(total_violations)
return candidate_actions[best_idx]
class ConstrainedPPO:
def __init__(self, agent: ConstrainedOncologyAgent,
constraint_tolerance: float = 0.1):
self.agent = agent
self.constraint_tolerance = constraint_tolerance
self.optimizer = optim.Adam(agent.parameters(), lr=3e-4)
self.memory = deque(maxlen=10000)
def update(self, states, actions, rewards, next_states,
constraints, constraint_violations, dones):
"""Update policy with constraint-aware loss"""
states = torch.FloatTensor(states)
actions = torch.FloatTensor(actions)
rewards = torch.FloatTensor(rewards)
next_states = torch.FloatTensor(next_states)
constraints = torch.FloatTensor(constraints)
constraint_violations = torch.FloatTensor(constraint_violations)
# Calculate advantages
with torch.no_grad():
values = self.agent.value_net(states)
next_values = self.agent.value_net(next_states)
advantages = rewards + 0.99 * next_values * (1 - dones) - values
# Policy loss with constraint penalty
mean, std = self.agent(states)
normal_dist = torch.distributions.Normal(mean, std)
log_probs = normal_dist.log_prob(actions).sum(dim=-1)
# Constraint violation penalty
constraint_penalty = torch.relu(constraint_violations - constraints)
penalty_weight = 10.0 # Hyperparameter tuned through experimentation
# Combined loss
policy_loss = -(log_probs * advantages).mean()
constraint_loss = penalty_weight * constraint_penalty.mean()
total_loss = policy_loss + constraint_loss
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.agent.parameters(), 0.5)
self.optimizer.step()
Explainability Module
During my experimentation with explanation generation, I found that clinicians respond best to explanations that mirror clinical reasoning patterns. While exploring various explanation techniques, I developed a hybrid approach:
python
from typing import Dict, List, Any
import numpy as np
from dataclasses import dataclass
@dataclass
class ClinicalExplanation:
recommendation: str
confidence: float
primary_reasoning: List[str]
alternative_options: List[Dict[str, Any]]
counterfactual_scenarios: List[Dict[str, Any]]
evidence_level: str # "strong", "moderate", "weak"
uncertainty_acknowledgement: str
class CausalExplanationGenerator:
def __init__(self, causal_graph: OncologyCausalGraph,
treatment_history: pd.DataFrame):
self.causal_graph = causal_graph
self.treatment_history = treatment_history
self.explanation_templates = self._load_templates()
def generate_explanation(self, current_state: Dict[str, float],
recommended_action: Dict[str, float],
alternative_actions: List[Dict[str, float]]) -> ClinicalExplanation:
"""Generate clinical explanation for treatment recommendation"""
# 1. Identify primary causal pathways
primary_pathways = self._identify_causal_pathways(
current_state, recommended_action
)
# 2. Estimate effect sizes
effect_sizes = self._estimate_effects(
current_state, recommended_action, primary_pathways
)
# 3. Generate counterfactuals
counterfactuals = self._generate_counterfactuals(
current_state, recommended_action, alternative_actions
)
# 4. Assess evidence strength
evidence_strength = self._assess_evidence(
primary_pathways, effect_sizes
)
# 5. Compose explanation
explanation_text = self._compose_explanation(
primary_pathways, effect_sizes, counterfactuals, evidence_strength
)
return ClinicalExplanation(
recommendation=self._format_recommendation(recommended_action),
confidence=self._calculate_confidence(effect_sizes),
primary_reasoning=primary_pathways[:3], # Top 3 pathways
alternative_options=alternative_actions[:2], # Top 2 alternatives
counterfactual_scenarios=counterfactuals,
evidence_level=evidence_strength,
uncertainty_acknowledgement=self._acknowledge_uncertainties(
effect_sizes, counterfactuals
)
)
def _identify_causal_pathways(self, state: Dict, action: Dict) -> List[str]:
"""Identify the most relevant causal pathways for this decision"""
pathways = []
# Analyze treatment -> outcome pathways
for treatment, dose in action.items():
if dose > 0: # Active treatment
# Find outcomes affected by this treatment
for outcome in self.causal_graph.graph.nodes:
if outcome.startswith("progression") or outcome == "toxicity":
# Check if there's a causal path
try:
paths = list(nx.all_simple_paths(
self.causal_graph.graph,
source=treatment,
target=outcome,
cutoff=3 # Maximum path length
))
if paths:
for path in paths[:2]: # Take up to 2 shortest paths
pathway_desc = self._describe_pathway(path)
pathways.append(pathway_desc)
except:
continue
# Sort by pathway strength (simplified)
return sorted(pathways, key=lambda x: len(x), reverse=True)[:5]
def _estimate_effects(self, state: Dict, action: Dict,
pathways: List[str]) -> Dict[str, float]:
"""Estimate effect sizes for each causal pathway"""
effects = {}
# Simplified effect estimation
# In practice, this would use the causal model's do-calculus
for pathway in pathways:
# Extract treatment and outcome from pathway description
components = pathway.split(" → ")
if len(components) >= 2:
treatment = components[0]
outcome = components[-1]
# Use the causal graph's effect estimation
adjustment_set = self._identify_adjustment_set(treatment, outcome)
effect_result = self.causal_graph.estimate_causal_effect(
treatment, outcome, adjustment_set
)
effects[pathway] = effect_result["causal_effect"]
return effects
def _generate_counterfactuals(self, state: Dict,
chosen_action: Dict,
alternatives: List[Dict]) -> List[Dict]:
"""Generate counterfactual scenarios for alternative treatments"""
counterfactuals = []
for alt_action in alternatives[:3]: # Consider top 3 alternatives
scenario = {
"alternative_action": alt_action,
"expected_outcomes": {},
"comparison_to_chosen": {},
"risk_assessment": {}
}
# Estimate outcomes for alternative (simplified)
for outcome in ["progression_free_survival", "toxicity"]:
# This would use the causal model for proper counterfactual inference
alt_outcome = self._estimate_counterfactual_outcome(
state, alt_action, outcome
)
chosen_outcome = self._estimate_counterfactual_outcome(
state, chosen_action, outcome
)
scenario["expected_outcomes"][outcome] = alt_outcome
scenario["comparison_to_chosen"][outcome] = (
alt_outcome - chosen_outcome
)
counterfactuals.append(scenario)
Top comments (0)