DEV Community

Rikin Patel
Rikin Patel

Posted on

Explainable Causal Reinforcement Learning for circular manufacturing supply chains with inverse simulation verification

Explainable Causal Reinforcement Learning for Circular Manufacturing Supply Chains

Explainable Causal Reinforcement Learning for circular manufacturing supply chains with inverse simulation verification

Introduction: The Broken Feedback Loop

During my research into sustainable manufacturing systems, I encountered a fundamental problem that changed my approach to AI in industrial applications. While exploring reinforcement learning applications for supply chain optimization, I realized that traditional RL models were making decisions that appeared optimal on paper but failed catastrophically in real-world circular manufacturing environments. The issue wasn't the algorithms themselves, but their inability to understand why certain decisions led to specific outcomes in complex, interconnected systems.

One particularly revealing experiment involved training a standard PPO agent on a simulated circular supply chain where materials flowed through production, use, recovery, and remanufacturing stages. The agent learned to maximize short-term profit beautifully, but its strategy involved systematically depleting recovery buffers and creating irreversible material losses—exactly the opposite of what circular manufacturing aims to achieve. This wasn't just a suboptimal policy; it was a fundamental misunderstanding of the system's causal structure.

Through studying causal inference papers and combining them with my hands-on RL experimentation, I discovered that the missing piece was explainable causal reasoning. Circular supply chains aren't just sequential processes—they're complex networks of interdependent causal relationships where actions have delayed, distributed, and sometimes counterintuitive effects. My exploration revealed that without explicit causal modeling, even sophisticated RL agents would inevitably exploit statistical correlations rather than understanding true causal mechanisms.

Technical Background: Beyond Correlation to Causation

The Causal Revolution in Reinforcement Learning

Traditional reinforcement learning operates on the Markov assumption: the next state depends only on the current state and action. While exploring causal inference literature, I found that this assumption breaks down spectacularly in circular supply chains where:

  1. Feedback loops create non-Markovian dependencies
  2. Delayed effects span multiple time steps
  3. Confounding variables create spurious correlations
  4. Interventions have distributed consequences

During my investigation of causal graphical models, I realized that we could represent circular supply chains as structural causal models (SCMs) where each node represents a supply chain component and edges represent causal relationships. This representation fundamentally changed how I approached the problem.

import networkx as nx
import numpy as np

class CircularSupplyChainSCM:
    def __init__(self):
        self.graph = nx.DiGraph()

        # Define causal structure for circular manufacturing
        self.graph.add_edges_from([
            ('raw_material_extraction', 'primary_production'),
            ('primary_production', 'product_assembly'),
            ('product_assembly', 'distribution'),
            ('distribution', 'customer_use'),
            ('customer_use', 'product_return'),
            ('product_return', 'disassembly'),
            ('disassembly', 'material_recovery'),
            ('material_recovery', 'secondary_production'),
            ('secondary_production', 'product_assembly'),  # Circular link!
            ('remanufacturing_policy', 'product_return'),
            ('remanufacturing_policy', 'disassembly'),
            ('remanufacturing_policy', 'material_recovery')
        ])

    def intervene(self, node, value):
        """Perform do-calculus intervention"""
        # Cut incoming edges to intervened node
        intervened_graph = self.graph.copy()
        intervened_graph.remove_edges_from(list(intervened_graph.in_edges(node)))

        # Set node to intervention value
        self.node_values[node] = value

        # Propagate effects through modified graph
        return self._propagate_effects(intervened_graph)
Enter fullscreen mode Exit fullscreen mode

Causal Reinforcement Learning Framework

While learning about causal RL, I discovered that we could extend traditional RL frameworks by incorporating causal discovery and inference. The key insight from my experimentation was that causal models allow us to:

  1. Distinguish correlation from causation in reward signals
  2. Predict effects of interventions without actually performing them
  3. Counterfactual reasoning to understand what would have happened
  4. Transfer learning across similar but different supply chain configurations
import torch
import torch.nn as nn

class CausalAttentionLayer(nn.Module):
    """Attention mechanism that learns causal relationships"""
    def __init__(self, input_dim, num_heads):
        super().__init__()
        self.attention = nn.MultiheadAttention(input_dim, num_heads)
        self.causal_mask = None

    def learn_causal_structure(self, sequences):
        """Learn causal relationships from temporal data"""
        # Use Granger causality or transfer entropy
        correlations = self._compute_transfer_entropy(sequences)

        # Threshold to create binary causal mask
        self.causal_mask = (correlations > 0.1).float()

    def forward(self, x):
        # Apply causal mask to attention weights
        attn_output, _ = self.attention(x, x, x,
                                      attn_mask=self.causal_mask)
        return attn_output

class CausalQNetwork(nn.Module):
    """Q-network with explicit causal reasoning"""
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super().__init__()
        self.causal_layer = CausalAttentionLayer(state_dim, 4)
        self.value_stream = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.advantage_stream = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def compute_counterfactual(self, state, action, alternative_action):
        """What would Q be if we took alternative_action instead?"""
        current_q = self.forward(state, action)
        counterfactual_q = self.forward(state, alternative_action)
        return current_q, counterfactual_q, current_q - counterfactual_q
Enter fullscreen mode Exit fullscreen mode

Implementation Details: Building Explainable Causal Agents

Inverse Simulation Verification System

One of the most challenging aspects I encountered during my experimentation was verifying that the learned causal models actually reflected reality. Through studying verification methodologies, I developed an inverse simulation approach that works backward from outcomes to validate causal claims.

class InverseSimulationVerifier:
    def __init__(self, forward_model, causal_model):
        self.forward_model = forward_model  # High-fidelity simulator
        self.causal_model = causal_model    # Learned causal model

    def verify_causal_claim(self, cause_node, effect_node,
                          intervention_value, expected_effect):
        """
        Verify if intervening on cause_node with intervention_value
        produces expected_effect on effect_node
        """
        # 1. Run forward simulation with intervention
        forward_result = self.forward_model.simulate_intervention(
            cause_node, intervention_value
        )
        actual_effect = forward_result[effect_node]

        # 2. Query causal model for predicted effect
        predicted_effect = self.causal_model.predict_effect(
            cause_node, effect_node, intervention_value
        )

        # 3. Compute discrepancy
        discrepancy = abs(actual_effect - predicted_effect)

        # 4. Generate explanation if discrepancy is high
        if discrepancy > 0.1:
            explanation = self._generate_counterfactual_explanation(
                cause_node, effect_node,
                actual_effect, predicted_effect
            )
            return False, discrepancy, explanation

        return True, discrepancy, "Causal claim verified"

    def _generate_counterfactual_explanation(self, cause, effect,
                                           actual, predicted):
        """Generate human-readable explanation for discrepancy"""
        # Analyze mediating variables
        mediators = self.causal_model.find_mediators(cause, effect)

        # Check for unobserved confounders
        confounders = self._detect_confounders(cause, effect)

        explanation = f"""
        Discrepancy detected between actual ({actual:.3f}) and
        predicted ({predicted:.3f}) effect of {cause} on {effect}.

        Possible reasons:
        1. Mediating variables: {mediators}
        2. Unobserved confounders: {confounders}
        3. Non-linear interaction effects
        4. Context-dependent causal strength

        Recommended investigation:
        - Perform do-calculus with backdoor adjustment
        - Test for effect modification by context
        - Check for time-varying confounding
        """
        return explanation
Enter fullscreen mode Exit fullscreen mode

Causal Discovery from Supply Chain Data

During my research into causal discovery algorithms, I found that combining multiple methods yielded the most reliable results for circular supply chains:

import pandas as pd
from causalnex.structure import StructureModel
from causalnex.discovery import PC, Notears

class HybridCausalDiscoverer:
    def __init__(self):
        self.pc_model = PC()
        self.notears_model = Notears()
        self.ensemble_models = []

    def discover_from_timeseries(self, data: pd.DataFrame):
        """Discover causal structure from observational data"""

        # Method 1: Constraint-based (PC algorithm)
        pc_graph = self.pc_model.learn_structure(data)

        # Method 2: Score-based (NOTEARS)
        notears_graph = self.notears_model.learn_structure(data)

        # Method 3: Granger causality for time series
        granger_graph = self._compute_granger_causality(data)

        # Ensemble the results
        consensus_graph = self._ensemble_graphs(
            [pc_graph, notears_graph, granger_graph]
        )

        # Validate with domain knowledge
        validated_graph = self._apply_domain_constraints(
            consensus_graph, 'circular_manufacturing'
        )

        return validated_graph

    def _compute_granger_causality(self, data, max_lag=5):
        """Compute Granger causality for time series"""
        from statsmodels.tsa.stattools import grangercausalitytests

        n_vars = data.shape[1]
        causality_matrix = np.zeros((n_vars, n_vars))

        for i in range(n_vars):
            for j in range(n_vars):
                if i != j:
                    test_result = grangercausalitytests(
                        data[[i, j]], maxlag=max_lag,
                        verbose=False
                    )
                    # Extract p-value for best lag
                    p_values = [test_result[lag][0]['ssr_ftest'][1]
                              for lag in range(1, max_lag+1)]
                    min_p_value = min(p_values)

                    if min_p_value < 0.05:  # Significant causality
                        causality_matrix[i, j] = 1 - min_p_value

        return nx.from_numpy_array(causality_matrix, create_using=nx.DiGraph)
Enter fullscreen mode Exit fullscreen mode

Real-World Applications: Circular Manufacturing Case Study

Material Flow Optimization with Causal RL

In my experimentation with a real circular manufacturing dataset from an electronics remanufacturer, I implemented a causal RL agent that learned to optimize material recovery while maintaining production targets:

class CircularManufacturingEnv:
    """Environment for circular manufacturing supply chain"""

    def __init__(self):
        self.state_dim = 12  # Inventory levels, recovery rates, demand, etc.
        self.action_dim = 6   # Production rates, recovery investments, etc.

        # Causal relationships (learned or provided)
        self.causal_graph = self._initialize_causal_structure()

    def step(self, action):
        # Apply action
        new_state = self._apply_action(action)

        # Compute reward with causal attribution
        reward, attribution = self._compute_causal_reward(
            self.state, action, new_state
        )

        # Update causal model based on observed effects
        self._update_causal_model(self.state, action, new_state)

        return new_state, reward, False, {
            'causal_attribution': attribution,
            'counterfactuals': self._generate_counterfactuals(action)
        }

    def _compute_causal_reward(self, state, action, next_state):
        """Compute reward with causal attribution"""

        base_reward = (
            self._profit_reward(next_state) +
            self._sustainability_reward(next_state) -
            self._volatility_penalty(state, next_state)
        )

        # Causal attribution: which actions caused which outcomes
        attribution = {}
        for outcome in ['profit', 'recovery_rate', 'inventory_cost']:
            attribution[outcome] = self._attribute_to_actions(
                outcome, state, action, next_state
            )

        return base_reward, attribution

class CausalPPOAgent:
    """PPO agent with causal reasoning capabilities"""

    def __init__(self, env, causal_model):
        self.env = env
        self.causal_model = causal_model

        # Policy network with causal attention
        self.policy_net = CausalPolicyNetwork(
            env.state_dim, env.action_dim
        )

        # Value network for baseline
        self.value_net = CausalValueNetwork(env.state_dim)

    def select_action(self, state, explore=True):
        # Get action distribution from policy
        action_dist = self.policy_net(state)

        if explore:
            action = action_dist.sample()
        else:
            action = action_dist.mean

        # Generate causal explanation for action choice
        explanation = self._explain_action(state, action)

        return action, explanation

    def _explain_action(self, state, action):
        """Generate causal explanation for why this action was chosen"""

        # Compute expected causal effects
        effects = self.causal_model.predict_effects(state, action)

        # Find most influential factors
        influential = sorted(
            effects.items(),
            key=lambda x: abs(x[1]),
            reverse=True
        )[:3]

        explanation = f"Action selected to optimize:\n"
        for factor, effect in influential:
            if effect > 0:
                explanation += f"  • Increase {factor} by {effect:.2f}\n"
            else:
                explanation += f"  • Decrease {factor} by {abs(effect):.2f}\n"

        # Add counterfactual comparison
        alternative = self._best_alternative_action(state)
        comparison = self._compare_actions(state, action, alternative)
        explanation += f"\nCompared to alternative: {comparison}"

        return explanation
Enter fullscreen mode Exit fullscreen mode

Inverse Verification in Production

One particularly valuable insight from my hands-on implementation was the importance of continuous verification. I developed a system that constantly compares the causal model's predictions against actual outcomes and updates the model when discrepancies are detected:

class ContinuousCausalVerifier:
    """Continuously verify and update causal models"""

    def __init__(self, production_system, causal_model):
        self.production_system = production_system
        self.causal_model = causal_model
        self.discrepancy_log = []

    def monitor_and_update(self):
        """Continuous monitoring loop"""
        while True:
            # Collect recent production data
            recent_data = self.production_system.get_recent_data(
                hours=24
            )

            # Test causal claims
            for claim in self.causal_model.get_testable_claims():
                verified, discrepancy, explanation = \
                    self.verify_claim(claim, recent_data)

                if not verified:
                    self.discrepancy_log.append({
                        'claim': claim,
                        'discrepancy': discrepancy,
                        'explanation': explanation,
                        'timestamp': datetime.now()
                    })

                    # Trigger model update if discrepancy is large
                    if discrepancy > self.update_threshold:
                        self._update_causal_model(claim, recent_data)

            # Generate verification report
            self._generate_verification_report()

            time.sleep(self.check_interval)

    def verify_claim(self, claim, data):
        """Verify a specific causal claim against data"""

        # Extract cause, effect, and expected relationship
        cause, effect, expected = claim

        # Perform statistical tests
        test_results = {
            'difference_in_differences': self._did_test(cause, effect, data),
            'regression_discontinuity': self._rd_test(cause, effect, data),
            'instrumental_variables': self._iv_test(cause, effect, data)
        }

        # Consensus verification
        verified = sum(test_results.values()) >= 2  # At least 2 methods agree

        # Compute overall confidence
        confidence = np.mean(list(test_results.values()))

        return verified, 1 - confidence, test_results
Enter fullscreen mode Exit fullscreen mode

Challenges and Solutions: Lessons from Implementation

Challenge 1: Non-Stationarity in Circular Systems

During my experimentation, I discovered that circular supply chains are inherently non-stationary—the relationships between variables change over time as materials degrade, technologies evolve, and markets shift. Traditional RL algorithms assume stationarity and fail catastrophically in these environments.

Solution: I developed an adaptive causal discovery mechanism that continuously updates the causal graph:

class AdaptiveCausalModel:
    """Causal model that adapts to non-stationarity"""

    def __init__(self, initial_graph, adaptation_rate=0.1):
        self.current_graph = initial_graph
        self.adaptation_rate = adaptation_rate
        self.change_detector = ChangePointDetector()

    def update_based_on_evidence(self, new_data):
        """Update causal model if change is detected"""

        # Detect change points in relationships
        change_points = self.change_detector.detect(new_data)

        if change_points:
            # Re-learn causal structure from recent data
            recent_data = self._get_data_since_last_change()
            new_structure = self.discoverer.discover_from_timeseries(
                recent_data
            )

            # Blend old and new structures
            self.current_graph = self._blend_graphs(
                self.current_graph, new_structure,
                alpha=self.adaptation_rate
            )

            # Log the structural change
            self._log_structural_change(
                self.current_graph, new_structure
            )
Enter fullscreen mode Exit fullscreen mode

Challenge 2: Partial Observability

Circular supply chains often have unobserved variables—material quality degradation, hidden inventory, informal recovery channels. Through my research into causal inference with latent variables, I found that we could use

Top comments (0)