DEV Community

StemSplit
StemSplit

Posted on

Building a Music Practice App with AI Stem Separation (Python + React)

Learning an instrument by ear is one of the most valuable things a musician can do. It's also frustrating when the bass is buried under the full mix.

AI stem separation makes this dramatically easier: isolate the instrument you're learning, slow it down without changing the pitch, and loop the 4-bar phrase you keep missing. Here's how to build a lightweight practice tool that does exactly this.

What You'll Build

A small full-stack app with:

  • Python backend that accepts an audio URL or file upload
  • Stem separation via Demucs or API
  • Time-stretching (slow down without pitch shift)
  • Section looping
  • A minimal React frontend
# Backend
pip install fastapi uvicorn demucs librosa soundfile requests pyrubberband numpy

# For the API path instead of local Demucs
pip install requests
Enter fullscreen mode Exit fullscreen mode
# Frontend
npm create vite@latest practice-app -- --template react
cd practice-app && npm install wavesurfer.js axios
Enter fullscreen mode Exit fullscreen mode

Backend: Stem Separation + Audio Processing

main.py

import io
import subprocess
import tempfile
from pathlib import Path

import librosa
import numpy as np
import soundfile as sf
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse

app = FastAPI(title="Music Practice API")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:5173"],
    allow_methods=["*"],
    allow_headers=["*"],
)
Enter fullscreen mode Exit fullscreen mode

Stem Separation

def separate_stem(
    audio_path: str,
    stem: str,           # "vocals", "drums", "bass", "other"
    output_dir: str,
    use_api: bool = False,
    api_key: str = None,
) -> str:
    """
    Separate a single stem from an audio file.
    Returns path to separated stem WAV.
    """
    if use_api and api_key:
        return _separate_via_api(audio_path, stem, output_dir, api_key)
    return _separate_via_demucs(audio_path, stem, output_dir)


def _separate_via_demucs(audio_path: str, stem: str, output_dir: str) -> str:
    song = Path(audio_path).stem
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    subprocess.run(
        [
            "python", "-m", "demucs",
            "-n", "htdemucs_ft",
            "--two-stems", "vocals" if stem == "vocals" else stem,
            "-o", output_dir,
            audio_path,
        ],
        check=True,
        capture_output=True,
    )

    stem_path = Path(output_dir) / "htdemucs_ft" / song / f"{stem}.wav"
    if not stem_path.exists():
        raise FileNotFoundError(f"Stem not found: {stem_path}")
    return str(stem_path)


def _separate_via_api(
    audio_path: str,
    stem: str,
    output_dir: str,
    api_key: str,
) -> str:
    """Use the split stems online service via API."""
    import requests, time

    with open(audio_path, "rb") as f:
        resp = requests.post(
            "https://api.stemsplit.io/v1/separate",
            headers={"Authorization": f"Bearer {api_key}"},
            files={"audio": (Path(audio_path).name, f)},
            json={"stems": 4, "format": "wav"},
            timeout=30,
        )
    resp.raise_for_status()
    job_id = resp.json()["job_id"]

    while True:
        job = requests.get(
            f"https://api.stemsplit.io/v1/jobs/{job_id}",
            headers={"Authorization": f"Bearer {api_key}"},
        ).json()
        if job["status"] == "completed":
            data = requests.get(job["stems"][stem]).content
            out_path = Path(output_dir) / f"{stem}.wav"
            out_path.write_bytes(data)
            return str(out_path)
        if job["status"] == "failed":
            raise RuntimeError(job.get("error"))
        time.sleep(3)
Enter fullscreen mode Exit fullscreen mode

Time Stretching (Slow Down Without Pitch Shift)

This is the key feature for practice. pyrubberband wraps the high-quality Rubber Band library:

try:
    import pyrubberband as rb
    HAS_RUBBERBAND = True
except ImportError:
    HAS_RUBBERBAND = False


def time_stretch(
    audio_path: str,
    rate: float,           # 0.5 = half speed, 0.75 = 75%, 1.0 = normal
    pitch_shift: float = 0.0,   # semitones, 0 = no pitch change
) -> np.ndarray:
    """
    Time-stretch audio without changing pitch.

    Args:
        audio_path:   Input audio file
        rate:         Playback rate (0.5 = 50% speed)
        pitch_shift:  Semitones to shift pitch (independent of rate)

    Returns:
        Stretched audio as numpy array (float32, mono)
    """
    y, sr = librosa.load(audio_path, sr=44100, mono=True)

    if HAS_RUBBERBAND:
        # High-quality time stretching via Rubber Band
        stretched = rb.time_stretch(y, sr, rate)
        if pitch_shift != 0.0:
            stretched = rb.pitch_shift(stretched, sr, pitch_shift)
    else:
        # Fallback: librosa phase vocoder (lower quality but no extra dep)
        stretched = librosa.effects.time_stretch(y, rate=rate)
        if pitch_shift != 0.0:
            stretched = librosa.effects.pitch_shift(stretched, sr=sr, n_steps=pitch_shift)

    return stretched, sr


def time_stretch_endpoint(audio_path: str, rate: float) -> bytes:
    """Stretch audio and return as WAV bytes."""
    y, sr = time_stretch(audio_path, rate=rate)
    buf = io.BytesIO()
    sf.write(buf, y, sr, format="WAV", subtype="PCM_16")
    buf.seek(0)
    return buf.read()
Enter fullscreen mode Exit fullscreen mode

Section Loop Extractor

def extract_loop(
    audio_path: str,
    start_sec: float,
    end_sec: float,
    loop_count: int = 4,
    fade_ms: int = 10,
) -> bytes:
    """
    Extract a section from an audio file and loop it N times.
    Adds a short fade to avoid clicks at loop boundaries.

    Returns: WAV bytes
    """
    y, sr = librosa.load(audio_path, sr=44100, mono=True)

    start_sample = int(start_sec * sr)
    end_sample   = int(end_sec * sr)
    segment = y[start_sample:end_sample]

    # Apply fade in/out to avoid clicks
    fade_samples = int(fade_ms / 1000 * sr)
    fade_in  = np.linspace(0, 1, min(fade_samples, len(segment)))
    fade_out = np.linspace(1, 0, min(fade_samples, len(segment)))

    segment[:len(fade_in)]  *= fade_in
    segment[-len(fade_out):] *= fade_out

    # Loop
    looped = np.tile(segment, loop_count)

    buf = io.BytesIO()
    sf.write(buf, looped, sr, format="WAV", subtype="PCM_16")
    buf.seek(0)
    return buf.read()
Enter fullscreen mode Exit fullscreen mode

API Endpoints

import os
import uuid

UPLOAD_DIR = "/tmp/practice_app"


@app.post("/separate")
async def api_separate(
    audio: UploadFile = File(...),
    stem: str = Form("bass"),
    speed: float = Form(1.0),
):
    """
    Upload audio → separate stem → return time-stretched WAV.

    stem:  vocals | drums | bass | other
    speed: 0.5 = half speed, 1.0 = normal
    """
    if stem not in {"vocals", "drums", "bass", "other"}:
        raise HTTPException(422, f"Invalid stem. Choose: vocals, drums, bass, other")
    if not 0.25 <= speed <= 2.0:
        raise HTTPException(422, "Speed must be between 0.25 and 2.0")

    session_id = uuid.uuid4().hex[:8]
    session_dir = os.path.join(UPLOAD_DIR, session_id)
    os.makedirs(session_dir, exist_ok=True)

    # Save upload
    audio_path = os.path.join(session_dir, audio.filename or "audio.mp3")
    data = await audio.read()
    with open(audio_path, "wb") as f:
        f.write(data)

    # Separate stem
    stem_path = separate_stem(audio_path, stem, session_dir)

    # Apply time stretch if needed
    if abs(speed - 1.0) > 0.01:
        y, sr = time_stretch(stem_path, rate=speed)
        buf = io.BytesIO()
        sf.write(buf, y, sr, format="WAV", subtype="PCM_16")
        buf.seek(0)
        return StreamingResponse(buf, media_type="audio/wav")

    return StreamingResponse(open(stem_path, "rb"), media_type="audio/wav")


@app.post("/loop")
async def api_loop(
    audio: UploadFile = File(...),
    stem: str = Form("bass"),
    start_sec: float = Form(0.0),
    end_sec: float = Form(4.0),
    loop_count: int = Form(4),
    speed: float = Form(1.0),
):
    """Extract a section loop from a separated stem."""
    session_id = uuid.uuid4().hex[:8]
    session_dir = os.path.join(UPLOAD_DIR, session_id)
    os.makedirs(session_dir, exist_ok=True)

    audio_path = os.path.join(session_dir, audio.filename or "audio.mp3")
    data = await audio.read()
    with open(audio_path, "wb") as f:
        f.write(data)

    stem_path = separate_stem(audio_path, stem, session_dir)

    # Apply speed before looping
    if abs(speed - 1.0) > 0.01:
        y, sr = time_stretch(stem_path, rate=speed)
        buf = io.BytesIO()
        sf.write(buf, y, sr, format="WAV", subtype="PCM_16")
        buf.seek(0)
        stem_path = os.path.join(session_dir, "stretched.wav")
        with open(stem_path, "wb") as f:
            f.write(buf.getvalue())

    loop_bytes = extract_loop(stem_path, start_sec, end_sec, loop_count)
    return StreamingResponse(io.BytesIO(loop_bytes), media_type="audio/wav")
Enter fullscreen mode Exit fullscreen mode

Frontend: React + WaveSurfer

// src/App.jsx
import { useState, useRef, useEffect } from "react";
import WaveSurfer from "wavesurfer.js";
import axios from "axios";

const API = "http://localhost:8000";

const STEMS = ["vocals", "drums", "bass", "other"];

export default function PracticeApp() {
  const [file, setFile]         = useState(null);
  const [stem, setStem]         = useState("bass");
  const [speed, setSpeed]       = useState(1.0);
  const [loopStart, setLoopStart] = useState(0);
  const [loopEnd, setLoopEnd]   = useState(4);
  const [loading, setLoading]   = useState(false);
  const [audioUrl, setAudioUrl] = useState(null);

  const waveRef  = useRef(null);
  const waveSurfer = useRef(null);

  useEffect(() => {
    if (waveRef.current && !waveSurfer.current) {
      waveSurfer.current = WaveSurfer.create({
        container: waveRef.current,
        waveColor: "#6366f1",
        progressColor: "#4f46e5",
        height: 80,
        normalize: true,
      });
    }
    return () => waveSurfer.current?.destroy();
  }, []);

  useEffect(() => {
    if (audioUrl && waveSurfer.current) {
      waveSurfer.current.load(audioUrl);
    }
  }, [audioUrl]);

  async function handleSeparate() {
    if (!file) return;
    setLoading(true);

    const form = new FormData();
    form.append("audio", file);
    form.append("stem", stem);
    form.append("speed", speed);

    const resp = await axios.post(`${API}/separate`, form, {
      responseType: "blob",
    });

    const url = URL.createObjectURL(resp.data);
    setAudioUrl(url);
    setLoading(false);
  }

  async function handleLoop() {
    if (!file) return;
    setLoading(true);

    const form = new FormData();
    form.append("audio", file);
    form.append("stem", stem);
    form.append("start_sec", loopStart);
    form.append("end_sec", loopEnd);
    form.append("loop_count", 4);
    form.append("speed", speed);

    const resp = await axios.post(`${API}/loop`, form, {
      responseType: "blob",
    });

    const url = URL.createObjectURL(resp.data);
    setAudioUrl(url);
    setLoading(false);
  }

  return (
    <div className="max-w-2xl mx-auto p-6 space-y-6">
      <h1 className="text-2xl font-bold">🎸 Practice Tool</h1>

      {/* File upload */}
      <input
        type="file"
        accept=".mp3,.wav,.flac"
        onChange={e => setFile(e.target.files[0])}
        className="block w-full"
      />

      {/* Stem selector */}
      <div className="flex gap-2">
        {STEMS.map(s => (
          <button
            key={s}
            onClick={() => setStem(s)}
            className={`px-4 py-2 rounded capitalize ${
              stem === s ? "bg-indigo-600 text-white" : "bg-gray-100"
            }`}
          >
            {s}
          </button>
        ))}
      </div>

      {/* Speed slider */}
      <label className="block">
        <span className="text-sm font-medium">Speed: {speed}×</span>
        <input
          type="range"
          min="0.25" max="1.0" step="0.05"
          value={speed}
          onChange={e => setSpeed(Number(e.target.value))}
          className="w-full mt-1"
        />
      </label>

      {/* Loop controls */}
      <div className="flex gap-4">
        <label className="flex-1">
          <span className="text-sm">Loop start (s)</span>
          <input type="number" value={loopStart}
            onChange={e => setLoopStart(Number(e.target.value))}
            className="w-full border rounded p-1 mt-1" />
        </label>
        <label className="flex-1">
          <span className="text-sm">Loop end (s)</span>
          <input type="number" value={loopEnd}
            onChange={e => setLoopEnd(Number(e.target.value))}
            className="w-full border rounded p-1 mt-1" />
        </label>
      </div>

      {/* Action buttons */}
      <div className="flex gap-3">
        <button
          onClick={handleSeparate}
          disabled={!file || loading}
          className="flex-1 bg-indigo-600 text-white rounded py-2 disabled:opacity-50"
        >
          {loading ? "Processing..." : `Isolate ${stem}`}
        </button>
        <button
          onClick={handleLoop}
          disabled={!file || loading}
          className="flex-1 bg-emerald-600 text-white rounded py-2 disabled:opacity-50"
        >
          Extract Loop
        </button>
      </div>

      {/* Waveform */}
      <div ref={waveRef} className="border rounded p-2" />

      {/* Playback */}
      {audioUrl && (
        <div className="flex gap-2">
          <button
            onClick={() => waveSurfer.current?.playPause()}
            className="px-4 py-2 bg-gray-800 text-white rounded"
          >
            Play / Pause
          </button>
          <a href={audioUrl} download={`${stem}_${speed}x.wav`}
            className="px-4 py-2 bg-gray-100 rounded">
            Download
          </a>
        </div>
      )}
    </div>
  );
}
Enter fullscreen mode Exit fullscreen mode

Running the App

# Terminal 1: Backend
uvicorn main:app --host 0.0.0.0 --port 8000

# Terminal 2: Frontend
cd practice-app && npm run dev
Enter fullscreen mode Exit fullscreen mode

Open http://localhost:5173, upload a track, pick the instrument, and hit "Isolate".


Use Cases This Enables

Learning bass lines — isolate the bass stem at 60% speed, loop the 4-bar section you're working on. No more guessing what the bass is playing under the guitar and vocal.

Ear training — isolate one instrument, transcribe it, check against the original. Repeat at increasing speeds.

Vocal practice — isolate the vocal stem, hear the "clean" melody without reverb from the mix, then practice over the instrumental (no_vocals stem).

Drum analysis — slow down the drum stem to 50% to hear exactly what the drummer is doing in complex fills.


Performance Notes

Stem separation takes 30–60 seconds locally (CPU) or ~40 seconds via API. For a practice app, this one-time wait is fine. Cache the separated stems by file hash so you only process each track once:

import hashlib
from pathlib import Path


def get_file_hash(data: bytes) -> str:
    return hashlib.sha256(data).hexdigest()[:16]


def get_cached_stem(file_hash: str, stem: str, cache_dir: str = "/tmp/stem_cache") -> str | None:
    path = Path(cache_dir) / file_hash / f"{stem}.wav"
    return str(path) if path.exists() else None


def cache_stem(file_hash: str, stem: str, stem_path: str, cache_dir: str = "/tmp/stem_cache") -> None:
    dest = Path(cache_dir) / file_hash
    dest.mkdir(parents=True, exist_ok=True)
    import shutil
    shutil.copy(stem_path, str(dest / f"{stem}.wav"))
Enter fullscreen mode Exit fullscreen mode

Hosted Alternative

If you just want to practice rather than build, StemSplit handles separation in the browser — same HTDemucs quality, no local setup. The code in this article is for building your own tool or embedding this into an existing app.


Related Articles


What instrument are you learning? Drop it in the comments — interested to know which stem people use most.

Top comments (0)