स्पीकर डायराइज़ेशन: एक सम्पूर्ण तकनीकी मार्गदर्शिका
लेखक: Deep Learning शिक्षार्थियों के लिए
स्तर: Beginner से Intermediate
भाषा: हिंदी (Technical terms अंग्रेज़ी में)
विषय-सूची
- परिचय (Introduction)
- पूरे सिस्टम का अवलोकन (System Overview)
- Component-wise Deep Dive
- 3.1 Audio Loading और Preprocessing
- 3.2 SincNet — Feature Extraction
- 3.3 PyanNet — Segmentation Model
- 3.4 Powerset Encoding
- 3.5 Binarization
- 3.6 Speaker Count Estimation
- 3.7 WeSpeakerResNet34 — Speaker Embeddings
- 3.8 VBx Clustering
- 3.9 Label Assignment और Reconstruction
- गणितीय समझ (Mathematical Intuition)
- Deep Learning की सरल व्याख्या
- Line-by-Line Code Explanation
- प्रायोगिक समझ (Practical Insights)
- सामान्य गलतियाँ और समाधान
- निष्कर्ष (Conclusion)
1. परिचय (Introduction)
1.1 Speaker Diarization की औपचारिक परिभाषा
Speaker Diarization एक ऐसी प्रक्रिया है जिसमें एक audio recording को इस प्रकार विभाजित किया जाता है कि हर time segment के साथ यह जानकारी जुड़ी हो कि उस समय कौन-सा वक्ता (speaker) बोल रहा था।
औपचारिक रूप से:
जहाँ एक audio waveform है जिसकी length samples है। Output एक set होता है जिसमें (शुरुआत, अंत, वक्ता-पहचान) के tuples होते हैं।
इसे "Who Spoke When?" problem भी कहते हैं।
1.2 "कौन कब बोल रहा है" — समस्या का Formulation
मान लीजिए आपके पास एक 10 मिनट की meeting recording है जिसमें 3 लोग बोल रहे हैं। Diarization का काम है:
0:00 - 0:45 → SPEAKER_00
0:45 - 1:30 → SPEAKER_01
1:30 - 2:10 → SPEAKER_00
2:10 - 2:50 → SPEAKER_02
...
यह system पहले से नहीं जानता कि कितने speakers हैं या वे कौन हैं। इसे स्वयं यह निर्णय लेना होता है।
1.3 मुख्य उपयोग (Applications)
- Meeting Transcription: Zoom, Teams जैसी meetings में हर वक्ता की बात अलग-अलग transcript करना।
- Podcast Analysis: किस host ने कितनी देर बात की।
- Legal Proceedings: Court recordings में कौन-सा गवाह कब बोला।
- Medical Interviews: Doctor-Patient conversation में किसने क्या कहा।
- Broadcast Media: News anchor vs interview subject का अलगाव।
- Call Center Analytics: Agent और customer की बातचीत का analysis।
1.4 प्रमुख चुनौतियाँ
1. Overlapping Speech (Overlapping बोलना):
जब दो लोग एक साथ बोलते हैं, तो audio signal में दोनों की आवाज़ें मिली होती हैं। इन्हें अलग करना कठिन है।
2. Short Segments:
कई बार एक speaker कुछ ही milliseconds के लिए बोलता है। इतने कम data से embedding निकालना unreliable होता है।
3. Variable Number of Speakers:
System को पहले से नहीं पता कि recording में 2 हैं या 10 लोग।
4. Channel Conditions:
Noise, echo, different microphone qualities — ये सभी embeddings को प्रभावित करते हैं।
5. Speaker Confusion:
कभी-कभी दो अलग speakers की आवाज़ें मिलती-जुलती होती हैं (जैसे twins), जिससे clustering गलत हो जाती है।
2. पूरे सिस्टम का अवलोकन (System Overview)
यह pipeline निम्नलिखित sequential steps में काम करती है:
Audio File (WAV/MP3)
│
▼
┌─────────────────────┐
│ Audio Loading & │
│ Preprocessing │ ← Audio class (16kHz mono)
└─────────────────────┘
│
▼
┌─────────────────────┐
│ Segmentation │ ← PyanNet (SincNet + LSTM)
│ (Frame-level) │ "इस frame में कौन active है?"
└─────────────────────┘
│
▼
┌─────────────────────┐
│ Binarization │ ← Probabilities → Binary masks
└─────────────────────┘
│
▼
┌─────────────────────┐
│ Speaker Count │ ← हर time frame में कितने speakers?
└─────────────────────┘
│
▼
┌─────────────────────┐
│ Speaker Embeddings │ ← WeSpeakerResNet34
│ │ "हर chunk के हर speaker का fingerprint"
└─────────────────────┘
│
▼
┌─────────────────────┐
│ VBx Clustering │ ← AHC → VBx → KMeans (optional)
│ │ "कौन-से fingerprints एक ही person के हैं?"
└─────────────────────┘
│
▼
┌─────────────────────┐
│ Reconstruction │ ← Clustered labels → Timeline
└─────────────────────┘
│
▼
Final Diarization Output
{start, end, speaker_id}
हर Component का Role संक्षेप में:
| Component | Input | Output | Role |
|---|---|---|---|
| Audio | File path | Waveform tensor | Audio load करना |
| Segmentation (PyanNet) | Waveform chunks | Per-frame probabilities | कौन कब active है |
| Binarize | Probabilities | Binary 0/1 masks | Threshold apply करना |
| Speaker Count | Binary masks | Count per frame | Overlap handle करना |
| Embeddings (ResNet) | Waveform + mask | 256-dim vector | Speaker identity |
| VBx Clustering | Embeddings | Cluster labels | Groups बनाना |
| Reconstruct | Labels + segmentation | Final timeline | Output format |
3. Component-wise Deep Dive
3.1 Audio Loading और Preprocessing — Audio Class
Conceptual Explanation
Audio processing का पहला step है raw audio को एक standard format में लाना। अलग-अलग files में:
- Sample rate अलग हो सकती है (8kHz, 22kHz, 44kHz, 48kHz)
- Channels अलग हो सकते हैं (Stereo = 2 channels, Mono = 1 channel)
हमारे model को specifically 16kHz mono audio चाहिए।
Sample Rate क्या होती है?
1 second के audio को represent करने के लिए कितने numerical samples लिए गए हैं। 16kHz का मतलब है हर second में 16,000 samples।
Mono Downmix:
अगर audio stereo (2-channel) है, तो दोनों channels का average लेकर एक channel बनाते हैं:
Resampling:
अगर audio 44kHz पर है और हमें 16kHz चाहिए, तो हर 44,000 samples में से 16,000 को रखना होगा (interpolation के साथ)।
Code Explanation
class Audio:
def __init__(self, sample_rate: int = None, mono: str = None):
self.sample_rate = sample_rate # Target sample rate (16000)
self.mono = mono # "downmix" = stereo को mono में convert करो
def downmix_and_resample(self, waveform, sample_rate, channel=None):
# Step 1: अगर specific channel चाहिए तो extract करो
if channel is not None:
waveform = waveform[channel : channel + 1]
# Step 2: Multi-channel को mono में convert करो
if waveform.shape[0] > 1:
if self.mono == "downmix":
waveform = waveform.mean(dim=0, keepdim=True) # Average of channels
elif self.mono == "random":
ch = random.randint(0, waveform.shape[0] - 1)
waveform = waveform[ch : ch + 1] # Training में data augmentation के लिए
# Step 3: Resample करो अगर ज़रूरी हो
if (self.sample_rate is not None) and (self.sample_rate != sample_rate):
waveform = torchaudio.functional.resample(waveform, sample_rate, self.sample_rate)
return waveform, sample_rate
crop() method:
यह method audio का एक specific segment (start से end तक) load करता है।
def crop(self, file, segment, mode="raise"):
# Frame offset calculate करो
start_frame = int(round(segment.start * sr)) # seconds → samples
num_frames = int(round(segment.duration * sr))
# Disk से सिर्फ वही हिस्सा load करो
waveform, original_sr = torchaudio.load(
path,
frame_offset=start_frame, # यहाँ से शुरू करो
num_frames=num_frames # इतने ही samples लो
)
क्यों सिर्फ segment load करते हैं?
एक 2-hour recording को पूरा RAM में load करना wasteful होगा। torchaudio.load() का frame_offset और num_frames parameter हमें सीधे disk से chunk load करने देता है।
3.2 SincNet — Feature Extraction
Conceptual Explanation
Raw waveform (numerical samples का sequence) directly neural network में देना inefficient है क्योंकि:
- High dimensionality — 10 seconds = 160,000 samples at 16kHz
- Temporal redundancy — adjacent samples बहुत similar होते हैं
SincNet एक learned filter bank है जो raw waveform से meaningful features extract करता है।
Mathematical Formulation
Standard Convolution vs Sinc Filters:
एक standard 1D convolution में:
जहाँ learnable parameters हैं।
SincNet में filter को constrained किया जाता है:
यह एक bandpass filter है जो frequencies से के बीच के signals को pass करता है।
Sinc function:
महत्व: इस formulation में केवल (lower cutoff) और (upper cutoff) learnable parameters हैं, जो filters को interpretable बनाते हैं। Standard convolution की तुलना में यह phonetically meaningful features सीखता है।
Architecture
Input: Waveform (1, T) — 1 channel, T samples
│
▼ SincFB (80 filters, kernel=251, stride=10)
▼ |x| (Absolute value — phase irrelevant है speech के लिए)
▼ MaxPool1d(3, stride=3)
▼ InstanceNorm1d + LeakyReLU
│
▼ Conv1d(80→60, kernel=5, stride=1)
▼ MaxPool1d(3, stride=3)
▼ InstanceNorm1d + LeakyReLU
│
▼ Conv1d(60→60, kernel=5, stride=1)
▼ MaxPool1d(3, stride=3)
▼ InstanceNorm1d + LeakyReLU
│
Output: (60, F) — 60 features, F frames
Stride calculation:
10-second audio at 16kHz → 160,000 samples
Total stride = 10 × 3 × 3 = 90
Output frames ≈ 160,000 / 90 ≈ 1778 frames
Instance Normalization क्यों?
Batch Normalization statistics को batch के across normalize करता है। Speech के लिए यह problematic है क्योंकि अलग-अलग recordings की amplitude बहुत अलग होती है।
InstanceNorm प्रत्येक sample को independently normalize करता है:
जहाँ और एक single sample के लिए calculate होते हैं।
Code Walk-through
class SincNet(nn.Module):
def __init__(self, sample_rate: int = 16000, stride: int = 1):
super().__init__()
self.wav_norm1d = nn.InstanceNorm1d(1, affine=True) # Input normalization
self.layers = nn.ModuleDict({
"conv": nn.ModuleList([
# Layer 1: SincFB (learnable bandpass filters)
Encoder(ParamSincFB(80, 251, stride=stride, sample_rate=sample_rate,
min_low_hz=50, min_band_hz=50)),
# Layer 2: Standard convolution
nn.Conv1d(80, 60, 5, stride=1),
# Layer 3: Standard convolution
nn.Conv1d(60, 60, 5, stride=1)
]),
"pool": nn.ModuleList([nn.MaxPool1d(3, stride=3) for _ in range(3)]),
"norm": nn.ModuleList([nn.InstanceNorm1d(c, affine=True) for c in [80, 60, 60]])
})
def forward(self, waveforms):
x = self.wav_norm1d(waveforms) # Input normalize करो
for i, (conv, pool, norm) in enumerate(zip(...)):
x = conv(x)
if i == 0:
x = torch.abs(x) # SincFB के बाद absolute value
x = F.leaky_relu(norm(pool(x))) # Pool → Norm → Activate
return x
num_frames और receptive_field methods:
यह methods बताते हैं कि N input samples से कितने output frames बनते हैं, और एक output frame कितने input samples को "देखता" है।
def conv1d_num_frames(num_samples, kernel_size=5, stride=1, padding=0, dilation=1):
return 1 + (num_samples + 2*padding - dilation*(kernel_size-1) - 1) // stride
यह standard convolution output size formula है:
जहाँ = padding, = dilation, = kernel size, = stride।
3.3 PyanNet — Segmentation Model
Role
PyanNet का काम है: प्रत्येक frame के लिए बताना कि कौन-से speakers active हैं।
यह एक sliding window approach में काम करता है — हर बार 10-second chunk process होता है।
Architecture
Input: Waveform chunk (batch, 1, T)
│
▼ SincNet
│ Output: (batch, F, 60) — F frames, 60 features
│
▼ Rearrange: (batch, F, 60) — LSTM input format
│
▼ Bidirectional LSTM (4 layers, hidden=128)
│ Output: (batch, F, 256) — 128×2 because bidirectional
│
▼ Linear(256→128) + LeakyReLU ×2
│
▼ Classifier Linear(128→num_powerset_classes)
│
▼ LogSoftmax (Powerset classification)
│
Output: (batch, F, num_powerset_classes)
LSTM की भूमिका
SincNet frame-level features निकालता है जो local हैं (context-free)।
LSTM temporal context add करता है — "पिछले और अगले frames को देखकर निर्णय लेना।"
Bidirectional LSTM:
- Forward LSTM: बाएँ से दाएँ process करता है
- Backward LSTM: दाएँ से बाएँ process करता है
- दोनों के outputs concatenate होते हैं
यह forward और backward context दोनों provide करता है।
LSTM equations (एक time step के लिए):
जहाँ element-wise multiplication है, sigmoid function है।
Code Walk-through
class PyanNet(nn.Module):
def forward(self, waveforms):
# Step 1: SincNet से features निकालो
outputs = rearrange(
self.sincnet(waveforms),
"batch feature frame -> batch frame feature" # LSTM के लिए shape change
)
# Step 2: LSTM से temporal context add करो
if self.l_cfg["monolithic"]:
outputs, _ = self.lstm(outputs) # Single multi-layer LSTM
else:
for i, layer in enumerate(self.lstm): # Layer-by-layer processing
outputs, _ = layer(outputs)
if i + 1 < self.l_cfg["num_layers"]:
outputs = self.dropout(outputs)
# Step 3: Linear layers
for lin in self.linears:
outputs = F.leaky_relu(lin(outputs))
# Step 4: Final classification
return self.activation_logic(outputs)
def activation_logic(self, x):
spec = self.specifications
if spec.problem == Problem.MONO_LABEL_CLASSIFICATION:
return F.log_softmax(self.classifier(x), dim=-1) # Powerset के लिए
# ...
3.4 Powerset Encoding
Problem: Multiple Speakers का Classification
एक naive approach: हर speaker के लिए एक binary classifier।
Problem: Overlapping speech में यह inefficient है।
Powerset क्या होता है?
अगर 3 speakers (A, B, C) हैं, तो possible states हैं:
-
{}— कोई नहीं बोल रहा -
{A}— सिर्फ A -
{B}— सिर्फ B -
{C}— सिर्फ C -
{A,B}— A और B दोनों -
{A,C}— A और C दोनों -
{B,C}— B और C दोनों
(Maximum 2 simultaneous speakers assume करके — {A,B,C} exclude किया)
Total = classes
यह एक multi-class classification problem बन जाता है बजाय multi-label के।
Mathematical Formulation
Powerset mapping matrix :
Multilabel to Powerset:
जहाँ multilabel indicator है।
Powerset to Multilabel:
या hard version में:
Code Explanation
class Powerset(nn.Module):
def __init__(self, num_classes: int, max_set_size: int):
super().__init__()
self.num_classes = num_classes # 3 speakers
self.max_set_size = max_set_size # max 2 simultaneous
# Mapping matrix register करो (trainable नहीं)
self.register_buffer("mapping", self.build_mapping(), persistent=False)
def build_mapping(self):
# Shape: (num_powerset_classes, num_classes)
mapping = torch.zeros(self.num_powerset_classes, self.num_classes)
powerset_k = 0
for set_size in range(0, self.max_set_size + 1):
for current_set in itertools.combinations(range(self.num_classes), set_size):
mapping[powerset_k, current_set] = 1 # इस set के speakers को 1 set करो
powerset_k += 1
return mapping
def to_multilabel(self, powerset: torch.Tensor, soft: bool = False):
if soft:
powerset_probs = torch.exp(powerset) # log_softmax → probabilities
else:
# Hard: argmax लो, one-hot बनाओ
powerset_probs = F.one_hot(torch.argmax(powerset, dim=-1), ...).float()
# Matrix multiplication: (T, P) × (P, C) = (T, C)
return torch.matmul(powerset_probs, self.mapping)
3.5 Binarization — Binarize और binarize Functions
Role
Segmentation model probabilities output करता है (0 से 1 के बीच)।
हमें binary decision चाहिए: "यह speaker active है या नहीं?"
Hysteresis Thresholding
Simple threshold (probability > 0.5 → active) noisy होती है — छोटे fluctuations की वजह से बार-बार on/off हो सकता है।
Hysteresis (Two-threshold) approach:
-
onsetthreshold (default: 0.5): inactive → active transition -
offsetthreshold (default: same as onset): active → inactive transition
State Machine:
INACTIVE ──(prob > onset)──→ ACTIVE
ACTIVE ──(prob < offset)──→ INACTIVE
onset = offset = 0.5 के साथ यह simple thresholding है।
लेकिन onset > offset रखने से:
- Activate होने के लिए ज़्यादा confidence चाहिए
- Deactivate होने के लिए ज़्यादा confidence चाहिए
- Result: Smoother, less chattery output
Mathematical Formulation
State transition:
Minimum Duration Post-processing
छोटे active/inactive segments को merge या remove किया जाता है:
-
min_duration_on: इससे छोटे active segments delete करो -
min_duration_off: इससे छोटे gaps को fill करो (consecutive active segments को merge)
numpy बनाम Annotation-based Binarization
Code में दो binarize functions हैं:
binarize()(numpy): SlidingWindowFeature को binary SlidingWindowFeature में convert करता है — segment-level use के लिए।Binarize.__call__()(Annotation): Frame-level binary output को timeline annotation में convert करता है।
class Binarize:
def __call__(self, scores: SlidingWindowFeature) -> Annotation:
active = Annotation()
for k, k_scores in enumerate(scores.data.T): # हर speaker के लिए
label = k
start = timestamps[0]
is_active = k_scores[0] > self.onset # Initial state
for t, y in zip(timestamps[1:], k_scores[1:]):
if is_active:
if y < self.offset: # Active → Inactive
region = Segment(start, t)
active[region, track] = label
is_active = False
else:
if y > self.onset: # Inactive → Active
start = t
is_active = True
3.6 Speaker Count Estimation — speaker_count()
Purpose
दो level की जानकारी चाहिए:
- Segmentation: Frame-level में कौन-सा speaker index active है
- Count: उस frame में कितने speakers एक साथ बोल रहे हैं
Count, reconstruction में काम आता है — हम top-N speakers को active रखते हैं।
Process
@staticmethod
def speaker_count(binarized_segmentations, frames, warm_up=(0.1, 0.1)):
# Step 1: Warm-up regions trim करो
trimmed = Inference.trim(binarized_segmentations, warm_up=warm_up)
# Step 2: हर frame में active speakers sum करो
# बinarized_segmentations shape: (num_chunks, num_frames, num_speakers)
# sum over last axis → (num_chunks, num_frames, 1)
count = Inference.aggregate(
np.sum(trimmed, axis=-1, keepdims=True),
frames,
hamming=False,
missing=0.0,
skip_average=False,
)
# Step 3: Round to nearest integer
count.data = np.rint(count.data).astype(np.uint8)
return count
Warm-up trimming क्यों?
Sliding window के edges पर model के predictions less reliable होते हैं क्योंकि उनके पास कम context होता है। इसलिए हर chunk के 10% left और 10% right को trim करते हैं।
3.7 WeSpeakerResNet34 — Speaker Embeddings
Role
यह model किसी भी audio segment को एक fixed-length vector (embedding) में convert करता है जो उस speaker की identity represent करता है।
Desired property:
- Same speaker के दो अलग segments → similar vectors
- Different speakers के segments → dissimilar vectors
Feature Extraction: Filter Bank (FBank)
ResNet raw waveform नहीं लेता, बल्कि Mel Filter Bank features लेता है।
Mel Scale:
Human ear सभी frequencies को equally नहीं सुनता — low frequencies में ज़्यादा sensitive है। Mel scale इसे mimic करती है:
STFT (Short-Time Fourier Transform):
पहले audio को overlapping windows में divide किया जाता है:
जहाँ Hamming window है, hop length है।
Filter Bank:
STFT के power spectrum पर triangular mel filters apply होते हैं:
जहाँ -th mel filter है।
Log Compression:
यह dynamic range compress करता है और human auditory perception को mimic करता है।
Code:
def compute_fbank(self, waveforms):
waveforms = waveforms * (1 << 15) # Normalize to 16-bit range
features = torch.vmap(kaldi.fbank)(
waveforms.unsqueeze(1),
num_mel_bins=80,
window_type="hamming"
)
return features - torch.mean(features, dim=1, keepdim=True) # Mean subtraction
ResNet Architecture
Input: FBank features — shape (batch, time, 80)
Layers:
Input: (B, T, 80)
↓ permute → (B, 80, T) → unsqueeze → (B, 1, 80, T)
Conv2d(1, 32, 3×3) → BN → ReLU
↓
ResBlock × 3 (32 filters)
↓
ResBlock × 4 (64 filters, stride=2)
↓
ResBlock × 6 (128 filters, stride=2)
↓
ResBlock × 3 (256 filters, stride=2)
↓
TSTP Pooling (Time-aware Statistics Pooling)
↓
Linear(dim→256)
↓
Output: (B, 256) — 256-dim embedding
Residual Block
class BasicBlock(nn.Module):
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x))) # Conv → BN → ReLU
out = self.bn2(self.conv2(out)) # Conv → BN (no ReLU yet)
out = out + self.shortcut(x) # Residual connection
return F.relu(out) # अब ReLU
Residual Connection:
जहाँ दो convolutions का output है। यह deep networks में gradient flow को improve करता है।
Statistics Pooling (TSTP)
Variable length temporal sequence को fixed-length vector में convert करना।
Input: (B, C, T) — batch, channels, time
Output: (B, 2C) — mean और std concatenate
Formula:
Weighted version (mask के साथ):
जब segmentation mask दिया जाता है, तो सिर्फ active frames का contribute होता है:
जहाँ mask value है (0 या 1)।
3.8 VBx Clustering
Overview
Clustering का काम है: विभिन्न chunks के embeddings को देखकर यह decide करना कि कौन-सी embeddings एक ही speaker की हैं।
तीन stages हैं:
- AHC (Agglomerative Hierarchical Clustering) — initial coarse clustering
- VBx (Variational Bayes × HMM) — refined probabilistic clustering
- KMeans (optional) — अगर requested speaker count अलग हो
Stage 1: AHC (Agglomerative Hierarchical Clustering)
Algorithm:
- शुरुआत में हर embedding एक अलग cluster होती है
- सबसे similar दो clusters को merge करो
- जब तक threshold न आए, repeat करो
Cosine Distance:
Similar speakers के vectors का cosine similarity high होगा → distance कम।
Linkage Method (Centroid):
जहाँ
cluster centroids हैं।
train_embeddings_normed = train_embeddings / np.linalg.norm(
train_embeddings, axis=1, keepdims=True
) # L2 normalize
dendrogram = linkage(
train_embeddings_normed,
method="centroid",
metric="euclidean"
)
ahc_clusters = fcluster(dendrogram, self.threshold, criterion="distance") - 1
Stage 2: PLDA Transform
VBx को directly embeddings पर नहीं apply करते — पहले PLDA (Probabilistic Linear Discriminant Analysis) transform apply करते हैं।
PLDA का उद्देश्य:
Speaker space और channel/noise space को separate करना।
LDA Transform:
जहाँ LDA matrix है, mean है।
L2 Normalization:
PLDA Projection:
जहाँ पहले
dimensions रखे जाते हैं।
def vbx_setup(t_npz, p_npz):
# Transform matrices load करो
m1, m2, lda = t["mean1"], t["mean2"], t["lda"]
mu, tr, psi = p["mu"], p["tr"], p["psi"]
# LDA transform function
xvec_tf = lambda x: np.sqrt(lda.shape[1]) * l2_norm(
(lda.T @ (np.sqrt(lda.shape[0]) * l2_norm(x - m1)).T).T - m2
)
# PLDA projection function
plda_tf = lambda x, lda_dim=lda.shape[1]: ((x - mu) @ tr.T)[:, :lda_dim]
return xvec_tf, plda_tf, psi
Stage 3: VBx Algorithm
VBx एक probabilistic model है जो speaker assignments को soft probabilities के रूप में estimate करता है।
Generative Model Assumptions:
हर observation (frame t का embedding) को इस प्रकार model किया जाता है:
जहाँ:
- — speaker assignment (latent variable)
- — speaker का embedding (भी latent)
- — between-speaker variance (PLDA से)
- — scaling factors
ELBO Optimization:
VBx variational inference के through speaker assignments optimize करता है:
इसे iteratively maximize किया जाता है।
E-step (responsibilities update):
M-step (speaker model update):
जहाँ:
- = frame का speaker के प्रति responsibility
- = speaker का prior weight
Code:
def VBx(X, Phi, Fa=1.0, Fb=1.0, pi=10, gamma=None, maxIters=10, epsilon=1e-4, ...):
D = X.shape[1]
G = -0.5 * (np.sum(X**2, axis=1, keepdims=True) + D * np.log(2 * np.pi))
V = np.sqrt(Phi) # Between-speaker variance
rho = X * V # Scaled observations
for ii in range(maxIters):
# M-step: Speaker models update करो
invL = 1.0 / (1 + Fa/Fb * gamma.sum(axis=0, keepdims=True).T * Phi)
alpha = Fa/Fb * invL * gamma.T.dot(rho)
# E-step: Responsibilities update करो
log_p_ = Fa * (rho.dot(alpha.T) - 0.5 * (invL + alpha**2).dot(Phi) + G)
lpi = np.log(pi + 1e-8)
log_p_x = logsumexp(log_p_ + lpi, axis=-1)
gamma = np.exp(log_p_ + lpi - log_p_x[:, None])
# Prior update
pi = gamma.sum(axis=0)
pi = pi / pi.sum()
# ELBO compute करो
ELBO = log_p_x.sum() + Fb * 0.5 * np.sum(np.log(invL) - invL - alpha**2 + 1)
# Convergence check
if ii > 0 and ELBO - Li[-2][0] < epsilon:
break
return gamma, pi, Li, alpha, invL
Output gamma: Shape (T, K) — हर frame की हर speaker के प्रति responsibility।
Centroids Calculation
VBx के बाद, जिन speakers का prior
है उन्हें रखते हैं:
W = q[:, sp > 1e-7] # Valid speakers की responsibilities
centroids = W.T @ train_embeddings.reshape(-1, dimension) / W.sum(0, keepdims=True).T
यह weighted average है:
Constrained Assignment
Clustering के बाद embeddings को clusters assign करना होता है। Naive argmax problematic है क्योंकि एक ही chunk में एक speaker को multiple segments में assign नहीं करना चाहिए।
Hungarian Algorithm (Linear Sum Assignment):
यह optimal assignment problem solve करता है: हर speaker को exactly एक cluster assign करो जिससे total similarity maximize हो।
Subject to:
bijection होनी चाहिए (injective mapping)।
def constrained_argmax(self, soft_clusters):
# soft_clusters shape: (num_chunks, num_speakers, num_clusters)
for c, cost in enumerate(soft_clusters):
# Hungarian algorithm
speakers, clusters = linear_sum_assignment(cost, maximize=True)
for s, k in zip(speakers, clusters):
hard_clusters[c, s] = k
return hard_clusters
3.9 Label Assignment और Reconstruction
reconstruct() Method
Hard cluster assignments को वापस frame-level segmentation format में convert करना।
def reconstruct(self, segmentations, hard_clusters, count):
# Output: (num_chunks, num_frames, num_clusters)
clustered_segmentations = np.zeros((num_chunks, num_frames, num_clusters))
for c, (cluster, (_, segmentation)) in enumerate(zip(hard_clusters, segmentations)):
for k in np.unique(cluster):
if k == -2: continue # Inactive speakers skip
# इस cluster के सभी speakers के max probability लो
clustered_segmentations[c, :, k] = np.max(
segmentation[:, cluster == k], axis=1
)
return self.to_diarization(clustered_segmentations, count)
to_diarization() Method
Aggregated scores को binary diarization में convert करना, speaker count को constraint के रूप में उपयोग करते हुए।
@staticmethod
def to_diarization(segmentations, count):
# Step 1: Overlap-add aggregation
activations = Inference.aggregate(segmentations, ...)
# Step 2: हर frame में count के according top speakers रखो
sorted_speakers = np.argsort(-activations, axis=-1) # Score के हिसाब से sort
for t, ((_, c), speakers) in enumerate(zip(count, sorted_speakers)):
for i in range(c.item()): # सिर्फ c speakers रखो
binary[t, speakers[i]] = 1.0
return SlidingWindowFeature(binary, activations.sliding_window)
Final Output Format
return {
"diarization": [
{"start": 0.0, "end": 2.5, "speaker": 0},
{"start": 2.5, "end": 5.1, "speaker": 1},
...
],
"exclusive_diarization": [...], # Overlap के बिना
"speaker_embeddings": centroids # (N_speakers, 256) numpy array
}
4. गणितीय समझ (Mathematical Intuition)
4.1 Overlap-Add Aggregation
Sliding window से आई predictions को aggregate करने की ज़रूरत होती है। हर frame पर multiple chunks contribute करते हैं।
Formula:
जहाँ:
- = chunk में frame की prediction
- = warm-up window (edges पर , center पर 1)
- = Hamming window (smooth aggregation के लिए)
Hamming Window:
यह edges पर 0 और center पर 1 के करीब होती है, जिससे chunk boundaries पर artifacts कम होते हैं।
4.2 Cosine Similarity vs Euclidean Distance
Embedding space में similarity measure करने के दो तरीके:
Cosine Similarity:
L2-normalized vectors पर Euclidean Distance:
जहाँ ।
यानी L2-normalized vectors पर Euclidean distance और Cosine distance equivalent हैं। इसीलिए AHC में L2-normalize करके Euclidean metric use करते हैं।
4.3 Evidence Lower Bound (ELBO)
VBx में exact inference intractable है। Variational inference में हम एक tractable distribution find करते हैं जो true posterior को approximate करे।
ELBO maximize करना = KL divergence minimize करना = better approximation।
5. Deep Learning की सरल व्याख्या
5.1 Neural Networks Embeddings कैसे बनाते हैं
एक neural network एक function है:
Training के दौरान parameters इस प्रकार adjust होते हैं कि:
- Same speaker के inputs → similar outputs (close in vector space)
- Different speaker के inputs → dissimilar outputs (far in vector space)
यह speaker verification loss (जैसे Additive Margin Softmax) से achieve होता है।
5.2 Training vs Inference
Training Phase:
- Labeled data से gradient descent के through parameters सीखे जाते हैं
- Loss function speaker identity की correctness measure करती है
Inference Phase:
- Frozen parameters (no gradient computation)
- Input → Forward pass → Output
-
torch.inference_mode()memory और speed optimize करता है
5.3 Vector Space Representation
256-dimensional embedding space में:
Geometric intuition:
Same speaker के embeddings एक cluster बनाते हैं।
Different speakers के clusters अलग-अलग regions में होते हैं।
यही property clustering को possible बनाती है।
5.4 Sliding Window Inference
पूरी audio एक साथ process करना:
- Memory-intensive (hours की recording)
- Model को fixed-size input चाहिए
Solution:
Audio: ─────────────────────────────────────────
[ chunk 1 ]
[ chunk 2 ]
[ chunk 3 ]
[ chunk 4 ]
हर chunk independently process होता है। Predictions को overlap-add से combine किया जाता है।
6. Line-by-Line Code Explanation
6.1 SpeakerDiarization.__init__()
class SpeakerDiarization(nn.Module):
def __init__(self, ...):
super().__init__()
# 1. Segmentation model बनाओ
self._segmentation_model = PyanNet(
specifications=Specifications(
Problem.MONO_LABEL_CLASSIFICATION, # Powerset classification
Resolution.FRAME, # Frame-level output
10.0, # 10-second chunks
classes=['speaker#1', 'speaker#2', 'speaker#3'], # 3 speakers max
powerset_max_classes=2, # Max 2 simultaneous
permutation_invariant=True # Speaker order arbitrary
)
)
# 2. Embedding model बनाओ
self._embedding = WeSpeakerResNet34()
# 3. PLDA बनाओ
self._plda = PLDA(
hf_hub_download(..., "plda/xvec_transform.npz"),
hf_hub_download(..., "plda/plda.npz")
)
# 4. Clustering engine
self.clustering = VBxClustering(self._plda)
# 5. Inference wrapper (sliding window handle करता है)
self._segmentation = Inference(
self._segmentation_model,
duration=10.0, # 10-sec windows
step=self.segmentation_step * 10.0, # 0.1 × 10 = 1-sec step
skip_aggregation=True, # Raw chunk outputs चाहिए
)
# 6. Pretrained weights load करो
self._segmentation_model.load_state_dict(
tb.state_bridge(load_file(...), """
linear,linears # Key renaming rules
conv1d,layers.conv
.norm1d,.layers.norm
""")
)
tb.state_bridge() क्या करता है?
Pretrained checkpoint के key names और current model के key names अलग हो सकते हैं। यह function mapping rules लेकर keys rename करता है।
6.2 forward() Method — Main Pipeline
@torch.inference_mode()
def forward(self, file, num_speakers=None, min_speakers=None, max_speakers=None):
# Step 1: Speaker count constraints setup करो
num_speakers, min_speakers, max_speakers = set_num_speakers(
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
)
# Default: min=1, max=∞
# Step 2: Segmentation चलाओ
segmentations = self._segmentation(file)
# Output: SlidingWindowFeature, shape (num_chunks, num_frames, num_powerset_classes)
# Step 3: Binarize करो
binarized_segmentations = binarize(segmentations, initial_state=False)
# Output: SlidingWindowFeature, shape (num_chunks, num_frames, num_speakers)
# Step 4: Speaker count estimate करो
count = self.speaker_count(
binarized_segmentations,
self._segmentation_model.receptive_field,
warm_up=(0.0, 0.0),
)
# Output: SlidingWindowFeature, shape (num_frames, 1)
# Early exit: अगर कोई speaker नहीं
if np.nanmax(count.data) == 0.0:
return
# Step 5: Embeddings निकालो
embeddings = self.get_embeddings(
file,
binarized_segmentations,
exclude_overlap=self.embedding_exclude_overlap
)
# Output: numpy array, shape (num_chunks, num_speakers, 256)
# Step 6: Cluster करो
hard_clusters, _, centroids = self.clustering(
embeddings=embeddings,
segmentations=binarized_segmentations,
num_clusters=num_speakers,
min_clusters=min_speakers,
max_clusters=max_speakers,
)
# Output: hard_clusters shape (num_chunks, num_speakers)
# Step 7: Inactive speakers handle करो
inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0
hard_clusters[inactive_speakers] = -2 # -2 = inactive marker
# Step 8: Reconstruct timeline
discrete_diarization = self.reconstruct(segmentations, hard_clusters, count)
# Step 9: Annotation format में convert करो
diarization = self.to_annotation(discrete_diarization)
# Step 10: Labels rename करो (0→SPEAKER_00, 1→SPEAKER_01, ...)
mapping = {label: expected for label, expected in zip(diarization.labels(), self.classes())}
diarization = diarization.rename_labels(mapping=mapping)
return {"diarization": [...], "exclusive_diarization": [...], "speaker_embeddings": centroids}
6.3 get_embeddings() — EmbeddingDataset का उपयोग
def get_embeddings(self, file, binary_segmentations, exclude_overlap=False):
# Step 1: Clean segmentations बनाओ (overlap-free regions)
if exclude_overlap:
# Single speaker वाले frames identify करो
clean_frames = 1.0 * (np.sum(binary_segmentations.data, axis=2, keepdims=True) < 2)
clean_segmentations = SlidingWindowFeature(
binary_segmentations.data * clean_frames,
binary_segmentations.sliding_window,
)
# Step 2: Dataset और DataLoader बनाओ
dataset = EmbeddingDataset(
file=file,
binary_segmentations=binary_segmentations,
clean_segmentations=clean_segmentations,
audio=self._audio,
min_num_frames=min_num_frames,
)
loader = DataLoader(dataset, batch_size=self.embedding_batch_size, pin_memory=True)
# Step 3: Embeddings pre-allocate करो
embeddings = np.zeros((num_chunks, num_speakers, self._embedding.dimension))
# Step 4: Batch-by-batch inference
for waveforms, masks, chunk_idxs, speaker_idxs in tqdm(loader):
waveforms = waveforms.to(device) # GPU पर move करो
masks = masks.to(device)
# Forward pass
batch_embeddings = self._embedding(waveforms, masks) # (B, 256)
# सही position पर store करो
embeddings[chunk_idxs.numpy(), speaker_idxs.numpy()] = batch_embeddings.cpu().numpy()
return embeddings
6.4 EmbeddingDataset.__getitem__()
def __getitem__(self, idx):
chunk_idx, speaker_idx = self.indices[idx]
# RAM से directly slice करो (disk I/O avoid)
start_sample = round(self.sliding_window[chunk_idx].start * self.sample_rate)
end_sample = start_sample + self.window_size
waveform = self.waveform[:, start_sample:end_sample] # (1, window_size)
# Last chunk के लिए padding
if waveform.shape[1] < self.window_size:
waveform = F.pad(waveform, (0, self.window_size - waveform.shape[1]))
mask = self.seg_data[chunk_idx, :, speaker_idx] # Segmentation mask
clean_mask = self.clean_data[chunk_idx, :, speaker_idx] # Overlap-free mask
# Prefer clean mask (overlap-free), otherwise use regular mask
used_mask = clean_mask if np.sum(clean_mask) > self.min_num_frames else mask
return (
waveform.squeeze(0), # (window_size,)
torch.from_numpy(used_mask), # (num_frames,)
chunk_idx,
speaker_idx,
)
Design Decision:
पूरी audio को पहले RAM में load करते हैं (self.waveform), फिर __getitem__ में RAM से slice करते हैं। यह disk I/O को बार-बार करने से बचाता है और DataLoader workers के साथ efficient है।
7. प्रायोगिक समझ (Practical Insights)
7.1 Libraries का Role
| Library | Role |
|---|---|
torchaudio |
Audio loading, resampling, FBank computation |
einops |
Tensor reshaping (readable notation) |
scipy |
AHC clustering, Voronoi, signal processing |
sklearn |
KMeans clustering |
pyannote.core |
Speech timeline data structures |
huggingface_hub |
Pretrained model weights download |
safetensors |
Secure, fast model weight format |
torch_state_bridge |
State dict key mapping |
asteroid_filterbanks |
SincFB implementation |
7.2 Memory Efficiency
Problem: 1-hour audio at 16kHz = 16000 × 3600 = 57.6M samples × 4 bytes = ~230 MB सिर्फ waveform के लिए।
Solutions implemented:
-
Lazy Segmentation Processing:
Inference.slide()streaming batch processing करता है - EmbeddingDataset: पूरी audio एक बार load, फिर in-memory slicing
-
pin_memory=Truein DataLoader: CPU-to-GPU transfer तेज़ होती है -
@torch.inference_mode(): Gradient graph नहीं बनता → कम memory
7.3 GPU vs CPU
device = next(self.parameters()).device # Model जहाँ है, वहाँ data भेजो
waveforms = waveforms.to(device)
self.conversion.to(device)
Inference GPU पर fast है, लेकिन clustering CPU पर (numpy)। इसीलिए embeddings CPU पर वापस आते हैं:
batch_embeddings = batch_embeddings.cpu().numpy()
7.4 Sliding Window Parameters का Impact
| Parameter | Default | Effect |
|---|---|---|
duration |
10s | ज़्यादा → better context, ज़्यादा memory |
step |
1s (0.1×10) | कम → smoother output, ज़्यादा computation |
warm_up |
(0.1, 0.1) | ज़्यादा → edges ignore, कम overlap artifacts |
7.5 segmentation_step = 0.1 का Meaning
step = self.segmentation_step * segmentation_duration
= 0.1 × 10.0 = 1.0 seconds
हर chunk 10 seconds का है। Consecutive chunks के बीच 1 second का step है।
यानी हर chunk अपने पड़ोसी chunk के साथ 9 seconds overlap share करता है।
यह high overlap दो कारणों से ज़रूरी है:
- Smooth aggregation
- Boundary regions में accurate predictions
8. सामान्य गलतियाँ और समाधान
8.1 Sample Rate Mismatch
गलती:
# 8kHz audio को directly SincNet में दे दिया
model(audio_8khz) # SincNet केवल 16kHz support करता है!
समाधान:
# Audio class हमेशा resample करती है
audio = Audio(sample_rate=16000, mono="downmix")
waveform, sr = audio(file) # Automatic resampling
Symptom: Poor segmentation accuracy, या NotImplementedError.
8.2 Wrong Tensor Shape
गलती:
# Waveform shape (16000,) है, model को (1, 16000) चाहिए
model(waveform) # Error या wrong output
समाधान:
waveform = waveform.unsqueeze(0) # (16000,) → (1, 16000) — channel dimension add
# या
model(waveform[None]) # (1, 16000) → (1, 1, 16000) with batch dimension
8.3 GPU-CPU Tensor Mixing
गलती:
embeddings_cpu = embeddings.numpy() # Error! GPU tensor को numpy नहीं बना सकते
समाधान:
embeddings_cpu = embeddings.cpu().numpy() # पहले CPU पर, फिर numpy
8.4 Inactive Speaker Handling
गलती:
# `-2` cluster को valid speaker मान लेना
for k in np.unique(hard_clusters):
process_speaker(k) # -2 को भी process करेगा!
समाधान:
for k in np.unique(hard_clusters):
if k == -2: continue # Inactive marker skip करो
process_speaker(k)
8.5 Short Audio Files
समस्या: 10-second से छोटी audio पर Inference.slide() कोई chunk नहीं बनाता।
Code में handling:
has_last_chunk = (num_samples < window_size) or ...
if has_last_chunk:
last_chunk = waveform[:, num_chunks * step_size:]
last_pad = window_size - last_window_size
last_chunk = F.pad(last_chunk, (0, last_pad)) # Zero padding
Recommendation: Audio कम से कम 1-2 seconds की होनी चाहिए।
8.6 Memory Error for Long Files
समस्या: बहुत लंबी recordings पर EmbeddingDataset बनाते समय OOM error।
Debug approach:
# Check करो कितने embedding pairs हैं
num_chunks = binary_segmentations.data.shape[0]
num_speakers = binary_segmentations.data.shape[2]
total_samples = num_chunks * num_speakers # यह बड़ा हो सकता है!
समाधान:
embedding_batch_size कम करो, या audio को segments में process करो।
9. निष्कर्ष (Conclusion)
9.1 पूरे Pipeline का सार
Raw Audio
→ Audio Loading (16kHz mono)
→ SincNet (learned filterbank features)
→ PyanNet/LSTM (frame-level speaker activity)
→ Powerset Binarization (binary segment masks)
→ Speaker Count (per-frame active speaker count)
→ WeSpeakerResNet34 (256-dim speaker embeddings)
→ PLDA Transform (speaker-discriminant space)
→ AHC Initial Clustering
→ VBx Refinement (probabilistic, EM-based)
→ Constrained Assignment (Hungarian algorithm)
→ Timeline Reconstruction
→ Final Annotation: {start, end, speaker_id}
9.2 Key Design Decisions की Summary
| Decision | Reason |
|---|---|
| Powerset encoding | Overlap को single classification problem में handle |
| SincNet (vs MFCC) | End-to-end learnable, phonetically interpretable |
| Bidirectional LSTM | Left और right context दोनों |
| Weighted statistics pooling | Silence/overlap frames को ignore |
| VBx over pure AHC | Soft assignments, automatic speaker count |
| Hungarian assignment | Chunk-level consistency |
| Overlap-add aggregation | Smooth predictions across chunk boundaries |
9.3 संभावित सुधार (Future Scope)
1. End-to-end Training:
वर्तमान में segmentation और embedding models अलग-अलग trained हैं। Joint training से performance improve हो सकती है।
2. Streaming Inference:
Real-time applications के लिए online (non-buffered) processing।
3. Multi-channel Audio:
Multiple microphones की spatial information से better separation।
4. Speaker Adaptation:
अगर किसी speaker के reference recordings हों, तो उन्हें incorporate करके accuracy बढ़ाना।
5. Neural Clustering:
VBx को neural end-to-end clustering से replace करना।
9.4 इस ब्लॉग के बाद क्या पढ़ें?
- Pyannote Audio paper (Bredin et al.) — segmentation model की original paper
- WeSpeaker (Wang et al.) — embedding model
- VBx paper (Diez et al.) — clustering algorithm
- PLDA (Prince & Elder) — probabilistic discriminant analysis
- End-to-End Neural Diarization (Fujita et al.) — alternative approach
यह ब्लॉग उस codebase पर आधारित है जो pyannote-community/speaker-diarization-community-1 और shethjenil/speaker-diarization HuggingFace repositories के pretrained models उपयोग करता है।
Top comments (0)