This is where the magic happens! In this part, we will deep dive into LoRA (Low-Rank Adaptation) fine-tuning and use MLX to train our model with incredible efficiency on Apple Silicon.
Understanding LoRA: The Game-Changing Technique
Imagine you are a master chef who wants to learn a new cuisine. Instead of forgetting everything you know and starting from scratch, you add new techniques and flavor profiles to your existing knowledge. That's exactly what LoRA (Low-Rank Adaptation) does for language models.
The Traditional Fine-Tuning Problem
Traditional fine-tuning updates all 1.7 billion parameters of our model. This means:
- β Massive memory requirements
- β Slow training
- β Risk of "catastrophic forgetting" (losing general knowledge)
- β Large model files
The LoRA Solution
LoRA adds small "adapter" layers that learn new behaviors while keeping the original model frozen:
- β Minimal memory usage
- β Fast training
- β Preserves general knowledge
- β Tiny adapter file size
- β Can be combined or switched out easily
How LoRA Works Under the Hood
Think of the original model as a Swiss Army knife with all its tools welded in place. LoRA adds new attachments that can be snapped on or off.
MLX: Apple's Secret Weapon for AI
MLX is Apple's machine learning framework designed specifically for Apple Silicon. It's what makes our local fine-tuning possible and incredibly fast.
Why MLX is good for Local AI
- Unified Memory Architecture: M-series chips share memory between CPU and GPU, eliminating data transfer bottlenecks
- Optimized Computation: Hand-tuned for Apple Silicon's specific capabilities
- Memory Efficiency: Intelligent memory management for maximum model sizes
- Python Integration: Easy to use while being incredibly fast
Setting Up Our Fine-Tuning Pipeline
Let us build our fine-tuning system step by step, understanding each component.
Step 1: Configuration and Setup
First, let's create a comprehensive configuration system:
touch fine_tuning_config.py
# Create fine_tuning_config.py
import os
from pathlib import Path
import mlx.core as mx
class FineTuningConfig:
"""Centralized configuration for fine-tuning"""
def __init__(self):
# Model configuration
self.base_model = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
self.adapter_path = "./adapters/email_sentiment"
# Data paths
self.train_data_path = "./data/mlx_format/train.jsonl"
self.valid_data_path = "./data/mlx_format/valid.jsonl"
# LoRA parameters
self.lora_layers = 16 # Number of transformer layers to add LoRA to
self.lora_rank = 16 # The 'r' in LoRA - higher = more capacity but slower
self.lora_alpha = 32 # Scaling factor for LoRA adapters
# Training parameters
self.batch_size = 2 # Batch size (reduce if out of memory)
self.learning_rate = 5e-5 # Learning rate
self.max_iters = 1000 # Maximum training iterations
self.steps_per_report = 10 # How often to print progress
self.steps_per_eval = 200 # How often to run validation
self.save_every = 400 # How often to save checkpoints
# Hardware optimization
self.use_gpu = mx.metal.is_available()
self.max_sequence_length = 2048
# Create directories
Path(self.adapter_path).mkdir(parents=True, exist_ok=True)
def print_config(self):
"""Print current configuration"""
print("π§ Fine-tuning Configuration:")
print(f" Base model: {self.base_model}")
print(f" GPU available: {self.use_gpu}")
print(f" LoRA rank: {self.lora_rank}")
print(f" LoRA layers: {self.lora_layers}")
print(f" Batch size: {self.batch_size}")
print(f" Learning rate: {self.learning_rate}")
print(f" Max iterations: {self.max_iters}")
print(f" Adapter path: {self.adapter_path}")
# Create and test config
if __name__ == "__main__":
config = FineTuningConfig()
config.print_config()
Step 2: Memory and Performance Monitoring
Before we start fine-tuning, let's create tools to monitor our system:
touch monitoring.py
# Create monitoring.py
import time
import mlx.core as mx
from typing import Dict, List
import psutil
class PerformanceMonitor:
"""Monitor memory usage and training performance"""
def __init__(self):
self.start_time = time.time()
self.metrics = []
def log_memory_usage(self, step: int, loss: float = None):
"""Log current memory and performance metrics"""
# GPU memory (if available)
gpu_memory = {}
if mx.metal.is_available():
gpu_memory = {
'active_mb': mx.metal.get_active_memory() / 1e6,
'peak_mb': mx.metal.get_peak_memory() / 1e6
}
# System memory
system_memory = psutil.virtual_memory()
# Training metrics
elapsed = time.time() - self.start_time
metrics = {
'step': step,
'elapsed_seconds': elapsed,
'loss': loss,
'gpu_active_mb': gpu_memory.get('active_mb', 0),
'gpu_peak_mb': gpu_memory.get('peak_mb', 0),
'system_memory_percent': system_memory.percent,
'system_memory_available_gb': system_memory.available / 1e9
}
self.metrics.append(metrics)
if step % 50 == 0: # Print every 50 steps
self.print_status(metrics)
return metrics
def print_status(self, metrics: Dict):
"""Print current training status"""
print(f"Step {metrics['step']:4d} | "
f"Loss: {metrics['loss']:.4f} | "
f"GPU: {metrics['gpu_active_mb']:.0f}MB | "
f"Time: {metrics['elapsed_seconds']:.1f}s")
def get_training_summary(self):
"""Get summary of training run"""
if not self.metrics:
return {}
peak_gpu = max(m['gpu_peak_mb'] for m in self.metrics)
total_time = self.metrics[-1]['elapsed_seconds']
final_loss = self.metrics[-1]['loss']
return {
'total_training_time': total_time,
'peak_gpu_memory_mb': peak_gpu,
'final_loss': final_loss,
'steps_completed': len(self.metrics)
}
Step 3: The Fine-Tuning Engine
Now let's create our main fine-tuning script using MLX-LM:
touch fine_tune_model.py
# Create fine_tune_model.py
import subprocess
import time
import json
import os
from pathlib import Path
from fine_tuning_config import FineTuningConfig
from monitoring import PerformanceMonitor
class MLXFineTuner:
"""Fine-tune models using MLX with LoRA"""
def __init__(self, config: FineTuningConfig):
self.config = config
self.monitor = PerformanceMonitor()
def validate_data(self):
"""Validate that training data exists and is properly formatted"""
print("π Validating training data...")
if not os.path.exists(self.config.train_data_path):
raise FileNotFoundError(f"Training data not found: {self.config.train_data_path}")
# Count training examples
train_count = 0
with open(self.config.train_data_path, 'r') as f:
for line in f:
if line.strip():
train_count += 1
print(f"β
Found {train_count} training examples")
# Validate format
with open(self.config.train_data_path, 'r') as f:
first_line = f.readline()
try:
example = json.loads(first_line)
if 'text' not in example:
raise ValueError("Training data must have 'text' field")
print("β
Data format validated")
except json.JSONDecodeError:
raise ValueError("Training data must be valid JSONL format")
return train_count
def build_training_command(self):
"""Build the MLX-LM training command"""
cmd = [
"python3", "-m", "mlx_lm", "lora",
"--model", self.config.base_model,
"--train",
"--data", "./data/mlx_format", # Directory containing train.jsonl
"--batch-size", str(self.config.batch_size),
"--iters", str(self.config.max_iters),
"--learning-rate", str(self.config.learning_rate),
"--steps-per-report", str(self.config.steps_per_report),
"--steps-per-eval", str(self.config.steps_per_eval),
"--adapter-path", self.config.adapter_path,
"--save-every", str(self.config.save_every)
]
return cmd
def run_fine_tuning(self):
"""Execute the fine-tuning process"""
print("π Starting LoRA fine-tuning with MLX...")
print("=" * 60)
# Validate everything is ready
train_count = self.validate_data()
self.config.print_config()
# Build command
cmd = self.build_training_command()
print(f"\nπ Command: {' '.join(cmd)}")
# Start training
start_time = time.time()
print(f"\nπ Training started at {time.strftime('%H:%M:%S')}")
print(f"π Training on {train_count} examples")
print("π‘ This typically takes 3-10 minutes on Apple Silicon M3")
print("β° Progress will be reported every 10 steps\n")
try:
# Run the training command
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
training_time = time.time() - start_time
print("\n" + "="*60)
print("π Fine-tuning completed successfully!")
print(f"β±οΈ Total training time: {training_time:.1f} seconds")
print(f"πΎ Adapters saved to: {self.config.adapter_path}")
# Save training metadata
metadata = {
'model_name': self.config.base_model,
'training_time_seconds': training_time,
'training_examples': train_count,
'lora_rank': self.config.lora_rank,
'lora_layers': self.config.lora_layers,
'batch_size': self.config.batch_size,
'learning_rate': self.config.learning_rate,
'max_iters': self.config.max_iters,
'timestamp': time.time(),
'command_used': ' '.join(cmd)
}
metadata_path = f"{self.config.adapter_path}/training_metadata.json"
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
print(f"π Training metadata saved to: {metadata_path}")
# Parse and display training output
self.parse_training_output(result.stdout)
return True, metadata
except subprocess.CalledProcessError as e:
print("\nβ Fine-tuning failed!")
print(f"Error code: {e.returncode}")
print(f"Error output: {e.stderr}")
print(f"Standard output: {e.stdout}")
return False, None
def parse_training_output(self, output: str):
"""Parse and display key information from training output"""
print("\nπ Training Progress Summary:")
print("-" * 40)
lines = output.split('\n')
# Look for key training metrics
for line in lines:
if 'Loss:' in line or 'Validation' in line:
print(f" {line.strip()}")
# Look for final metrics
for line in reversed(lines):
if 'Loss:' in line:
print(f"\nπ― Final training loss: {line.split('Loss:')[-1].strip()}")
break
def verify_training_output(self):
"""Verify that training produced the expected files"""
print("\nπ Verifying training output...")
adapter_path = Path(self.config.adapter_path)
# Check for adapter files
adapter_files = list(adapter_path.glob("*.safetensors")) + list(adapter_path.glob("*.npz"))
if adapter_files:
print(f"β
Found adapter files: {[f.name for f in adapter_files]}")
else:
print("β No adapter files found")
return False
# Check for configuration
config_file = adapter_path / "adapter_config.json"
if config_file.exists():
print(f"β
Found adapter config: {config_file}")
# Display config contents
with open(config_file, 'r') as f:
config_data = json.load(f)
print(f" LoRA rank: {config_data.get('r', 'unknown')}")
print(f" LoRA alpha: {config_data.get('lora_alpha', 'unknown')}")
else:
print("β οΈ No adapter config found")
# Calculate total size
total_size = sum(f.stat().st_size for f in adapter_path.rglob('*') if f.is_file())
print(f"π Total adapter size: {total_size / 1e6:.1f} MB")
return True
def main():
"""Main fine-tuning execution"""
print("π€ MLX LoRA Fine-Tuning Pipeline")
print("=" * 50)
# Create configuration
config = FineTuningConfig()
# Create fine-tuner
fine_tuner = MLXFineTuner(config)
# Run fine-tuning
success, metadata = fine_tuner.run_fine_tuning()
if success:
# Verify output
fine_tuner.verify_training_output()
print("\n⨠Fine-tuning pipeline completed successfully!")
print("\nπ― Next steps:")
print(" 1. Test your fine-tuned model")
print(" 2. Run evaluation to measure performance")
print(" 3. Build your application interface")
return metadata
else:
print("\nπ₯ Fine-tuning failed. Please check the error messages above.")
return None
if __name__ == "__main__":
metadata = main()
Top comments (0)