DEV Community

Cover image for Speculative Decoding: Cómo Acelerar LLMs 2.4x Sin Cambiar el Modelo
Abdessamad Ammi
Abdessamad Ammi

Posted on • Originally published at bcloud.consulting

Speculative Decoding: Cómo Acelerar LLMs 2.4x Sin Cambiar el Modelo

Publicado originalmente en bcloud.consulting

TL;DR

• Speculative Decoding acelera inferencia LLM 2-4x
• Sin cambiar modelo, sin reentrenamiento, calidad idéntica
• Usa modelo pequeño para draft, grande para verificación
• Caso real: 312ms → 130ms por token (58% reducción)
• Implementación en 2 días con frameworks existentes


El Cuello de Botella de Inferencia LLM

Generar texto con LLMs grandes es inherentemente lento debido a la naturaleza autoregresiva:

# Inferencia tradicional - LENTA
def traditional_inference(model, prompt, max_tokens=100):
    tokens = tokenize(prompt)

    for _ in range(max_tokens):
        # Cada token requiere forward pass completo
        logits = model.forward(tokens)  # 300ms for 70B model
        next_token = sample(logits)
        tokens.append(next_token)

        if next_token == EOS_TOKEN:
            break

    return detokenize(tokens)

# Tiempo total = num_tokens * latency_per_token
# 20 tokens * 300ms = 6 seconds!
Enter fullscreen mode Exit fullscreen mode

Cómo Funciona Speculative Decoding

La idea brillante: usar un modelo pequeño rápido para "adivinar" los siguientes tokens, luego verificar en batch con el modelo grande.

class SpeculativeDecoder:
    def __init__(self, draft_model, target_model, gamma=5):
        self.draft = draft_model    # Modelo pequeño (7B)
        self.target = target_model  # Modelo grande (70B)
        self.gamma = gamma          # Tokens a especular

    def generate(self, prompt, max_tokens=100):
        tokens = tokenize(prompt)

        while len(tokens) < max_tokens:
            # 1. Draft model genera γ tokens rápidamente
            draft_tokens = self.speculate(tokens)

            # 2. Target model verifica todos en UN solo pass
            verified_tokens = self.verify(tokens, draft_tokens)

            # 3. Añadir tokens verificados
            tokens.extend(verified_tokens)

            if tokens[-1] == EOS_TOKEN:
                break

        return detokenize(tokens)

    def speculate(self, context):
        """Modelo draft genera γ tokens especulativos"""
        spec_tokens = []
        current = context.copy()

        for _ in range(self.gamma):
            # Draft es 10x más rápido (30ms vs 300ms)
            logits = self.draft.forward(current)
            next_token = sample(logits)
            spec_tokens.append(next_token)
            current.append(next_token)

        return spec_tokens

    def verify(self, context, draft_tokens):
        """Target verifica todos los tokens en paralelo"""
        # Preparar secuencias para verificación paralela
        sequences = []
        for i in range(len(draft_tokens) + 1):
            seq = context + draft_tokens[:i]
            sequences.append(seq)

        # UN SOLO forward pass para todas las secuencias
        all_logits = self.target.forward_batch(sequences)

        # Verificar cada token especulativo
        accepted_tokens = []
        for i, draft_token in enumerate(draft_tokens):
            target_logits = all_logits[i]
            target_distribution = softmax(target_logits)

            # Aceptar si draft es suficientemente probable
            if self.accept_token(draft_token, target_distribution):
                accepted_tokens.append(draft_token)
            else:
                # Rechazar y samplear del target
                correct_token = sample(target_logits)
                accepted_tokens.append(correct_token)
                break  # Stop at first rejection

        return accepted_tokens
Enter fullscreen mode Exit fullscreen mode

Implementación Práctica con vLLM

# vLLM tiene soporte nativo para speculative decoding
from vllm import LLM, SamplingParams

# Configurar speculative decoding
llm = LLM(
    model="meta-llama/Llama-2-70b-hf",      # Target model
    speculative_model="meta-llama/Llama-2-7b-hf",  # Draft model
    num_speculative_tokens=5,                # γ parameter
    use_v2_block_manager=True
)

# Usar como normal - speedup automático
sampling_params = SamplingParams(
    temperature=0.8,
    top_p=0.95,
    max_tokens=100
)

# Genera con speculative decoding transparentemente
outputs = llm.generate(
    ["Explain quantum computing"],
    sampling_params
)

# Métricas de performance
print(f"Tokens/second: {outputs[0].metrics.tokens_per_second}")
print(f"Acceptance rate: {outputs[0].metrics.acceptance_rate}")
Enter fullscreen mode Exit fullscreen mode

Optimización de Parámetros

def optimize_gamma(draft_model, target_model, test_prompts):
    """Encuentra γ óptimo para tu setup"""
    results = {}

    for gamma in range(3, 10):
        decoder = SpeculativeDecoder(
            draft_model,
            target_model,
            gamma=gamma
        )

        total_time = 0
        total_tokens = 0
        acceptance_rates = []

        for prompt in test_prompts:
            start = time.time()
            output = decoder.generate(prompt)
            elapsed = time.time() - start

            total_time += elapsed
            total_tokens += len(output)
            acceptance_rates.append(decoder.last_acceptance_rate)

        results[gamma] = {
            'tokens_per_second': total_tokens / total_time,
            'avg_acceptance_rate': np.mean(acceptance_rates)
        }

    # Mejor γ maximiza tokens/second
    optimal_gamma = max(results, key=lambda k: results[k]['tokens_per_second'])
    return optimal_gamma, results

# Típicamente γ=4-6 es óptimo
Enter fullscreen mode Exit fullscreen mode

Benchmarks Reales

# Setup de prueba
benchmark_config = {
    'target_model': 'Llama-70B',
    'draft_model': 'Llama-7B',
    'hardware': 'A100 80GB',
    'batch_size': 1,
    'sequence_length': 512,
    'num_samples': 1000
}

# Resultados medidos
results = {
    'traditional': {
        'latency_per_token_ms': 312,
        'tokens_per_second': 3.2,
        'gpu_utilization': 0.45
    },
    'speculative_gamma_5': {
        'latency_per_token_ms': 130,
        'tokens_per_second': 7.7,
        'gpu_utilization': 0.78,
        'acceptance_rate': 0.73,
        'speedup': 2.4
    }
}

# Análisis de acceptance rate por posición
acceptance_by_position = {
    1: 0.92,  # First token usually accepted
    2: 0.81,
    3: 0.71,
    4: 0.62,
    5: 0.51   # Fifth token 50/50
}
Enter fullscreen mode Exit fullscreen mode

Caso de Estudio: Chatbot Enterprise

Situación inicial:

  • Modelo: Llama 70B
  • Usuarios concurrentes: 500 máximo
  • Latencia P95: 8 segundos
  • Hardware: 8x A100 GPUs

Implementación Speculative Decoding:

# Configuración producción
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs

# Setup con speculative decoding
engine_args = AsyncEngineArgs(
    model="meta-llama/Llama-2-70b-chat-hf",
    speculative_model="meta-llama/Llama-2-7b-chat-hf",
    num_speculative_tokens=5,
    tensor_parallel_size=8,
    max_num_seqs=256,
    max_model_len=2048,
    gpu_memory_utilization=0.90
)

# Servidor async para high throughput
async def serve_chatbot():
    engine = AsyncLLMEngine.from_engine_args(engine_args)

    async def handle_request(prompt):
        sampling_params = SamplingParams(
            temperature=0.7,
            max_tokens=200,
            presence_penalty=0.1
        )

        async for output in engine.generate(prompt, sampling_params):
            yield output.text
Enter fullscreen mode Exit fullscreen mode

Resultados:

  • Usuarios concurrentes: 1200 (+140%)
  • Latencia P95: 3.3 segundos (-59%)
  • Throughput: 2.4x mejora
  • Hardware: Mismo (8x A100)
  • ROI: Evitó compra de 12 GPUs adicionales

Cuándo NO Usar Speculative Decoding

def should_use_speculative_decoding(scenario):
    # No beneficia estos casos:
    if scenario['model_size'] < 10_000_000_000:  # <10B params
        return False, "Model already fast enough"

    if scenario['task'] == 'single_token_classification':
        return False, "No sequential generation"

    if scenario['draft_target_ratio'] > 0.3:
        return False, "Draft too large relative to target"

    if scenario['batch_size'] > 32:
        return False, "Better use continuous batching"

    # Ideal para:
    if scenario['model_size'] > 30_000_000_000:  # >30B params
        if scenario['latency_sensitive']:
            return True, "Perfect use case"

    return True, "Will likely benefit"
Enter fullscreen mode Exit fullscreen mode

Troubleshooting Común

class SpeculativeDecodingDebugger:
    def diagnose(self, metrics):
        issues = []

        if metrics['acceptance_rate'] < 0.5:
            issues.append({
                'problem': 'Low acceptance rate',
                'solutions': [
                    'Use larger draft model',
                    'Reduce gamma',
                    'Fine-tune draft on target distribution'
                ]
            })

        if metrics['speedup'] < 1.5:
            issues.append({
                'problem': 'Low speedup',
                'solutions': [
                    'Check draft model speed',
                    'Optimize batch processing',
                    'Profile GPU utilization'
                ]
            })

        if metrics['memory_usage'] > 0.95:
            issues.append({
                'problem': 'High memory usage',
                'solutions': [
                    'Use quantized draft model',
                    'Reduce batch size',
                    'Enable memory efficient attention'
                ]
            })

        return issues
Enter fullscreen mode Exit fullscreen mode

Conclusiones

→ Speculative Decoding es plug-and-play para LLMs grandes
→ 2-4x speedup sin cambio en calidad
→ ROI inmediato evitando compra de hardware
→ Crítico para latency-sensitive applications
→ Soporte nativo en frameworks modernos (vLLM, TGI)


Artículo Completo

Este es un resumen. Para implementación completa:

👉 Lee el artículo completo

Incluye:

  • Código producción completo
  • Benchmarks detallados
  • Configuración óptima por modelo
  • Integración con serving frameworks

¿Has probado Speculative Decoding? Comparte resultados 👇

Top comments (0)