DEV Community

Rikin Patel
Rikin Patel

Posted on

Human-Aligned Decision Transformers for coastal climate resilience planning for extreme data sparsity scenarios

Human-Aligned Decision Transformers for Coastal Climate Resilience Planning

Human-Aligned Decision Transformers for coastal climate resilience planning for extreme data sparsity scenarios

Introduction: The Data Desert Dilemma

I remember the exact moment the problem crystallized for me. I was sitting in a coastal community center in Southeast Asia, laptop open, trying to help local planners model storm surge impacts. They had decades of handwritten tide gauge records, a few scattered satellite images, and profound indigenous knowledge about seasonal patterns—but virtually no structured, machine-readable data. My standard deep learning models, hungry for gigabytes of labeled examples, were useless. This experience, repeated across vulnerable coastal regions from the Mekong Delta to small island states, sparked my multi-year investigation into how we can make AI work when data is desperately scarce.

Through my research into reinforcement learning and sequential decision-making, I discovered that traditional approaches to climate resilience planning fail spectacularly in data-sparse environments. Most coastal communities facing the brunt of climate change have limited monitoring infrastructure, inconsistent historical records, and resources for only occasional measurements. Yet, they must make billion-dollar decisions about seawalls, managed retreat, ecosystem restoration, and early warning systems.

My exploration led me to an unexpected convergence: Decision Transformers (originally developed for offline RL) combined with human-alignment techniques could potentially solve this extreme data sparsity problem. What began as theoretical curiosity transformed into a practical framework that I've since tested in simulation environments and am now working to deploy in actual coastal planning contexts.

Technical Background: Bridging Two Worlds

The Core Challenge of Extreme Data Sparsity

In my investigation of coastal climate data, I found that sparsity manifests in three dimensions:

  1. Temporal sparsity: Measurements taken irregularly (monthly or seasonally rather than continuously)
  2. Spatial sparsity: Limited sensor coverage across vast coastal areas
  3. Feature sparsity: Incomplete observations of relevant variables (e.g., having wave height but not direction)

Traditional approaches like Gaussian processes or spatial interpolation fail when the underlying dynamics are non-stationary—which climate systems decidedly are. During my experimentation with various imputation methods, I realized we needed a paradigm shift: instead of trying to fill missing data, we should build models that explicitly reason about uncertainty from sparse observations.

Decision Transformers: A Sequential Decision-Making Revolution

While studying the evolution of reinforcement learning, I came across Decision Transformers—a fascinating architecture that frames sequential decision-making as conditional sequence modeling. Unlike traditional RL that learns a policy through reward maximization, Decision Transformers learn to generate actions conditioned on desired returns (rewards-to-go) and past states.

The key insight I gained from implementing Decision Transformers was their suitability for offline learning. They can leverage historical decision trajectories without requiring online interaction with the environment—perfect for our coastal planning scenario where we cannot "experiment" with real ecosystems.

import torch
import torch.nn as nn
import numpy as np

class DecisionTransformerBlock(nn.Module):
    """A single transformer block for decision modeling"""
    def __init__(self, hidden_dim, num_heads, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, 4 * hidden_dim),
            nn.GELU(),
            nn.Linear(4 * hidden_dim, hidden_dim),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        # Self-attention with residual connection
        attn_out, _ = self.attention(x, x, x, attn_mask=attn_mask)
        x = self.norm1(x + self.dropout(attn_out))

        # Feed-forward with residual connection
        mlp_out = self.mlp(x)
        x = self.norm2(x + self.dropout(mlp_out))
        return x

class SparseObservationEncoder(nn.Module):
    """Encodes sparse observations with uncertainty quantification"""
    def __init__(self, obs_dim, hidden_dim, missing_token=-1):
        super().__init__()
        self.obs_embedding = nn.Linear(obs_dim, hidden_dim)
        self.missing_token = missing_token
        self.presence_embedding = nn.Embedding(2, hidden_dim)  # 0: missing, 1: present
        self.uncertainty_net = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )

    def forward(self, observations):
        # observations shape: (batch, seq_len, obs_dim)
        batch_size, seq_len, obs_dim = observations.shape

        # Create presence mask (1 where data exists, 0 where missing)
        presence_mask = (observations != self.missing_token).any(dim=-1).float()

        # Replace missing values with zeros for embedding
        obs_filled = observations.clone()
        obs_filled[obs_filled == self.missing_token] = 0

        # Embed observations
        obs_embedded = self.obs_embedding(obs_filled)

        # Add presence information
        presence_embedded = self.presence_embedding(presence_mask.long())
        combined = obs_embedded + presence_embedded

        # Estimate uncertainty (higher for missing observations)
        uncertainty = self.uncertainty_net(combined)
        uncertainty = uncertainty * (1 - presence_mask.unsqueeze(-1))  # Zero uncertainty for present obs

        return combined, uncertainty, presence_mask
Enter fullscreen mode Exit fullscreen mode

Human Alignment: Beyond Reward Maximization

One of the most profound realizations from my research was that pure reward maximization fails catastrophically in climate planning. Coastal communities have complex, often conflicting objectives: economic development, cultural preservation, ecological sustainability, and intergenerational equity. A model optimizing for a single metric (like economic cost) would make disastrous recommendations.

Through studying human-in-the-loop systems and value learning, I developed an approach that aligns Decision Transformers with human preferences and multiple value systems:

class MultiObjectiveAlignmentLayer(nn.Module):
    """Aligns model outputs with multiple human value systems"""
    def __init__(self, hidden_dim, num_objectives):
        super().__init__()
        self.num_objectives = num_objectives

        # Value heads for different objectives
        self.value_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.ReLU(),
                nn.Linear(hidden_dim // 2, 1)
            ) for _ in range(num_objectives)
        ])

        # Preference network learns to balance objectives
        self.preference_network = nn.Sequential(
            nn.Linear(num_objectives + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_objectives),
            nn.Softmax(dim=-1)
        )

    def forward(self, hidden_states, human_preferences=None):
        batch_size, seq_len, _ = hidden_states.shape

        # Compute values for each objective
        objective_values = []
        for head in self.value_heads:
            values = head(hidden_states)  # (batch, seq_len, 1)
            objective_values.append(values)

        objective_tensor = torch.cat(objective_values, dim=-1)  # (batch, seq_len, num_objectives)

        # If human preferences provided, use them; otherwise infer from context
        if human_preferences is not None:
            preferences = human_preferences
        else:
            # Infer preferences from the context
            context_features = hidden_states.mean(dim=1)  # (batch, hidden_dim)
            preference_input = torch.cat([
                objective_tensor.mean(dim=1),  # Average objective values
                context_features
            ], dim=-1)
            preferences = self.preference_network(preference_input)  # (batch, num_objectives)

        # Compute aligned value (weighted sum according to preferences)
        aligned_value = torch.sum(objective_tensor * preferences.unsqueeze(1), dim=-1)

        return aligned_value, preferences, objective_tensor
Enter fullscreen mode Exit fullscreen mode

Implementation Details: Building for Data Scarcity

Architecture for Sparse Coastal Data

My experimentation led to a specialized architecture that handles the unique challenges of coastal climate data:

class CoastalDecisionTransformer(nn.Module):
    """Decision Transformer adapted for sparse coastal climate data"""
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Encoders for different data modalities
        self.observation_encoder = SparseObservationEncoder(
            obs_dim=config.obs_dim,
            hidden_dim=config.hidden_dim
        )

        # Action and return embeddings
        self.action_embedding = nn.Linear(config.action_dim, config.hidden_dim)
        self.return_embedding = nn.Linear(1, config.hidden_dim)

        # Temporal encoding (critical for irregular time series)
        self.temporal_encoder = TemporalEncoder(
            hidden_dim=config.hidden_dim,
            max_timesteps=config.max_timesteps
        )

        # Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            DecisionTransformerBlock(
                hidden_dim=config.hidden_dim,
                num_heads=config.num_heads,
                dropout=config.dropout
            ) for _ in range(config.num_layers)
        ])

        # Alignment layer for human values
        self.alignment_layer = MultiObjectiveAlignmentLayer(
            hidden_dim=config.hidden_dim,
            num_objectives=config.num_objectives
        )

        # Output heads
        self.action_head = nn.Linear(config.hidden_dim, config.action_dim)
        self.uncertainty_head = nn.Linear(config.hidden_dim, config.action_dim)

        # Causal mask for autoregressive generation
        self.register_buffer("causal_mask", torch.tril(torch.ones(config.max_seq_len, config.max_seq_len)))

    def forward(self, observations, actions, returns_to_go, timesteps, attention_mask=None):
        batch_size, seq_len = observations.shape[:2]

        # Encode sparse observations with uncertainty
        obs_embedded, obs_uncertainty, presence_mask = self.observation_encoder(observations)

        # Embed actions and returns
        action_embedded = self.action_embedding(actions)
        return_embedded = self.return_embedding(returns_to_go.unsqueeze(-1))

        # Add temporal encoding
        temporal_encoding = self.temporal_encoder(timesteps)

        # Combine embeddings (sequence: [return, observation, action] for each timestep)
        token_embeddings = torch.zeros(
            batch_size, seq_len * 3, self.config.hidden_dim,
            device=observations.device
        )

        # Arrange tokens in the Decision Transformer pattern
        token_embeddings[:, 0::3] = return_embedded + temporal_encoding
        token_embeddings[:, 1::3] = obs_embedded + temporal_encoding
        token_embeddings[:, 2::3] = action_embedded + temporal_encoding

        # Apply causal mask
        causal_mask = self.causal_mask[:seq_len*3, :seq_len*3]

        # Apply attention mask if provided (for padding)
        if attention_mask is not None:
            # Expand attention mask to match token sequence
            attention_mask = attention_mask.repeat_interleave(3, dim=1)
            combined_mask = causal_mask.unsqueeze(0) * attention_mask.unsqueeze(1)
        else:
            combined_mask = causal_mask.unsqueeze(0)

        # Pass through transformer blocks
        x = token_embeddings
        for block in self.transformer_blocks:
            x = block(x, attn_mask=combined_mask)

        # Extract action predictions (only at action positions)
        x_actions = x[:, 2::3]  # Take every third token (action positions)

        # Predict actions and uncertainties
        action_preds = self.action_head(x_actions)
        uncertainty_preds = torch.sigmoid(self.uncertainty_head(x_actions))

        # Apply alignment to get human-aligned values
        aligned_values, preferences, objective_values = self.alignment_layer(x_actions)

        return {
            'action_preds': action_preds,
            'uncertainty_preds': uncertainty_preds,
            'aligned_values': aligned_values,
            'preferences': preferences,
            'objective_values': objective_values,
            'obs_uncertainty': obs_uncertainty,
            'presence_mask': presence_mask
        }
Enter fullscreen mode Exit fullscreen mode

Training with Sparse, Noisy Data

The training methodology was perhaps the most challenging aspect of my research. Standard supervised learning fails when 80-90% of your "labels" (optimal actions) are unknown. Through experimentation, I developed a multi-stage training approach:


python
class SparseDataTraining:
    """Training procedure for extreme data sparsity scenarios"""

    def __init__(self, model, config):
        self.model = model
        self.config = config
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )

        # Multiple loss functions for different objectives
        self.action_loss_fn = nn.MSELoss(reduction='none')
        self.value_loss_fn = nn.HuberLoss()
        self.uncertainty_loss_fn = self._create_uncertainty_loss()

    def _create_uncertainty_loss(self):
        """Creates loss that encourages higher uncertainty for missing data"""
        def uncertainty_loss(pred_uncertainty, presence_mask, obs_uncertainty):
            # We want uncertainty to be high when data is missing
            target_uncertainty = 1 - presence_mask.float()

            # Weight by observation uncertainty
            weight = 1 + obs_uncertainty.squeeze()

            loss = F.mse_loss(pred_uncertainty, target_uncertainty.unsqueeze(-1), reduction='none')
            weighted_loss = loss * weight.unsqueeze(-1)
            return weighted_loss.mean()

        return uncertainty_loss

    def train_step(self, batch, human_feedback=None):
        observations, actions, returns, timesteps, masks = batch

        # Forward pass
        outputs = self.model(observations, actions, returns, timesteps)

        # Compute losses
        total_loss = 0
        loss_dict = {}

        # 1. Action prediction loss (only where actions are known)
        action_mask = masks['action_mask']
        if action_mask.sum() > 0:  # Only compute if we have some action labels
            action_loss = self.action_loss_fn(
                outputs['action_preds'],
                actions
            )
            action_loss = (action_loss * action_mask.unsqueeze(-1)).sum() / action_mask.sum()
            total_loss += self.config.action_weight * action_loss
            loss_dict['action_loss'] = action_loss.item()

        # 2. Uncertainty calibration loss
        uncertainty_loss = self.uncertainty_loss_fn(
            outputs['uncertainty_preds'],
            outputs['presence_mask'],
            outputs['obs_uncertainty']
        )
        total_loss += self.config.uncertainty_weight * uncertainty_loss
        loss_dict['uncertainty_loss'] = uncertainty_loss.item()

        # 3. Human alignment loss (if feedback provided)
        if human_feedback is not None:
            alignment_loss = self._compute_alignment_loss(
                outputs['aligned_values'],
                outputs['preferences'],
                human_feedback
            )
            total_loss += self.config.alignment_weight * alignment_loss
            loss_dict['alignment_loss'] = alignment_loss.item()

        # 4. Consistency loss (encourage similar predictions for similar contexts)
        consistency_loss = self._compute_consistency_loss(outputs, batch)
        total_loss += self.config.consistency_weight * consistency_loss
        loss_dict['consistency_loss'] = consistency_loss.item()

        # Backward pass
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
        self.optimizer.step()

        return total_loss.item(), loss_dict

    def _compute_alignment_loss(self, aligned_values, preferences, human_feedback):
        """Computes loss based on human feedback"""
        # Human feedback can be:
        # 1. Preference rankings between trajectories
        # 2. Direct ratings of actions
        # 3. Corrections to model predictions

        if 'preference_ranking' in human_feedback:
            # Bradley-Terry model for preference learning
            rankings = human_feedback['preference_ranking']
            # Implement preference learning loss
            loss = self._preference_loss(aligned_values, rankings)
        elif 'action_ratings' in human_feedback:
            # Direct supervision on action quality
            ratings = human_feedback['action_ratings']
            loss = F.mse_loss(aligned_values, ratings)
        else:
            # Default: encourage diversity in preferences
            loss = -preferences.entropy().mean()  # Encourage diverse consideration of objectives

        return loss

    def _compute_consistency_loss(self, outputs, batch):
        """Encourages consistent predictions for similar states"""
        # This is critical for sparse data - helps generalize from limited examples
        observations = batch[0]

        # Create slightly perturbed versions of observations
        noise = torch.randn_like(observations) * 0.01
        perturbed_obs = observations + noise
        perturbed_obs = torch.clamp(perturbed_obs, -10, 10)

        # Get predictions for
Enter fullscreen mode Exit fullscreen mode

Top comments (0)