Human-Aligned Decision Transformers for coastal climate resilience planning for extreme data sparsity scenarios
Introduction: The Data Desert Dilemma
I remember the exact moment the problem crystallized for me. I was sitting in a coastal community center in Southeast Asia, laptop open, trying to help local planners model storm surge impacts. They had decades of handwritten tide gauge records, a few scattered satellite images, and profound indigenous knowledge about seasonal patterns—but virtually no structured, machine-readable data. My standard deep learning models, hungry for gigabytes of labeled examples, were useless. This experience, repeated across vulnerable coastal regions from the Mekong Delta to small island states, sparked my multi-year investigation into how we can make AI work when data is desperately scarce.
Through my research into reinforcement learning and sequential decision-making, I discovered that traditional approaches to climate resilience planning fail spectacularly in data-sparse environments. Most coastal communities facing the brunt of climate change have limited monitoring infrastructure, inconsistent historical records, and resources for only occasional measurements. Yet, they must make billion-dollar decisions about seawalls, managed retreat, ecosystem restoration, and early warning systems.
My exploration led me to an unexpected convergence: Decision Transformers (originally developed for offline RL) combined with human-alignment techniques could potentially solve this extreme data sparsity problem. What began as theoretical curiosity transformed into a practical framework that I've since tested in simulation environments and am now working to deploy in actual coastal planning contexts.
Technical Background: Bridging Two Worlds
The Core Challenge of Extreme Data Sparsity
In my investigation of coastal climate data, I found that sparsity manifests in three dimensions:
- Temporal sparsity: Measurements taken irregularly (monthly or seasonally rather than continuously)
- Spatial sparsity: Limited sensor coverage across vast coastal areas
- Feature sparsity: Incomplete observations of relevant variables (e.g., having wave height but not direction)
Traditional approaches like Gaussian processes or spatial interpolation fail when the underlying dynamics are non-stationary—which climate systems decidedly are. During my experimentation with various imputation methods, I realized we needed a paradigm shift: instead of trying to fill missing data, we should build models that explicitly reason about uncertainty from sparse observations.
Decision Transformers: A Sequential Decision-Making Revolution
While studying the evolution of reinforcement learning, I came across Decision Transformers—a fascinating architecture that frames sequential decision-making as conditional sequence modeling. Unlike traditional RL that learns a policy through reward maximization, Decision Transformers learn to generate actions conditioned on desired returns (rewards-to-go) and past states.
The key insight I gained from implementing Decision Transformers was their suitability for offline learning. They can leverage historical decision trajectories without requiring online interaction with the environment—perfect for our coastal planning scenario where we cannot "experiment" with real ecosystems.
import torch
import torch.nn as nn
import numpy as np
class DecisionTransformerBlock(nn.Module):
"""A single transformer block for decision modeling"""
def __init__(self, hidden_dim, num_heads, dropout=0.1):
super().__init__()
self.attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, 4 * hidden_dim),
nn.GELU(),
nn.Linear(4 * hidden_dim, hidden_dim),
nn.Dropout(dropout)
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, attn_mask=None):
# Self-attention with residual connection
attn_out, _ = self.attention(x, x, x, attn_mask=attn_mask)
x = self.norm1(x + self.dropout(attn_out))
# Feed-forward with residual connection
mlp_out = self.mlp(x)
x = self.norm2(x + self.dropout(mlp_out))
return x
class SparseObservationEncoder(nn.Module):
"""Encodes sparse observations with uncertainty quantification"""
def __init__(self, obs_dim, hidden_dim, missing_token=-1):
super().__init__()
self.obs_embedding = nn.Linear(obs_dim, hidden_dim)
self.missing_token = missing_token
self.presence_embedding = nn.Embedding(2, hidden_dim) # 0: missing, 1: present
self.uncertainty_net = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid()
)
def forward(self, observations):
# observations shape: (batch, seq_len, obs_dim)
batch_size, seq_len, obs_dim = observations.shape
# Create presence mask (1 where data exists, 0 where missing)
presence_mask = (observations != self.missing_token).any(dim=-1).float()
# Replace missing values with zeros for embedding
obs_filled = observations.clone()
obs_filled[obs_filled == self.missing_token] = 0
# Embed observations
obs_embedded = self.obs_embedding(obs_filled)
# Add presence information
presence_embedded = self.presence_embedding(presence_mask.long())
combined = obs_embedded + presence_embedded
# Estimate uncertainty (higher for missing observations)
uncertainty = self.uncertainty_net(combined)
uncertainty = uncertainty * (1 - presence_mask.unsqueeze(-1)) # Zero uncertainty for present obs
return combined, uncertainty, presence_mask
Human Alignment: Beyond Reward Maximization
One of the most profound realizations from my research was that pure reward maximization fails catastrophically in climate planning. Coastal communities have complex, often conflicting objectives: economic development, cultural preservation, ecological sustainability, and intergenerational equity. A model optimizing for a single metric (like economic cost) would make disastrous recommendations.
Through studying human-in-the-loop systems and value learning, I developed an approach that aligns Decision Transformers with human preferences and multiple value systems:
class MultiObjectiveAlignmentLayer(nn.Module):
"""Aligns model outputs with multiple human value systems"""
def __init__(self, hidden_dim, num_objectives):
super().__init__()
self.num_objectives = num_objectives
# Value heads for different objectives
self.value_heads = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1)
) for _ in range(num_objectives)
])
# Preference network learns to balance objectives
self.preference_network = nn.Sequential(
nn.Linear(num_objectives + hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, num_objectives),
nn.Softmax(dim=-1)
)
def forward(self, hidden_states, human_preferences=None):
batch_size, seq_len, _ = hidden_states.shape
# Compute values for each objective
objective_values = []
for head in self.value_heads:
values = head(hidden_states) # (batch, seq_len, 1)
objective_values.append(values)
objective_tensor = torch.cat(objective_values, dim=-1) # (batch, seq_len, num_objectives)
# If human preferences provided, use them; otherwise infer from context
if human_preferences is not None:
preferences = human_preferences
else:
# Infer preferences from the context
context_features = hidden_states.mean(dim=1) # (batch, hidden_dim)
preference_input = torch.cat([
objective_tensor.mean(dim=1), # Average objective values
context_features
], dim=-1)
preferences = self.preference_network(preference_input) # (batch, num_objectives)
# Compute aligned value (weighted sum according to preferences)
aligned_value = torch.sum(objective_tensor * preferences.unsqueeze(1), dim=-1)
return aligned_value, preferences, objective_tensor
Implementation Details: Building for Data Scarcity
Architecture for Sparse Coastal Data
My experimentation led to a specialized architecture that handles the unique challenges of coastal climate data:
class CoastalDecisionTransformer(nn.Module):
"""Decision Transformer adapted for sparse coastal climate data"""
def __init__(self, config):
super().__init__()
self.config = config
# Encoders for different data modalities
self.observation_encoder = SparseObservationEncoder(
obs_dim=config.obs_dim,
hidden_dim=config.hidden_dim
)
# Action and return embeddings
self.action_embedding = nn.Linear(config.action_dim, config.hidden_dim)
self.return_embedding = nn.Linear(1, config.hidden_dim)
# Temporal encoding (critical for irregular time series)
self.temporal_encoder = TemporalEncoder(
hidden_dim=config.hidden_dim,
max_timesteps=config.max_timesteps
)
# Transformer blocks
self.transformer_blocks = nn.ModuleList([
DecisionTransformerBlock(
hidden_dim=config.hidden_dim,
num_heads=config.num_heads,
dropout=config.dropout
) for _ in range(config.num_layers)
])
# Alignment layer for human values
self.alignment_layer = MultiObjectiveAlignmentLayer(
hidden_dim=config.hidden_dim,
num_objectives=config.num_objectives
)
# Output heads
self.action_head = nn.Linear(config.hidden_dim, config.action_dim)
self.uncertainty_head = nn.Linear(config.hidden_dim, config.action_dim)
# Causal mask for autoregressive generation
self.register_buffer("causal_mask", torch.tril(torch.ones(config.max_seq_len, config.max_seq_len)))
def forward(self, observations, actions, returns_to_go, timesteps, attention_mask=None):
batch_size, seq_len = observations.shape[:2]
# Encode sparse observations with uncertainty
obs_embedded, obs_uncertainty, presence_mask = self.observation_encoder(observations)
# Embed actions and returns
action_embedded = self.action_embedding(actions)
return_embedded = self.return_embedding(returns_to_go.unsqueeze(-1))
# Add temporal encoding
temporal_encoding = self.temporal_encoder(timesteps)
# Combine embeddings (sequence: [return, observation, action] for each timestep)
token_embeddings = torch.zeros(
batch_size, seq_len * 3, self.config.hidden_dim,
device=observations.device
)
# Arrange tokens in the Decision Transformer pattern
token_embeddings[:, 0::3] = return_embedded + temporal_encoding
token_embeddings[:, 1::3] = obs_embedded + temporal_encoding
token_embeddings[:, 2::3] = action_embedded + temporal_encoding
# Apply causal mask
causal_mask = self.causal_mask[:seq_len*3, :seq_len*3]
# Apply attention mask if provided (for padding)
if attention_mask is not None:
# Expand attention mask to match token sequence
attention_mask = attention_mask.repeat_interleave(3, dim=1)
combined_mask = causal_mask.unsqueeze(0) * attention_mask.unsqueeze(1)
else:
combined_mask = causal_mask.unsqueeze(0)
# Pass through transformer blocks
x = token_embeddings
for block in self.transformer_blocks:
x = block(x, attn_mask=combined_mask)
# Extract action predictions (only at action positions)
x_actions = x[:, 2::3] # Take every third token (action positions)
# Predict actions and uncertainties
action_preds = self.action_head(x_actions)
uncertainty_preds = torch.sigmoid(self.uncertainty_head(x_actions))
# Apply alignment to get human-aligned values
aligned_values, preferences, objective_values = self.alignment_layer(x_actions)
return {
'action_preds': action_preds,
'uncertainty_preds': uncertainty_preds,
'aligned_values': aligned_values,
'preferences': preferences,
'objective_values': objective_values,
'obs_uncertainty': obs_uncertainty,
'presence_mask': presence_mask
}
Training with Sparse, Noisy Data
The training methodology was perhaps the most challenging aspect of my research. Standard supervised learning fails when 80-90% of your "labels" (optimal actions) are unknown. Through experimentation, I developed a multi-stage training approach:
python
class SparseDataTraining:
"""Training procedure for extreme data sparsity scenarios"""
def __init__(self, model, config):
self.model = model
self.config = config
self.optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
# Multiple loss functions for different objectives
self.action_loss_fn = nn.MSELoss(reduction='none')
self.value_loss_fn = nn.HuberLoss()
self.uncertainty_loss_fn = self._create_uncertainty_loss()
def _create_uncertainty_loss(self):
"""Creates loss that encourages higher uncertainty for missing data"""
def uncertainty_loss(pred_uncertainty, presence_mask, obs_uncertainty):
# We want uncertainty to be high when data is missing
target_uncertainty = 1 - presence_mask.float()
# Weight by observation uncertainty
weight = 1 + obs_uncertainty.squeeze()
loss = F.mse_loss(pred_uncertainty, target_uncertainty.unsqueeze(-1), reduction='none')
weighted_loss = loss * weight.unsqueeze(-1)
return weighted_loss.mean()
return uncertainty_loss
def train_step(self, batch, human_feedback=None):
observations, actions, returns, timesteps, masks = batch
# Forward pass
outputs = self.model(observations, actions, returns, timesteps)
# Compute losses
total_loss = 0
loss_dict = {}
# 1. Action prediction loss (only where actions are known)
action_mask = masks['action_mask']
if action_mask.sum() > 0: # Only compute if we have some action labels
action_loss = self.action_loss_fn(
outputs['action_preds'],
actions
)
action_loss = (action_loss * action_mask.unsqueeze(-1)).sum() / action_mask.sum()
total_loss += self.config.action_weight * action_loss
loss_dict['action_loss'] = action_loss.item()
# 2. Uncertainty calibration loss
uncertainty_loss = self.uncertainty_loss_fn(
outputs['uncertainty_preds'],
outputs['presence_mask'],
outputs['obs_uncertainty']
)
total_loss += self.config.uncertainty_weight * uncertainty_loss
loss_dict['uncertainty_loss'] = uncertainty_loss.item()
# 3. Human alignment loss (if feedback provided)
if human_feedback is not None:
alignment_loss = self._compute_alignment_loss(
outputs['aligned_values'],
outputs['preferences'],
human_feedback
)
total_loss += self.config.alignment_weight * alignment_loss
loss_dict['alignment_loss'] = alignment_loss.item()
# 4. Consistency loss (encourage similar predictions for similar contexts)
consistency_loss = self._compute_consistency_loss(outputs, batch)
total_loss += self.config.consistency_weight * consistency_loss
loss_dict['consistency_loss'] = consistency_loss.item()
# Backward pass
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
self.optimizer.step()
return total_loss.item(), loss_dict
def _compute_alignment_loss(self, aligned_values, preferences, human_feedback):
"""Computes loss based on human feedback"""
# Human feedback can be:
# 1. Preference rankings between trajectories
# 2. Direct ratings of actions
# 3. Corrections to model predictions
if 'preference_ranking' in human_feedback:
# Bradley-Terry model for preference learning
rankings = human_feedback['preference_ranking']
# Implement preference learning loss
loss = self._preference_loss(aligned_values, rankings)
elif 'action_ratings' in human_feedback:
# Direct supervision on action quality
ratings = human_feedback['action_ratings']
loss = F.mse_loss(aligned_values, ratings)
else:
# Default: encourage diversity in preferences
loss = -preferences.entropy().mean() # Encourage diverse consideration of objectives
return loss
def _compute_consistency_loss(self, outputs, batch):
"""Encourages consistent predictions for similar states"""
# This is critical for sparse data - helps generalize from limited examples
observations = batch[0]
# Create slightly perturbed versions of observations
noise = torch.randn_like(observations) * 0.01
perturbed_obs = observations + noise
perturbed_obs = torch.clamp(perturbed_obs, -10, 10)
# Get predictions for
Top comments (0)