Here at Simli, we care the most about latency. That's what we're all about after all: low latency video. On the other hand, some of the most used algorithms in Audio Machine Learning have really really slow implementations. To be clear, these implementations are usually fine for creating the models themselves or for batched inferencing. But for us at Simli, a couple milliseconds could mean the difference between a stuttering mess or a smooth video.
Luckily for me (and by proxy you the reader), this guide does not require much knowledge in math, much smarter people have already figured out how to get the correct answer, we're just making the compute more efficient. If you need more info to understand what the MelSpectrogram even is, you can read this article. There are multiple ways to calculate the spectrogram, it heavily depends on your application. So, we're focusing on the mels required for running our internal models for the sake of convenience to the author.
The common solution: Librosa
You are most likely here after encountering a repo that’s using Librosa. It’s a pretty handy library, to be honest. There are a ton of utilities, easy ways to read the audio on disk, and quick access to many commonly required functionality such as audio resampling, channel downmixing, and others. In our case, we’re interested in one particular functionality: melspectrogram calculation. In librosa, getting the melspectrogram is straightforward.
import librosa
# load in any audio to test
sampleAudio, sr = librosa.load("sample.mp3", sr=None) # sr=None means the original sampling rate
spectrogram = librosa.feature.melspectrogram(
y=sampleAudio,
sr=sr,
n_fft=int(0.05 * sr), # 50ms
hop_length=int(0.0125 * sr), # 12.5ms
win_length=int(0.05 * sr),
)
Straightforward and it takes on average around 2ms on a GCP g2 VM. Well, there are two main issues:
- Usually, when working with DL models, you would need to run the model on a GPU. This means that part of your chain runs on the CPU and then you copy the results back to the GPU. For batched inference, this is mostly fine since you should collect as much data as you can fit on the GPU/transfer. However, in our case, we often work with one frame at a time to reduce waiting and processing time.
- Our total time budget is roughly 33ms/frame. This includes transfer latency from the API server to the ML inference server, CPU to GPU Copy, preprocessing, and postprocessing for the models including the melspectrogram. Every millisecond matters when you’re working with such a tight budget. These two milliseconds actually contributed towards having a working live rendered video stream for Simli (well it was many optimizations each worth a millisecond or two).
Looking online for solutions
While trying to look at how other people have done it (luckily this is not a unique problem for us), I found this article that explained both how melspectrograms work and provided a reference implementation that for some reason took only 1ms (50% improvement). That's a good start but there's still the first problem, not everything was on the GPU. We're using PyTorch and have been relying on the torch.compile with the mode=reduce-overhead
for maximum speed improvements. However, data transfer like this is may tank the performance as the PyTorch compiler will not be able to optimize the function as well. The solution is a bit tedious but relatively easy, rewrite it in torch. The PyTorch team have made sure a lot of their syntax and functionality is as close to NumPy as possible (with some edge cases that are usually well documented, apart from one that lost me a couple of days but that's a story for a different blog).
The PyTorch rewrite
So there are a couple of steps we need to do in order to successfully rewrite everything in Pytorch. Melspectrograms can be split into three steps:
- Computing Short time Fourier transform
- Generating the mel scale frequency banks
- Generating the spectrogram.
There’s good and bad news. The good news is all required functionality is readily available in pytorch or torchaudio. The bad news is the default behavior is a lot different from librosa so there’s a lot of configuration and trial and error to get it right. I’ve been through that and I’m sharing the info cause I can’t even wish that hell upon my worst enemy. One thing that we need to understand is this code heavily relies on caching some of our results to be used later. This is done in an initialization function that pregenerates all of the static arrays (mel frequency banks for example depends on the sampling rate and the number of mels you need). Here’s our optimized version of melspectrogram function using PyTorch
import torch
if torch.cuda.is_available
@torch.compile(mode="reduce-overhead")
else:
@torch.compile
def melspecrogram_torch(wav:torch.Tensor,sample_rate:int, hann_window: torch.Tensor, mel_basis: torch.Tensor):
stftWav = torch.stft(
wav,
n_fft=int(sample_rate*0.05),
win_length=int(sample_rate*0.05),
hop_length=int(sample_rate*0.0125),
window=hann_window,
pad_mode="constant",
return_complex=True,
).abs()
stftWav = stftWav.squeeze()
mel_stftWav = torch.mm(mel_basis, stftWav)
return mel_stftWav
device = "cuda" if torch.cuda.is_available() else "cpu"
melspectrogram_torch(
sampleAudio,
sr,
torch.hann_window(int(sample_rate*0.05), device=device, dtype=torch.float32),
torchaudio.functional.melscale_fbanks(
sample_rate=sr,
n_freqs=(int(sample_rate*0.05) // 2 + 1),
norm="slaney", # this is the normalization algorithm used by librosa
# this is an example that's related to our own pipeline, check what you need for yours
n_mels=80,
f_min=55,
f_max=7600,
)
.T.to(device)
)
After the initial compilation run, we measured this function to take 350 microseconds using an Nvidia L4 GPU (with caching the hann_window and melscale_fbanks). Adjusted call will look like this:
hann=torch.hann_window(int(sample_rate*0.05), device=device, dtype=torch.float32),
melscale=torchaudio.functional.melscale_fbanks(
sample_rate=sr,
n_freqs=(int(sample_rate*0.05) // 2 + 1),
norm="slaney", # this is the normalization algorithm used by librosa
# this is an example that's related to our own pipeline, check what you need for yours
n_mels=80,
f_min=55,
f_max=7600,
)
.T.to(device)
melspectrogram_torch(
sampleAudio,
sr,
hann,
melscale,
)
This is one part of a series of articles about how we optimized our deployed pretrained models, optimizing the preprocessing and postprocessing steps. You can check https://www.simli.com/demo to see the deployed models and the lowest latency avatars we provide
Top comments (0)