DEV Community

Prashant Nigam
Prashant Nigam

Posted on

The Magic of LoRA Fine-Tuning with MLX (Part 4)

This is where the magic happens! In this part, we will deep dive into LoRA (Low-Rank Adaptation) fine-tuning and use MLX to train our model with incredible efficiency on Apple Silicon.

Understanding LoRA: The Game-Changing Technique

Imagine you are a master chef who wants to learn a new cuisine. Instead of forgetting everything you know and starting from scratch, you add new techniques and flavor profiles to your existing knowledge. That's exactly what LoRA (Low-Rank Adaptation) does for language models.

The Traditional Fine-Tuning Problem

Traditional fine-tuning updates all 1.7 billion parameters of our model. This means:

  • ❌ Massive memory requirements
  • ❌ Slow training
  • ❌ Risk of "catastrophic forgetting" (losing general knowledge)
  • ❌ Large model files

The LoRA Solution

LoRA adds small "adapter" layers that learn new behaviors while keeping the original model frozen:

  • βœ… Minimal memory usage
  • βœ… Fast training
  • βœ… Preserves general knowledge
  • βœ… Tiny adapter file size
  • βœ… Can be combined or switched out easily

How LoRA Works Under the Hood

Think of the original model as a Swiss Army knife with all its tools welded in place. LoRA adds new attachments that can be snapped on or off.

MLX: Apple's Secret Weapon for AI

MLX is Apple's machine learning framework designed specifically for Apple Silicon. It's what makes our local fine-tuning possible and incredibly fast.

Why MLX is good for Local AI

  1. Unified Memory Architecture: M-series chips share memory between CPU and GPU, eliminating data transfer bottlenecks
  2. Optimized Computation: Hand-tuned for Apple Silicon's specific capabilities
  3. Memory Efficiency: Intelligent memory management for maximum model sizes
  4. Python Integration: Easy to use while being incredibly fast

Setting Up Our Fine-Tuning Pipeline

Let us build our fine-tuning system step by step, understanding each component.

Step 1: Configuration and Setup

First, let's create a comprehensive configuration system:

touch fine_tuning_config.py

# Create fine_tuning_config.py
import os
from pathlib import Path
import mlx.core as mx

class FineTuningConfig:
    """Centralized configuration for fine-tuning"""

    def __init__(self):
        # Model configuration
        self.base_model = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
        self.adapter_path = "./adapters/email_sentiment"

        # Data paths
        self.train_data_path = "./data/mlx_format/train.jsonl"
        self.valid_data_path = "./data/mlx_format/valid.jsonl"

        # LoRA parameters
        self.lora_layers = 16  # Number of transformer layers to add LoRA to
        self.lora_rank = 16    # The 'r' in LoRA - higher = more capacity but slower
        self.lora_alpha = 32   # Scaling factor for LoRA adapters

        # Training parameters
        self.batch_size = 2           # Batch size (reduce if out of memory)
        self.learning_rate = 5e-5     # Learning rate
        self.max_iters = 1000         # Maximum training iterations
        self.steps_per_report = 10    # How often to print progress
        self.steps_per_eval = 200     # How often to run validation
        self.save_every = 400         # How often to save checkpoints

        # Hardware optimization
        self.use_gpu = mx.metal.is_available()
        self.max_sequence_length = 2048

        # Create directories
        Path(self.adapter_path).mkdir(parents=True, exist_ok=True)

    def print_config(self):
        """Print current configuration"""
        print("πŸ”§ Fine-tuning Configuration:")
        print(f"  Base model: {self.base_model}")
        print(f"  GPU available: {self.use_gpu}")
        print(f"  LoRA rank: {self.lora_rank}")
        print(f"  LoRA layers: {self.lora_layers}")
        print(f"  Batch size: {self.batch_size}")
        print(f"  Learning rate: {self.learning_rate}")
        print(f"  Max iterations: {self.max_iters}")
        print(f"  Adapter path: {self.adapter_path}")

# Create and test config
if __name__ == "__main__":
    config = FineTuningConfig()
    config.print_config()
Enter fullscreen mode Exit fullscreen mode

Step 2: Memory and Performance Monitoring

Before we start fine-tuning, let's create tools to monitor our system:

touch monitoring.py

# Create monitoring.py
import time
import mlx.core as mx
from typing import Dict, List
import psutil

class PerformanceMonitor:
    """Monitor memory usage and training performance"""

    def __init__(self):
        self.start_time = time.time()
        self.metrics = []

    def log_memory_usage(self, step: int, loss: float = None):
        """Log current memory and performance metrics"""

        # GPU memory (if available)
        gpu_memory = {}
        if mx.metal.is_available():
            gpu_memory = {
                'active_mb': mx.metal.get_active_memory() / 1e6,
                'peak_mb': mx.metal.get_peak_memory() / 1e6
            }

        # System memory
        system_memory = psutil.virtual_memory()

        # Training metrics
        elapsed = time.time() - self.start_time

        metrics = {
            'step': step,
            'elapsed_seconds': elapsed,
            'loss': loss,
            'gpu_active_mb': gpu_memory.get('active_mb', 0),
            'gpu_peak_mb': gpu_memory.get('peak_mb', 0),
            'system_memory_percent': system_memory.percent,
            'system_memory_available_gb': system_memory.available / 1e9
        }

        self.metrics.append(metrics)

        if step % 50 == 0:  # Print every 50 steps
            self.print_status(metrics)

        return metrics

    def print_status(self, metrics: Dict):
        """Print current training status"""

        print(f"Step {metrics['step']:4d} | "
              f"Loss: {metrics['loss']:.4f} | "
              f"GPU: {metrics['gpu_active_mb']:.0f}MB | "
              f"Time: {metrics['elapsed_seconds']:.1f}s")

    def get_training_summary(self):
        """Get summary of training run"""

        if not self.metrics:
            return {}

        peak_gpu = max(m['gpu_peak_mb'] for m in self.metrics)
        total_time = self.metrics[-1]['elapsed_seconds']
        final_loss = self.metrics[-1]['loss']

        return {
            'total_training_time': total_time,
            'peak_gpu_memory_mb': peak_gpu,
            'final_loss': final_loss,
            'steps_completed': len(self.metrics)
        }
Enter fullscreen mode Exit fullscreen mode

Step 3: The Fine-Tuning Engine

Now let's create our main fine-tuning script using MLX-LM:

touch fine_tune_model.py

# Create fine_tune_model.py
import subprocess
import time
import json
import os
from pathlib import Path
from fine_tuning_config import FineTuningConfig
from monitoring import PerformanceMonitor

class MLXFineTuner:
    """Fine-tune models using MLX with LoRA"""

    def __init__(self, config: FineTuningConfig):
        self.config = config
        self.monitor = PerformanceMonitor()

    def validate_data(self):
        """Validate that training data exists and is properly formatted"""

        print("πŸ“Š Validating training data...")

        if not os.path.exists(self.config.train_data_path):
            raise FileNotFoundError(f"Training data not found: {self.config.train_data_path}")

        # Count training examples
        train_count = 0
        with open(self.config.train_data_path, 'r') as f:
            for line in f:
                if line.strip():
                    train_count += 1

        print(f"βœ… Found {train_count} training examples")

        # Validate format
        with open(self.config.train_data_path, 'r') as f:
            first_line = f.readline()
            try:
                example = json.loads(first_line)
                if 'text' not in example:
                    raise ValueError("Training data must have 'text' field")
                print("βœ… Data format validated")
            except json.JSONDecodeError:
                raise ValueError("Training data must be valid JSONL format")

        return train_count

    def build_training_command(self):
        """Build the MLX-LM training command"""

        cmd = [
            "python3", "-m", "mlx_lm", "lora",
            "--model", self.config.base_model,
            "--train",
            "--data", "./data/mlx_format",  # Directory containing train.jsonl
            "--batch-size", str(self.config.batch_size),
            "--iters", str(self.config.max_iters),
            "--learning-rate", str(self.config.learning_rate),
            "--steps-per-report", str(self.config.steps_per_report),
            "--steps-per-eval", str(self.config.steps_per_eval),
            "--adapter-path", self.config.adapter_path,
            "--save-every", str(self.config.save_every)
        ]

        return cmd

    def run_fine_tuning(self):
        """Execute the fine-tuning process"""

        print("πŸš€ Starting LoRA fine-tuning with MLX...")
        print("=" * 60)

        # Validate everything is ready
        train_count = self.validate_data()
        self.config.print_config()

        # Build command
        cmd = self.build_training_command()
        print(f"\nπŸ“ Command: {' '.join(cmd)}")

        # Start training
        start_time = time.time()

        print(f"\nπŸƒ Training started at {time.strftime('%H:%M:%S')}")
        print(f"πŸ“š Training on {train_count} examples")
        print("πŸ’‘ This typically takes 3-10 minutes on Apple Silicon M3")
        print("⏰ Progress will be reported every 10 steps\n")

        try:
            # Run the training command
            result = subprocess.run(cmd, capture_output=True, text=True, check=True)

            training_time = time.time() - start_time

            print("\n" + "="*60)
            print("πŸŽ‰ Fine-tuning completed successfully!")
            print(f"⏱️  Total training time: {training_time:.1f} seconds")
            print(f"πŸ’Ύ Adapters saved to: {self.config.adapter_path}")

            # Save training metadata
            metadata = {
                'model_name': self.config.base_model,
                'training_time_seconds': training_time,
                'training_examples': train_count,
                'lora_rank': self.config.lora_rank,
                'lora_layers': self.config.lora_layers,
                'batch_size': self.config.batch_size,
                'learning_rate': self.config.learning_rate,
                'max_iters': self.config.max_iters,
                'timestamp': time.time(),
                'command_used': ' '.join(cmd)
            }

            metadata_path = f"{self.config.adapter_path}/training_metadata.json"
            with open(metadata_path, 'w') as f:
                json.dump(metadata, f, indent=2)

            print(f"πŸ“Š Training metadata saved to: {metadata_path}")

            # Parse and display training output
            self.parse_training_output(result.stdout)

            return True, metadata

        except subprocess.CalledProcessError as e:
            print("\n❌ Fine-tuning failed!")
            print(f"Error code: {e.returncode}")
            print(f"Error output: {e.stderr}")
            print(f"Standard output: {e.stdout}")
            return False, None

    def parse_training_output(self, output: str):
        """Parse and display key information from training output"""

        print("\nπŸ“ˆ Training Progress Summary:")
        print("-" * 40)

        lines = output.split('\n')

        # Look for key training metrics
        for line in lines:
            if 'Loss:' in line or 'Validation' in line:
                print(f"  {line.strip()}")

        # Look for final metrics
        for line in reversed(lines):
            if 'Loss:' in line:
                print(f"\n🎯 Final training loss: {line.split('Loss:')[-1].strip()}")
                break

    def verify_training_output(self):
        """Verify that training produced the expected files"""

        print("\nπŸ” Verifying training output...")

        adapter_path = Path(self.config.adapter_path)

        # Check for adapter files
        adapter_files = list(adapter_path.glob("*.safetensors")) + list(adapter_path.glob("*.npz"))
        if adapter_files:
            print(f"βœ… Found adapter files: {[f.name for f in adapter_files]}")
        else:
            print("❌ No adapter files found")
            return False

        # Check for configuration
        config_file = adapter_path / "adapter_config.json"
        if config_file.exists():
            print(f"βœ… Found adapter config: {config_file}")

            # Display config contents
            with open(config_file, 'r') as f:
                config_data = json.load(f)
                print(f"   LoRA rank: {config_data.get('r', 'unknown')}")
                print(f"   LoRA alpha: {config_data.get('lora_alpha', 'unknown')}")
        else:
            print("⚠️  No adapter config found")

        # Calculate total size
        total_size = sum(f.stat().st_size for f in adapter_path.rglob('*') if f.is_file())
        print(f"πŸ“ Total adapter size: {total_size / 1e6:.1f} MB")

        return True

def main():
    """Main fine-tuning execution"""

    print("πŸ€– MLX LoRA Fine-Tuning Pipeline")
    print("=" * 50)

    # Create configuration
    config = FineTuningConfig()

    # Create fine-tuner
    fine_tuner = MLXFineTuner(config)

    # Run fine-tuning
    success, metadata = fine_tuner.run_fine_tuning()

    if success:
        # Verify output
        fine_tuner.verify_training_output()

        print("\n✨ Fine-tuning pipeline completed successfully!")
        print("\n🎯 Next steps:")
        print("  1. Test your fine-tuned model")
        print("  2. Run evaluation to measure performance")
        print("  3. Build your application interface")

        return metadata
    else:
        print("\nπŸ’₯ Fine-tuning failed. Please check the error messages above.")
        return None

if __name__ == "__main__":
    metadata = main()
Enter fullscreen mode Exit fullscreen mode

Top comments (0)