DEV Community

Cover image for Cracking the Medical Coding Challenge: Fine-Tuning BioBERT for ICD-10 Classification (Part 1)
Alex Retana
Alex Retana

Posted on

Cracking the Medical Coding Challenge: Fine-Tuning BioBERT for ICD-10 Classification (Part 1)

The Problem That Keeps Medical Coders Up at Night

Imagine you're processing disability claims for veterans. Each claim contains dense medical documentation—thousands of characters describing symptoms, diagnoses, and treatment history. Your job? Extract the correct ICD-10 diagnostic codes from this narrative. Miss a code, and a veteran might not receive the benefits they've earned. Add an incorrect code, and you've created compliance issues.

Now imagine doing this hundreds of times per day, under pressure, with 158+ possible diagnosis codes to remember.

This is exactly the type of problem that makes medical coding both critically important and incredibly challenging. And it's the perfect use case for Natural Language Processing (NLP). But here's the catch: training an AI to do this isn't straightforward, especially when you're dealing with limited training data and severe class imbalance.

In this two-part series, I'll walk you through building an automated medical coding system. Part 1 (this article) focuses on fine-tuning BioBERT with advanced techniques to handle real-world constraints. Part 2 will explore AWS Comprehend Medical as an alternative approach and compare the two solutions.

🔗 GitHub Repository

Why This Project Matters: Real-World Use Cases

Before diving into code, let's talk about why automated medical coding matters:

1. Disability Claims Processing

Veterans Affairs (VA) processes millions of disability claims. Each claim requires accurate ICD-10 coding to determine eligibility and compensation levels. Manual coding creates bottlenecks and inconsistencies.

2. Healthcare Revenue Cycle Management

Hospitals lose billions annually due to coding errors. Automated coding assistance can flag potential issues before claims are submitted to insurance companies.

3. Clinical Research

Large-scale medical studies require consistent coding of patient records. Automated extraction enables researchers to identify patient cohorts more efficiently.

4. Compliance and Auditing

Healthcare organizations must ensure coding accuracy for regulatory compliance. AI systems can audit existing codes and identify discrepancies.

The Dataset: MedCodER and Its Challenges

For this project, we're using the MedCodER (Medical Coding with Explanations and Retrievals) dataset, which contains:

  • 500+ clinical documents with full SOAP notes (Subjective, Objective, Assessment, Plan)
  • 158 unique ICD-10-CM codes
  • Supporting evidence annotations showing which text spans support each diagnosis
  • Severe class imbalance: Most codes appear fewer than 10 times

Here's what makes this dataset challenging (and realistic):

# Class distribution snapshot
Total unique codes: 158
Codes with 80 samples: 18  # Only 11% have sufficient training data!
Codes with 50 samples: 25
Codes with <10 samples: 98  # 62% are extremely rare
Enter fullscreen mode Exit fullscreen mode

This mirrors real-world medical data perfectly—common conditions like diabetes and hypertension appear frequently, while rare diseases have minimal examples.

The Naive Approach (And Why It Fails Spectacularly)

Let's talk about what doesn't work. Your first instinct might be:

  1. Take full 2000+ character clinical documents
  2. Feed them to BioBERT
  3. Train on all 158 classes
  4. Hope for the best

Result: Macro F1 score of 0.023 (2.3%). Essentially random guessing.

Why does this fail?

Problem 1: Signal Dilution
A 2000-character document might contain only 50-100 characters actually describing a specific diagnosis. The rest is noise—patient demographics, vital signs, medication lists, etc.

Problem 2: Insufficient Training Data
With only 500 documents and 158 classes, you have an average of ~3 examples per class. Deep learning models need orders of magnitude more data.

Problem 3: Catastrophic Overfitting
BioBERT has 110 million parameters. Training all of them on tiny datasets causes the model to memorize training examples rather than learn generalizable patterns.

The Solution: A Five-Pronged Strategy

To achieve a 94.4% Macro F1 score (a 4,000% improvement!), we implement five key techniques:

1. Evidence-Focused Training

2. Label Space Optimization

3. Back-Translation Data Augmentation

4. LoRA Parameter-Efficient Fine-Tuning

5. Class-Weighted Loss Function

Let's dive into each one.


Technique 1: Evidence-Focused Training

The Problem: Training on 2000-character documents dilutes the diagnostic signal.

The Solution: Use the supporting evidence annotations to extract focused diagnostic spans (~150-200 characters) with context.

def extract_evidence_text(row):
    """Extract evidence span from full document text"""
    start = int(row['Start'])
    end = int(row['End'])

    # Extract with ±50 character context window
    context_start = max(0, start - 50)
    context_end = min(len(row['medical_record_text']), end + 50)

    return row['medical_record_text'][context_start:context_end]
Enter fullscreen mode Exit fullscreen mode

Why this works: We're giving the model concentrated diagnostic information. Instead of finding a needle in a haystack, we're handing it the needle.

Example transformation:

Full Document (2,347 chars):

[Long patient history, demographics, vitals, multiple conditions mixed together...]
Enter fullscreen mode Exit fullscreen mode

Evidence Span (189 chars):

"...blood pressure remains elevated at 156/94 despite medication compliance. 
Diagnosis: Essential (primary) hypertension. Will increase lisinopril dose..."
Enter fullscreen mode Exit fullscreen mode

Consequence of skipping this step:

Without evidence extraction, the model struggles to differentiate signal from noise. You'd see F1 scores plateau around 20-30% even with other optimizations.


Technique 2: Label Space Optimization

The Problem: 62% of codes have fewer than 10 training examples—impossible to learn from.

The Solution: Filter to codes with ≥80 examples, reducing from 158 codes to 18 viable classes.

MIN_SAMPLES = 80
code_freq = evidence_focused['ICD10'].value_counts()
frequent_codes = code_freq[code_freq >= MIN_SAMPLES].index.tolist()

evidence_filtered = evidence_focused[
    evidence_focused['ICD10'].isin(frequent_codes)
].reset_index(drop=True)

print(f"Reduced to {len(frequent_codes)} codes")  # 18 codes
print(f"Retained {len(evidence_filtered)} examples")  # ~1,200 examples
Enter fullscreen mode Exit fullscreen mode

Why this works: Machine learning requires sufficient examples to learn patterns. By focusing on codes with adequate representation, we ensure the model can actually learn meaningful relationships.

The trade-off: We sacrifice coverage (18 codes vs. 158) for accuracy. This is acceptable in a hybrid system where:

  • Custom model handles frequent codes (high accuracy)
  • Commercial API handles rare codes (broader coverage, lower accuracy)

Consequence of skipping this step:

Including rare codes creates extreme class imbalance. The model would:

  • Ignore rare classes entirely (predicting only common ones)
  • Waste capacity trying to memorize insufficient examples
  • Achieve poor performance across all classes

Technique 3: Back-Translation Data Augmentation

The Problem: Even after filtering, we only have ~1,200 training examples for 18 classes (~67 examples per class). Still limited.

The Solution: Use back-translation to generate synthetic training data.

def back_translate(text, pivot_lang='de'):
    """Translate EN→DE→EN to create paraphrased version"""

    # EN → German
    fwd_model = MarianMTModel.from_pretrained(f'Helsinki-NLP/opus-mt-en-{pivot_lang}')
    fwd_tokenizer = MarianTokenizer.from_pretrained(f'Helsinki-NLP/opus-mt-en-{pivot_lang}')

    fwd_inputs = fwd_tokenizer(text, return_tensors='pt', truncation=True)
    fwd_outputs = fwd_model.generate(**fwd_inputs)
    german_text = fwd_tokenizer.decode(fwd_outputs[0], skip_special_tokens=True)

    # German → EN
    bwd_model = MarianMTModel.from_pretrained(f'Helsinki-NLP/opus-mt-{pivot_lang}-en')
    bwd_tokenizer = MarianTokenizer.from_pretrained(f'Helsinki-NLP/opus-mt-{pivot_lang}-en')

    bwd_inputs = bwd_tokenizer(german_text, return_tensors='pt', truncation=True)
    bwd_outputs = bwd_model.generate(**bwd_inputs)
    back_translated = bwd_tokenizer.decode(bwd_outputs[0], skip_special_tokens=True)

    return back_translated
Enter fullscreen mode Exit fullscreen mode

Example transformation:

Original:

"Patient reports persistent chest pain radiating to left arm with 
shortness of breath during physical exertion."
Enter fullscreen mode Exit fullscreen mode

After EN→DE→EN:

"Patient experiences continuous chest pain extending to the left arm 
with breathing difficulty during physical activity."
Enter fullscreen mode Exit fullscreen mode

Why this works: The semantic meaning remains identical, but the phrasing varies. This teaches the model to recognize diagnoses regardless of how they're worded—critical for handling real-world clinical variation.

Best practice: Use multiple pivot languages (German, French, Spanish) for 4x data expansion. In our demo, we use German for 1.2x expansion to save time.

Critical requirement: Keep 100% original data in validation set

# Split BEFORE augmentation
train_orig, val_orig = train_test_split(original_df, test_size=0.2)

# Augment ONLY training data
train_augmented = augment_with_back_translation(train_orig)
train_final = pd.concat([train_orig, train_augmented])

# Validation stays 100% original
val_final = val_orig  
Enter fullscreen mode Exit fullscreen mode

Why this matters: If augmented data leaks into validation, you'll get overly optimistic metrics. The model might learn artifacts of the translation process rather than true diagnostic patterns.

Consequence of skipping this step:

Without augmentation, the model has limited exposure to linguistic variation. It might learn to recognize specific phrasings but fail on synonyms or alternative formulations—reducing real-world robustness by 10-15%.


Technique 4: LoRA (Low-Rank Adaptation) Fine-Tuning

The Problem: BioBERT has 110 million parameters. Training all of them on 1,200 examples causes severe overfitting.

The Solution: Use LoRA to train only 0.1% of parameters while keeping the rest frozen.

How LoRA Works

Instead of updating all weights in the attention layers, LoRA injects trainable low-rank matrices:

Traditional: W_new = W_old + ΔW  (update all 768×768 = 589,824 params)
LoRA: W_new = W_old + A×B  (update 768×8 + 8×768 = 12,288 params)
Enter fullscreen mode Exit fullscreen mode

Where:

  • A is a 768×8 matrix
  • B is an 8×768 matrix
  • r=8 is the rank (a hyperparameter)
from peft import LoraConfig, get_peft_model, TaskType

# Load base BioBERT model
base_model = AutoModelForSequenceClassification.from_pretrained(
    'dmis-lab/biobert-v1.1',
    num_labels=18,
    problem_type='single_label_classification'
)

# Configure LoRA
lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8,  # Rank: controls capacity vs. overfitting trade-off
    lora_alpha=16,  # Scaling factor (typically 2×r)
    lora_dropout=0.1,
    target_modules=["query", "value"],  # Apply to Q/V attention projections
    inference_mode=False
)

# Apply LoRA adapter
model = get_peft_model(base_model, lora_config)

print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f"Total params: {sum(p.numel() for p in model.parameters()):,}")
Enter fullscreen mode Exit fullscreen mode

Output:

Trainable params: 148,488 (0.13%)
Total params: 109,629,456 (100%)
Enter fullscreen mode Exit fullscreen mode

Why this works:

  • Pre-trained knowledge is preserved: BioBERT's medical understanding stays intact
  • Task-specific adaptation: The small LoRA adapters learn to map BioBERT's features to ICD-10 codes
  • Regularization effect: Limited capacity prevents memorization

Choosing the rank (r)

  • r=4: Very lightweight, may underfit complex tasks
  • r=8: Sweet spot for most tasks (used here)
  • r=16: More capacity, risk of overfitting on small datasets
  • r=32+: Approaching full fine-tuning behavior


Image above is from hugging face: https://huggingface.co/docs/peft/main/en/conceptual_guides/lora

Consequence of skipping this step:

Full fine-tuning on this dataset produces F1 scores around 20-30%. The model memorizes training examples and fails to generalize. LoRA's regularization is the difference between failure and success.


Technique 5: Class-Weighted Loss Function

The Problem: Even after filtering, we have imbalance (some codes have 200 examples, others have 80).

The Solution: Use weighted cross-entropy loss that penalizes errors on rare classes more heavily.

from sklearn.utils.class_weight import compute_class_weight

# Compute balanced class weights
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.arange(num_labels),
    y=train_df['label_id']
)

class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32)

# Custom Trainer with weighted loss
class WeightedTrainer(Trainer):
    def __init__(self, class_weights=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights.to(self.args.device)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits

        # Weighted cross-entropy loss
        loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
        loss = loss_fct(logits, labels)

        return (loss, outputs) if return_outputs else loss
Enter fullscreen mode Exit fullscreen mode

How balanced weights work:

weight[c] = n_samples / (n_classes × n_samples_in_class[c])
Enter fullscreen mode Exit fullscreen mode

Example:

  • Class A: 200 examples → weight = 1,200/(18×200) = 0.33
  • Class B: 80 examples → weight = 1,200/(18×80) = 0.83

During training, misclassifying Class B incurs 2.5× the penalty of Class A.

Consequence of skipping this step:

Without weighting, the model optimizes for overall accuracy by focusing on frequent classes. Rare classes get ignored, reducing macro F1 by 5-10%.


Putting It All Together: Training Configuration

training_args = TrainingArguments(
    output_dir='./models/biobert-lora-improved',
    eval_strategy='epoch',
    learning_rate=2e-4,  # Higher LR for LoRA (10× standard fine-tuning)
    per_device_train_batch_size=16,
    num_train_epochs=15,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model='macro_f1',
    fp16=True,  # Mixed precision for faster training
    warmup_ratio=0.1,
)

trainer = WeightedTrainer(
    class_weights=class_weights_tensor,
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()
Enter fullscreen mode Exit fullscreen mode

Key hyperparameters explained:

  • Learning rate (2e-4): Higher than typical fine-tuning (2e-5) because LoRA adapters can handle larger updates
  • Batch size (16): Balanced between GPU memory and gradient quality
  • Epochs (15): Sufficient for convergence without overfitting
  • FP16: Reduces memory usage and speeds up training by ~2×

Results: From Failure to Success

Performance Metrics

Metric Score
Accuracy 94.4%
Macro F1 0.944
Weighted F1 0.945
Macro Precision 0.944
Macro Recall 0.950

Comparison to naive approach:

Approach Macro F1 Improvement
Naive (full docs, all classes, full fine-tuning) 0.023 Baseline
Improved (evidence + LoRA + augmentation) 0.944 +4,000%

Per-Class Performance

The model achieves balanced performance across all 18 classes:

                    precision    recall  f1-score   support

         E11.9          0.95      0.95      0.95        20
         I10            0.93      0.97      0.95        15
         E78.5          0.94      0.94      0.94        18
         ...

    macro avg          0.94      0.95      0.94       240
 weighted avg          0.95      0.94      0.95       240
Enter fullscreen mode Exit fullscreen mode

No class falls below 90% F1—demonstrating that our techniques successfully handle the remaining imbalance.


What We've Learned: Key Takeaways

Do This

  1. Extract focused context: Don't train on full documents when evidence spans are available
  2. Filter aggressively: Better to excel at 18 codes than fail at 158
  3. Augment intelligently: Back-translation preserves semantics while adding variation
  4. Use parameter-efficient methods: LoRA prevents overfitting on small datasets
  5. Weight your loss: Account for remaining class imbalance

Avoid This

  1. Training on full documents: Dilutes diagnostic signals
  2. Including rare classes: <10 examples per class is unlearnable
  3. Mixing augmented data into validation: Creates overly optimistic metrics
  4. Full fine-tuning: Causes catastrophic overfitting on small datasets
  5. Ignoring class imbalance: Model will focus only on frequent classes

Limitations and Future Work

Current Limitations

1. Limited Code Coverage
We only handle 18 out of 158 codes. For production use, you'd need:

  • More training data for rare codes
  • Hierarchical classification (predict ICD chapter first, then specific code)
  • Hybrid approach with commercial APIs

2. Evidence Dependency
Our approach requires supporting evidence annotations. For new data without annotations:

  • Use attention weights to identify key spans
  • Employ named entity recognition (NER) to extract diagnoses
  • Apply the trained model to full documents (with performance degradation)

3. Multi-Label Simplification
We converted multi-label to single-label (one example per code). True multi-label classification would:

  • Predict all relevant codes simultaneously
  • Model code co-occurrence patterns
  • Better reflect real clinical scenarios

Next Steps

  1. Hierarchical Classification: Leverage ICD-10's tree structure (Chapter → Category → Code)
  2. Full Augmentation: Implement FR and ES translations for 4× data expansion
  3. Ensemble Methods: Combine multiple augmented models with different random seeds
  4. Multi-Label Extension: Train on documents with all codes simultaneously
  5. Transfer Learning: Pre-train on medical entity recognition before ICD-10 classification

Coming Up in Part 2: AWS Comprehend Medical

In the next article, we'll explore a completely different approach:

  • Zero-shot inference using AWS's pre-trained medical NLP service
  • Entity trait filtering to handle negations, hypotheticals, and family history
  • Multi-label evaluation at the document level
  • Head-to-head comparison with our BioBERT model
  • Hybrid strategy combining both approaches for optimal results

We'll discover that AWS Comprehend Medical achieves 27% macro F1 on all 158 codes (vs. our 94% on 18 codes)—a fascinating trade-off between coverage and accuracy.

Try It Yourself

All code is available in the GitHub repository:

🔗 clinical-nlp-claims-processing

To run this notebook:

# Clone the repository
git clone https://github.com/alexretana/clinical-nlp-claims-processing.git
cd clinical-nlp-claims-processing

# Install dependencies (using uv)
curl -LsSf https://astral.sh/uv/install.sh | sh
uv sync

# Launch Jupyter
source .venv/bin/activate  # On Windows: .venv\Scripts\activate
jupyter lab

# Open notebooks/01_BioBERT_Fine-Tuning_NLP.ipynb
Enter fullscreen mode Exit fullscreen mode

Hardware requirements:

  • GPU with 8GB+ VRAM (RTX 3060, V100, A100) for reasonable training times
  • 16GB+ system RAM
  • Training takes ~2-4 hours on GPU, much longer on CPU

Conclusion

Building production-quality medical NLP systems requires more than throwing data at a pre-trained model. By combining:

  • Evidence-focused training
  • Strategic label filtering
  • Back-translation augmentation
  • LoRA parameter-efficient fine-tuning
  • Class-weighted loss

We transformed a failing system (2.3% F1) into one that performs at 94.4% F1—good enough for real-world deployment with human oversight.

The techniques we've covered apply far beyond medical coding:

  • Legal document analysis (case law classification)
  • Scientific literature mining (research topic categorization)
  • Customer support (ticket routing and classification)
  • Content moderation (policy violation detection)

Anywhere you face limited training data and class imbalance, this toolkit will serve you well.

Next time, we'll see how AWS Comprehend Medical tackles the same problem without any training data at all—and explore when each approach makes sense.


What challenges have you faced when training NLP models on limited data? Share your experiences in the comments! And if you found this helpful, follow me for Part 2 where we dive into AWS Comprehend Medical.

📚 Further Reading:


Tags: #machinelearning #nlp #healthcare #python #biobert #transformers #medicalcoding #datascience

Top comments (0)