DEV Community

Rikin Patel
Rikin Patel

Posted on

Explainable Causal Reinforcement Learning for sustainable aquaculture monitoring systems with inverse simulation verification

Explainable Causal Reinforcement Learning for sustainable aquaculture monitoring systems with inverse simulation verification

Explainable Causal Reinforcement Learning for sustainable aquaculture monitoring systems with inverse simulation verification

Introduction: From Observational Confusion to Causal Clarity

My journey into causal reinforcement learning began during a frustrating period of working with traditional aquaculture monitoring systems. While exploring deep reinforcement learning for optimizing fish feeding schedules, I kept hitting the same wall: the models would find spurious correlations that worked in simulation but failed catastrophically in real aquaculture environments. One memorable incident involved a model that "learned" to associate increased water flow with better fish growth—only to discover it was actually correlating with seasonal temperature changes that happened to coincide with flow adjustments.

Through studying cutting-edge papers on causal inference, I realized the fundamental limitation: traditional RL agents learn associations, not causation. They become masters of correlation without understanding the underlying mechanisms. This epiphany led me down a path of integrating causal discovery with reinforcement learning, specifically for the complex, multi-variate environment of sustainable aquaculture.

Technical Background: The Marriage of Causality and Reinforcement Learning

The Core Problem in Aquaculture AI

In my research of aquaculture monitoring systems, I discovered that traditional machine learning approaches suffer from several critical issues:

  1. Non-stationarity: Water quality parameters, fish behavior, and environmental conditions change over time
  2. Confounding variables: Multiple interacting factors create misleading correlations
  3. Delayed effects: Actions like feeding or oxygenation have consequences that manifest hours or days later
  4. Intervention effects: Changing one parameter affects multiple downstream variables

While exploring causal inference literature, I came across the fundamental distinction between the observational distribution P(Y|X) and the interventional distribution P(Y|do(X)). This distinction became the cornerstone of my approach.

Causal Reinforcement Learning Framework

During my investigation of causal RL, I found that the most promising approach combines:

  • Causal discovery to learn the structural causal model (SCM) from observational data
  • Causal inference to estimate intervention effects
  • Reinforcement learning to optimize policies over the causal model

One interesting finding from my experimentation with different causal discovery algorithms was that constraint-based methods (like PC and FCI algorithms) performed better than score-based methods in aquaculture environments due to their ability to handle latent confounding.

Implementation Details: Building the Causal RL System

Phase 1: Causal Discovery from Aquaculture Sensor Data

My exploration of aquaculture data revealed that we need to handle mixed data types (continuous water parameters, discrete feeding events, binary equipment states) and irregular time series. Here's a simplified implementation of the causal discovery pipeline I developed:

import numpy as np
from causalnex.structure import StructureModel
from causalnex.structure.notears import from_pandas
import pandas as pd

class AquacultureCausalDiscoverer:
    def __init__(self, alpha=0.05, max_lag=24):
        self.alpha = alpha  # Significance level for independence tests
        self.max_lag = max_lag  # Maximum time lag to consider
        self.structure_model = None

    def discover_temporal_causality(self, sensor_data):
        """Discover causal relationships in time-series aquaculture data"""
        # Preprocess and create lagged features
        lagged_data = self._create_lagged_features(sensor_data)

        # Use NOTEARS for non-linear causal discovery
        self.structure_model = from_pandas(
            lagged_data,
            max_iter=100,
            h_tol=1e-8,
            w_threshold=0.3
        )

        # Apply constraint-based refinement
        self._refine_with_constraints(lagged_data)

        return self.structure_model

    def _create_lagged_features(self, data):
        """Create time-lagged versions of features for temporal causal analysis"""
        lagged_df = data.copy()
        for col in data.columns:
            for lag in range(1, self.max_lag + 1):
                lagged_df[f'{col}_lag_{lag}'] = data[col].shift(lag)

        return lagged_df.dropna()

    def _refine_with_constraints(self, data):
        """Refine causal graph using conditional independence tests"""
        # Implementation of PC-algorithm style refinement
        # This removes edges that don't satisfy conditional independence
        pass
Enter fullscreen mode Exit fullscreen mode

Phase 2: Causal Reinforcement Learning Agent

Through studying causal RL papers, I learned that the key innovation is using the causal model to simulate interventions and counterfactuals. This allows the agent to reason about "what if" scenarios without actually intervening in the real system.

import torch
import torch.nn as nn
from stable_baselines3 import PPO
from causalnex.inference import InferenceEngine

class CausalRLAgent:
    def __init__(self, causal_model, state_dim, action_dim):
        self.causal_model = causal_model
        self.inference_engine = InferenceEngine(causal_model)
        self.state_dim = state_dim
        self.action_dim = action_dim

        # Policy network that uses causal features
        self.policy_net = CausalPolicyNetwork(
            state_dim,
            action_dim,
            causal_features_dim=64
        )

        # Counterfactual simulator for safe exploration
        self.counterfactual_sim = CounterfactualSimulator(causal_model)

    def select_action(self, state, use_counterfactual=True):
        """Select action using causal reasoning"""
        if use_counterfactual:
            # Generate counterfactual outcomes for different actions
            counterfactuals = []
            for action in self._get_action_space():
                cf_outcome = self.counterfactual_sim.simulate(
                    current_state=state,
                    intervention={'feeding_rate': action}
                )
                counterfactuals.append(cf_outcome)

            # Choose action with best counterfactual outcome
            best_action = self._evaluate_counterfactuals(counterfactuals)
            return best_action
        else:
            # Fall back to standard RL policy
            return self.policy_net(state)

    def update_causal_model(self, new_data):
        """Update causal model with new observational data"""
        # Online causal discovery update
        updated_model = self._online_causal_update(new_data)
        self.causal_model = updated_model
        self.counterfactual_sim.update_model(updated_model)

class CausalPolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, causal_features_dim):
        super().__init__()
        self.causal_encoder = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, causal_features_dim)
        )

        # Causal attention mechanism
        self.causal_attention = CausalAttention(causal_features_dim)

        self.policy_head = nn.Sequential(
            nn.Linear(causal_features_dim, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, state, causal_graph=None):
        # Encode state into causal feature space
        causal_features = self.causal_encoder(state)

        # Apply causal attention if graph is provided
        if causal_graph is not None:
            causal_features = self.causal_attention(
                causal_features,
                causal_graph
            )

        # Generate action probabilities
        action_probs = self.policy_head(causal_features)
        return action_probs
Enter fullscreen mode Exit fullscreen mode

Phase 3: Inverse Simulation Verification

One of my most significant discoveries came while experimenting with verification methods. Traditional validation approaches couldn't catch subtle causal misunderstandings. I developed an inverse simulation approach that works backward from desired outcomes to validate whether the learned causal model supports achieving those outcomes.

class InverseSimulationVerifier:
    def __init__(self, causal_model, target_outcomes):
        self.causal_model = causal_model
        self.target_outcomes = target_outcomes  # e.g., optimal growth conditions

    def verify_policy(self, policy, max_iterations=1000):
        """Verify if policy can achieve target outcomes through inverse simulation"""
        violations = []

        for target in self.target_outcomes:
            # Work backward from target outcome to required interventions
            required_interventions = self._inverse_simulate(target)

            # Check if policy would produce these interventions
            policy_interventions = self._simulate_policy_forward(policy)

            # Compare and identify violations
            if not self._interventions_compatible(
                required_interventions,
                policy_interventions
            ):
                violations.append({
                    'target': target,
                    'required': required_interventions,
                    'actual': policy_interventions
                })

        return violations

    def _inverse_simulate(self, target_outcome):
        """Inverse simulation: Find interventions that lead to target outcome"""
        # Use do-calculus to work backward through causal graph
        interventions = {}

        # Start from target variable and trace back through parents
        target_var = list(target_outcome.keys())[0]
        target_value = target_outcome[target_var]

        # Find minimal intervention set using causal paths
        causal_paths = self._find_causal_paths(target_var)

        for path in causal_paths:
            # For each parent in path, compute required value
            for parent in reversed(path[:-1]):
                required_val = self._compute_required_value(
                    parent,
                    target_var,
                    target_value
                )
                if parent not in interventions:
                    interventions[parent] = required_val

        return interventions

    def _find_causal_paths(self, target_variable):
        """Find all directed paths to target variable in causal graph"""
        # Implementation of path finding in DAG
        paths = []
        # ... path discovery logic
        return paths
Enter fullscreen mode Exit fullscreen mode

Real-World Applications: Sustainable Aquaculture Monitoring

Case Study: Norwegian Salmon Farming

During my experimentation with real aquaculture data from Norwegian salmon farms, I applied the causal RL framework to optimize multiple objectives simultaneously:

  1. Fish growth maximization
  2. Feed conversion ratio minimization
  3. Disease outbreak prevention
  4. Environmental impact reduction

The system integrated data from:

  • IoT sensors (temperature, oxygen, pH, salinity)
  • Underwater cameras (fish behavior analysis)
  • Feeding systems (feed distribution patterns)
  • Environmental monitors (currents, weather, water quality)

One interesting finding was that the causal model discovered a non-obvious relationship: moderate increases in water current (previously avoided to reduce energy costs) actually improved oxygen distribution and reduced localized waste accumulation, leading to better overall growth with lower disease incidence.

Implementation Architecture

class SustainableAquacultureMonitor:
    def __init__(self, farm_config):
        self.sensors = self._initialize_sensors(farm_config)
        self.causal_discoverer = AquacultureCausalDiscoverer()
        self.rl_agent = CausalRLAgent()
        self.verifier = InverseSimulationVerifier()

        # Multi-objective reward function
        self.reward_fn = MultiObjectiveReward(
            weights={
                'growth_rate': 0.4,
                'feed_efficiency': 0.3,
                'health_score': 0.2,
                'environmental_impact': 0.1
            }
        )

    def run_monitoring_cycle(self):
        """Complete monitoring and optimization cycle"""
        # 1. Collect sensor data
        current_state = self._collect_sensor_data()

        # 2. Update causal model (online learning)
        if self._enough_new_data():
            updated_model = self.causal_discoverer.update_online(
                current_state
            )
            self.rl_agent.update_causal_model(updated_model)

        # 3. Select optimal actions using causal RL
        actions = self.rl_agent.select_action(
            current_state,
            use_counterfactual=True
        )

        # 4. Verify actions won't violate sustainability constraints
        violations = self.verifier.verify_actions(actions)
        if violations:
            actions = self._apply_safety_corrections(actions, violations)

        # 5. Execute actions and observe outcomes
        self._execute_actions(actions)

        # 6. Learn from outcomes
        next_state = self._collect_sensor_data()
        reward = self.reward_fn.calculate(current_state, actions, next_state)
        self.rl_agent.update_policy(current_state, actions, reward, next_state)

        return actions, reward, violations

class MultiObjectiveReward:
    def calculate(self, state, actions, next_state):
        """Calculate multi-objective reward for sustainable aquaculture"""
        rewards = {}

        # Growth reward (based on estimated biomass increase)
        rewards['growth'] = self._calculate_growth_reward(state, next_state)

        # Feed efficiency reward
        rewards['feed_efficiency'] = self._calculate_feed_efficiency(
            actions['feed_amount'],
            rewards['growth']
        )

        # Health reward (based on disease indicators)
        rewards['health'] = self._calculate_health_reward(next_state)

        # Environmental impact reward
        rewards['environment'] = self._calculate_environmental_impact(
            actions,
            next_state
        )

        # Weighted sum with sustainability constraints
        total_reward = sum(
            weight * rewards[objective]
            for objective, weight in self.weights.items()
        )

        # Apply penalty for constraint violations
        if self._detect_constraint_violation(next_state):
            total_reward -= self.constraint_penalty

        return total_reward
Enter fullscreen mode Exit fullscreen mode

Challenges and Solutions: Lessons from the Trenches

Challenge 1: Sparse and Noisy Sensor Data

While exploring real aquaculture datasets, I encountered severe data quality issues. Sensors fail, biofouling corrupts measurements, and communication dropouts create gaps. My solution was a robust causal imputation method:

class CausalImputation:
    def impute_missing_values(self, data, causal_graph):
        """Impute missing values using causal relationships"""
        imputed_data = data.copy()

        for column in data.columns:
            missing_mask = data[column].isna()
            if missing_mask.any():
                # Find causal parents of this variable
                parents = self._get_causal_parents(column, causal_graph)

                if parents:
                    # Use causal relationships for imputation
                    imputed_values = self._causal_impute(
                        column,
                        parents,
                        data
                    )
                    imputed_data.loc[missing_mask, column] = imputed_values
                else:
                    # Fall back to temporal imputation
                    imputed_data[column] = imputed_data[column].interpolate()

        return imputed_data

    def _causal_impute(self, target_var, parent_vars, data):
        """Impute using causal relationships with parents"""
        # Train a model on complete cases
        complete_cases = data.dropna(subset=[target_var] + parent_vars)

        if len(complete_cases) > 0:
            X = complete_cases[parent_vars]
            y = complete_cases[target_var]

            model = self._train_causal_model(X, y)

            # Predict missing values
            missing_cases = data[data[target_var].isna()]
            if not missing_cases.empty:
                X_missing = missing_cases[parent_vars]
                predictions = model.predict(X_missing)
                return predictions

        return None
Enter fullscreen mode Exit fullscreen mode

Challenge 2: Non-Stationarity and Concept Drift

Aquaculture environments exhibit seasonal patterns, growth-dependent dynamics, and equipment aging effects. Through studying time-series causal methods, I implemented a change point detection and model adaptation system:

class AdaptiveCausalModel:
    def __init__(self, detection_threshold=0.05):
        self.current_model = None
        self.model_history = []
        self.detection_threshold = detection_threshold

    def detect_concept_drift(self, new_data):
        """Detect changes in causal relationships"""
        if self.current_model is None:
            return True

        # Compare conditional independence relationships
        old_ci = self._extract_ci_relationships(self.current_model)
        new_model = self._learn_from_data(new_data)
        new_ci = self._extract_ci_relationships(new_model)

        # Calculate divergence
        divergence = self._calculate_ci_divergence(old_ci, new_ci)

        return divergence > self.detection_threshold

    def adapt_model(self, new_data, detection_result):
        """Adapt causal model to detected changes"""
        if detection_result:
            # Significant drift detected - learn new model
            new_model = self._learn_from_data(new_data)
            self.model_history.append({
                'model': self.current_model,
                'timestamp': datetime.now(),
                'drift_magnitude': detection_result
            })
            self.current_model = new_model
        else:
            # Incremental update
            self._online_update(new_data)

        return self.current_model
Enter fullscreen mode Exit fullscreen mode

Challenge 3: Explainability for Aquaculture Operators

One realization from working with fish farm operators was that they distrusted "black box" AI recommendations. My solution was a causal explanation generator:


python
class CausalExplanationGenerator:
    def generate_explanation(self, state, action, causal_model):
        """Generate human-readable explanations for AI recommendations"""
        explanation = {
            'primary_reason': None,
            'causal_paths': [],
            'counterfactual_comparisons': [],
            'risk_assessment': None
        }

        # Extract causal path to key outcomes
        target_outcomes = ['fish_growth', 'disease_risk', 'feed_efficiency']

        for target in target_outcomes:
            paths = self._find_causal_paths(action, target, causal_model)
            if paths:
                explanation['causal_paths'].extend(paths)

        # Generate counterfactual comparisons
        alternative_actions = self._generate_alternatives(action)
        for alt_action in alternative_actions:
            comparison = self._compare_counterfactuals(
                action,
                alt_action,
                state,
                causal_model
            )
            explanation['counterfactual_comparisons'].append(comparison)

        # Identify primary reason based on strongest causal effect
Enter fullscreen mode Exit fullscreen mode

Top comments (0)