DEV Community

galp76
galp76

Posted on

De Keras a Bare-Metal C++: Construyendo un motor de inferencia dentro de un Arduino Uno (Parte 3)

En la Parte 2 de esta serie, logramos entrenar exitosamente una red neuronal recurrente (LSTM) en Python, capaz de predecir la falla de una turbina aeroespacial con un error de apenas 10 vuelos.

El modelo era un éxito en la nube. Pero en el mundo del mantenimiento industrial y el Edge Computing, no siempre tenemos servidores disponibles. Quería llevar a MAJN (mi red neuronal) al mundo físico. Así que tomé una decisión de hardware deliberadamente extrema para forzar mis habilidades en C++: desplegaría el modelo dentro de un humilde Arduino Uno.

Aquí chocamos de frente con la física del hardware: nuestro modelo en Keras tenía 14,665 parámetros en formato de punto flotante de 32 bits (float32), eso equivale a unos 60 Kilobytes. El cerebro del Arduino Uno (el microcontrolador ATmega328P) tiene 32 KB de memoria Flash (disco duro) y apenas 2 KB de memoria RAM.

En este artículo, documentaré paso a paso cómo logré meter esta red neuronal en el microcontrolador desarrollando un motor de inferencia LSTM en C++ puro desde cero.

La "Dieta" Matemática: Cuantización Lineal a 8-bits (Quantization)

En Python, nuestro modelo operaba con números decimales de 32 bits (float32), pero el microcontrolador del Arduino Uno carece de una Unidad de Punto Flotante (FPU) por hardware. Obligarlo a procesar decimales mediante emulación por software destruiría su rendimiento y saturaría sus 2KB de RAM al instante. La solución técnica fue aplicar Cuantización (Quantization): una transformación matemática donde escalamos los más de 14,000 pesos decimales para convertirlos en números enteros de un solo byte (int8_t). Esto reduciría el peso de la red a 14.6 KB, perfecto para la Flash de 32 KB.

Aquí está el script en Python que extrae los pesos de Keras, busca el valor absoluto máximo de toda la red, y calcula un Factor de Escala:

# Extracción y Cuantización en Python
pesos_keras = modelo.get_weights()
todos_los_pesos = np.concatenate([p.flatten() for p in pesos_keras])
max_abs = np.max(np.abs(todos_los_pesos))

# Buscamos encajar el peso más grande en el límite de un int8_t (127)
factor_escala = 126.0 / max_abs

# Multiplicamos la matriz, redondeamos y convertimos a entero de 1 byte
matriz_cuantizada = np.round(matriz_peso * factor_escala).astype(np.int8)
Enter fullscreen mode Exit fullscreen mode

Con esto ya tenemos las proporciones matemáticas exactas, pero comprimidas en números enteros minúsculos.

Siguiente reto: el muro de la Arquitectura Harvard (PROGMEM)

Las computadoras modernas usan la arquitectura Von Neumann (datos y programas comparten la RAM). El ATmega328P usa Arquitectura Harvard (la Flash y la RAM están separadas físicamente).

Si en C++ estándar declaras un arreglo global enorme (int8_t pesos[14665]), el compilador lo guarda en la Flash, pero al encender la placa, intenta copiarlo todo a la RAM para trabajar. Con 2 KB de RAM, esto provoca un Stack Overflow instantáneo y el Arduino muere.

La solución en Bare-Metal AVR C++ es usar la directiva PROGMEM. Esto "ancla" el arreglo a la Flash de forma permanente.
Pero esto genera un nuevo problema: en C++, un puntero normal no puede leer la memoria Flash. Para extraer nuestros pesos en tiempo de ejecución, tuvimos que usar aritmética de punteros y llamadas a bajo nivel con pgm_read_byte_near():

// El arreglo vive exclusivamente en los 32KB de Flash
const int8_t matriz_pesos_0[2800] PROGMEM = {12, -45, 89, ...};

// Para leerlo usando un índice (offset)
int8_t peso = (int8_t)pgm_read_byte_near(matriz_pesos_0 + indice);
Enter fullscreen mode Exit fullscreen mode

Álgebra lineal sin decimales: evitando el desbordamiento

Nuestro Arduino no tiene Unidad de Punto Flotante (FPU) ni librerías como NumPy. Tuvimos que escribir el producto punto (multiplicación de matrices) usando bucles for anidados.

Aquí el peligro era matemático: Si multiplicas un sensor cuantizado (ej. 100) por un peso cuantizado (ej. 50), el resultado es 5000. ¡Eso no cabe en los 8 bits de un int8_t! La memoria "daría la vuelta" arrojando basura negativa.
Para evitarlo, la acumulación en el código de C++ forzosamente se hizo usando un tipo de dato más grande (int32_t):

int32_t acumulador = 0;
// Multiplicación en 32 bits para evitar overflow de memoria
acumulador += (int32_t)sensor * (int32_t)peso;
Enter fullscreen mode Exit fullscreen mode

Aproximando las funciones de activación (Hard Tanh & Hard Sigmoid)

Una celda LSTM requiere funciones trigonométricas continuas (Sigmoide y Tangente Hiperbólica). Ejecutar la función exp() en un microcontrolador de 8 bits destruye el rendimiento por los ciclos de reloj que requiere.

La solución en el mundo del TinyML es usar "Hard Activations" (Aproximaciones lineales). Reemplazamos la compleja curva sigmoide por funciones de clipping (topes) usando solo sumas y divisiones sobre nuestra variable ESCALA constante:

int32_t hard_sigmoid_8bit(int32_t x) {
  // Aproximación de la Sigmoide adaptada a números enteros escalados
  int32_t sig = (x / 2) + (ESCALA / 2);
  if (sig > ESCALA) return ESCALA;
  if (sig < 0) return 0;
  return sig;
}
Enter fullscreen mode Exit fullscreen mode

El Motor LSTM en C++ y la Integración Final

Con todas las piezas listas, programamos el bucle del LSTM. Creé dos arreglos en la memoria RAM (h_estado y c_estado) para conservar el contexto temporal entre ciclos, y programé las cuatro compuertas (Forget, Input, Cell, Output).

Para que no quede ninguna duda de cómo se ve una Red Neuronal Recurrente escrita "a bajo nivel", aquí está el código completo que corrió dentro del Arduino Uno.

En este bloque pueden observar las aproximaciones a las funciones de activación (hard_tanh y hard_sigmoid), los arreglos globales para mantener la memoria temporal de los ciclos de la turbina (h_estado y c_estado), y el cálculo algebraico de las cuatro compuertas matemáticas del LSTM leyendo los pesos directamente desde la memoria Flash con PROGMEM:

#include "majn_weights.h"

const int PIN_ALARMA = 13;
const int32_t ESCALA = (int32_t)FACTOR_ESCALA;

// --- MEMORIA DEL LSTM (El "Estado" que viaja en el tiempo) ---
int32_t h_estado[50]; // Hidden State (Estado Oculto)
int32_t c_estado[50]; // Cell State (Memoria a largo plazo)

// --- FUNCIONES DE ACTIVACIÓN CUANTIZADAS ---
int32_t hard_tanh_8bit(int32_t x) {
  if (x > ESCALA) return ESCALA;
  if (x < -ESCALA) return -ESCALA;
  return x;
}

int32_t hard_sigmoid_8bit(int32_t x) {
  int32_t sig = (x / 2) + (ESCALA / 2);
  if (sig > ESCALA) return ESCALA;
  if (sig < 0) return 0;
  return sig;
}

// Función que limpia la "mente" de la red para un motor nuevo
void resetear_memoria_turbina() {
  for(int i = 0; i < 50; i++) {
    h_estado[i] = 0;
    c_estado[i] = 0;
  }
}

void setup() {
  Serial.begin(9600);
  pinMode(PIN_ALARMA, OUTPUT);
  digitalWrite(PIN_ALARMA, LOW);

  resetear_memoria_turbina();
  while(Serial.available()) Serial.read(); 
}

void loop() {
  if (Serial.available() > 0) {
    // Protocolo de comunicación con el "Gemelo Digital" en Python
    char comando = Serial.read();

    if (comando == 'R') {
      // Python pide resetear porque empezó un motor nuevo
      resetear_memoria_turbina();
      digitalWrite(PIN_ALARMA, LOW);
      Serial.println("RESET_OK");
    } 
    else if (comando == 'D') {
      // Python envía datos: 14 sensores cuantizados a 8-bits
      while(Serial.available() < 14) { /* Esperamos recepción */ }

      int8_t sensores[14];
      Serial.readBytes((char*)sensores, 14);

      // =======================================================
      // EL CORAZÓN DEL LSTM (Matemática Bare-Metal)
      // =======================================================
      int32_t h_nuevo[50]; 
      int32_t c_nuevo[50];

      for(int u = 0; u < 50; u++) { 
        // 1. Extraemos los Bias (Sesgos) de la Flash
        int32_t comp_i = (int8_t)pgm_read_byte_near(matriz_pesos_2 + u);       
        int32_t comp_f = (int8_t)pgm_read_byte_near(matriz_pesos_2 + 50 + u);  
        int32_t comp_c = (int8_t)pgm_read_byte_near(matriz_pesos_2 + 100 + u); 
        int32_t comp_o = (int8_t)pgm_read_byte_near(matriz_pesos_2 + 150 + u); 

        // 2. Multiplicamos Entradas x Pesos (W * x)
        int32_t sum_in_i = 0, sum_in_f = 0, sum_in_c = 0, sum_in_o = 0;
        for(int x = 0; x < 14; x++) {
          int32_t entrada = sensores[x];
          sum_in_i += entrada * (int8_t)pgm_read_byte_near(matriz_pesos_0 + (x * 200) + u);
          sum_in_f += entrada * (int8_t)pgm_read_byte_near(matriz_pesos_0 + (x * 200) + 50 + u);
          sum_in_c += entrada * (int8_t)pgm_read_byte_near(matriz_pesos_0 + (x * 200) + 100 + u);
          sum_in_o += entrada * (int8_t)pgm_read_byte_near(matriz_pesos_0 + (x * 200) + 150 + u);
        }
        comp_i += sum_in_i / 64;
        comp_f += sum_in_f / 64;
        comp_c += sum_in_c / 64;
        comp_o += sum_in_o / 64;

        // 3. Multiplicamos Estado Anterior x Pesos Recurrentes (U * h)
        int32_t sum_h_i = 0, sum_h_f = 0, sum_h_c = 0, sum_h_o = 0;
        for(int hu = 0; hu < 50; hu++) {
          int32_t h_pasado = h_estado[hu];
          sum_h_i += h_pasado * (int8_t)pgm_read_byte_near(matriz_pesos_1 + (hu * 200) + u);
          sum_h_f += h_pasado * (int8_t)pgm_read_byte_near(matriz_pesos_1 + (hu * 200) + 50 + u);
          sum_h_c += h_pasado * (int8_t)pgm_read_byte_near(matriz_pesos_1 + (hu * 200) + 100 + u);
          sum_h_o += h_pasado * (int8_t)pgm_read_byte_near(matriz_pesos_1 + (hu * 200) + 150 + u);
        }
        comp_i += sum_h_i / ESCALA;
        comp_f += sum_h_f / ESCALA;
        comp_c += sum_h_c / ESCALA;
        comp_o += sum_h_o / ESCALA;

        // 4. Pasamos por las Funciones de Activación
        int32_t act_i = hard_sigmoid_8bit(comp_i);
        int32_t act_f = hard_sigmoid_8bit(comp_f);
        int32_t act_c = hard_tanh_8bit(comp_c);
        int32_t act_o = hard_sigmoid_8bit(comp_o);

        // 5. Ecuaciones finales de actualización de la celda LSTM
        c_nuevo[u] = (act_f * c_estado[u]) / ESCALA + (act_i * act_c) / ESCALA;
        h_nuevo[u] = (act_o * hard_tanh_8bit(c_nuevo[u])) / ESCALA;
      }

      // Guardamos la memoria para el siguiente ciclo
      for(int i = 0; i < 50; i++) {
        h_estado[i] = h_nuevo[i];
        c_estado[i] = c_nuevo[i];
      }

      // =======================================================
      // CAPAS DENSAS Y RUL PREDICTIVO
      // =======================================================
      int32_t capa_densa[32];
      for(int d = 0; d < 32; d++) {
        capa_densa[d] = (int8_t)pgm_read_byte_near(matriz_pesos_4 + d);
        int32_t sum_densa = 0;
        for(int u = 0; u < 50; u++) {
          sum_densa += h_estado[u] * (int8_t)pgm_read_byte_near(matriz_pesos_3 + (u * 32) + d);
        }
        capa_densa[d] += sum_densa / ESCALA;
        if (capa_densa[d] < 0) capa_densa[d] = 0; // Activación ReLU
      }

      int32_t salida_final = (int8_t)pgm_read_byte_near(matriz_pesos_6 + 0);
      int32_t sum_salida = 0;
      for(int d = 0; d < 32; d++) {
        sum_salida += capa_densa[d] * (int8_t)pgm_read_byte_near(matriz_pesos_5 + d);
      }
      salida_final += sum_salida / ESCALA;

      // Descuantización para obtener ciclos reales
      float rul_predicho = (float)salida_final / FACTOR_ESCALA;
      rul_predicho = abs(rul_predicho); 
      if (rul_predicho > 125) rul_predicho = 125; 

      // Lógica de Alarma Física Industrial
      if (rul_predicho < 15.0) {
         digitalWrite(PIN_ALARMA, HIGH); // Peligro Inminente
         Serial.print("CRITICO,");
      } else {
         digitalWrite(PIN_ALARMA, LOW);
         Serial.print("OK,");
      }
      Serial.println(rul_predicho);
    }
  }
}
Enter fullscreen mode Exit fullscreen mode

Para que este motor de inferencia en el Arduino pudiera "ver" los datos de la turbina, necesitaba un programa en la computadora que actuara como un Gemelo Digital.

Escribí este script en Python que utiliza la librería pyserial. Su trabajo es cargar los datos históricos de la NASA, normalizarlos usando nuestro escalador guardado en la Parte 1, cuantizarlos a números enteros multiplicándolos por 64, y enviarlos byte a byte a través del cable USB.

El código inyecta vuelo tras vuelo, y se queda esperando a que el Arduino procese la matemática y devuelva el diagnóstico:

import serial
import pandas as pd
import joblib
import time

# --- CONFIGURACIÓN ---
PUERTO = '/dev/ttyACM0'  # Reemplazar por COM3 en Windows
BAUDIOS = 9600

print("1. Cargando datos de la NASA y el Escalador...")
# (Aquí omitimos el código de carga de Pandas para mantenerlo conciso)
df = cargar_y_limpiar_datos('train_FD001.txt')
scaler = joblib.load('escalador_sensores_turbina.pkl')
df[columnas_sensores] = scaler.transform(df[columnas_sensores])

# Seleccionamos el Motor 1 para la prueba
motor_1 = df[df['id_motor'] == 1]

print("2. Iniciando enlace Serial con MAJN (Arduino)...")
conexion = serial.Serial(PUERTO, BAUDIOS, timeout=2)
time.sleep(2) # Pausa obligatoria para que el Arduino despierte

# Le indicamos al Arduino que limpie su memoria RAM (Nuevo Motor)
conexion.write(b'R')
conexion.readline() # Esperamos el "RESET_OK"

print("3. Inyectando datos de vuelo en tiempo real...")

for index, row in motor_1.iterrows():
    ciclo = int(row['ciclo'])
    sensores_float = row[columnas_sensores].values

    # --- LA CUANTIZACIÓN (El puente entre Python y C++) ---
    # Multiplicamos los floats (0.0 a 1.0) por 64 para enviarlos como enteros
    sensores_int = [int(val * 64) for val in sensores_float]

    # Empaquetamos: La letra 'D' (Data) seguida de los 14 bytes de los sensores
    trama_bytes = b'D' + bytes([max(min(s, 127), -128) & 0xFF for s in sensores_int])

    # Inyectamos los datos por el cable USB
    conexion.write(trama_bytes)

    # Leemos la predicción que calculó el Arduino
    respuesta = conexion.readline().decode('utf-8').strip()

    if respuesta:
        estado, rul = respuesta.split(',')

        if estado == "OK":
            print(f"✈️ Vuelo {ciclo:03d} | Estado: OK | Vida útil: {float(rul):05.2f} ciclos")
        else:
            print(f"🔥 Vuelo {ciclo:03d} | Estado: CRÍTICO | Vida útil: {float(rul):05.2f} ciclos")
            print(">>> ALERTA ROJA: EL LED DEL ARDUINO SE ENCENDIÓ. PARADA DE PLANTA <<<")
            break # Detenemos la simulación para salvar el motor

    time.sleep(0.1) # Simulamos el tiempo entre un vuelo y otro

conexion.close()
Enter fullscreen mode Exit fullscreen mode

La prueba de estrés industrial (100 motores):

Hice pasar los 20,631 vuelos de los 100 motores del Test Set por el cable USB a máxima velocidad. Aquí los resultados del Arduino:
0% de Falsos Negativos: El sistema detectó la falla inminente en los 100 motores sin que se le escapara uno solo.
36.3 vuelos de anticipación: En promedio, el Arduino encendió la alarma 36 vuelos antes de la muerte real del motor. El margen logístico perfecto para programar una parada de planta y solicitar repuestos sin caer en urgencias.

Conclusión

Este proyecto, que comenzó como una curiosidad con Google Gemini, me llevó desde la limpieza de datos en Pandas hasta el Bare-Metal en C++. Me demostró que para hacer Mantenimiento Predictivo; con ingeniería, matemáticas y entendiendo los límites del hardware, podemos llevar la Inteligencia Artificial a la base misma de la industria.

Top comments (0)