DEV Community

alok kumar
alok kumar

Posted on

Fine-tuning LLM Using Masking

Fine-tuning LLM Using Masking

Fine-tuning a large language model (LLM) BART (Bidirectional and Auto-Regressive Transformers) using Masked Language Modeling (MLM) involves training the model on a specific dataset where some tokens are randomly masked and the model learns to predict the masked tokens. BART is a sequence-to-sequence model that combines the benefits of BERT (which uses MLM) and GPT (which is auto-regressive).

Below, I'll walk you through the steps and provide code to fine-tune BART using Masked Language Modeling.

Steps to Fine-Tune BART with MLM
Import Necessary Libraries: We’ll use the transformers library from Hugging Face, which provides pre-trained models and tokenizers.

Load a Pre-trained BART Model and Tokenizer: We’ll load a pre-trained BART model and its tokenizer.

Prepare the Dataset: We'll create or load a dataset, tokenize it, and apply the MLM. The dataset is split into input and target sequences.

Set Up the Training Arguments: Define the training parameters like learning rate, batch size, and the number of epochs.

Fine-Tune the Model: Use the Hugging Face Trainer API to fine-tune the model.

Evaluate the Model: After training, evaluate the model on a validation dataset

Code Example
Here is an example code to fine-tune BART using MLM:
from transformers import BartForConditionalGeneration, BartTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import torch
from torch.nn.utils.rnn import pad_sequence

Load the tokenizer and the model

model_name = "facebook/bart-base"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)

Load a sample dataset

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

Preprocessing function to tokenize the input text and mask some tokens

def preprocess_function(examples):
inputs = tokenizer(examples["text"], return_tensors="pt", truncation=True, padding=True)
inputs["input_ids"] = torch.tensor(inputs["input_ids"])

# Apply masking
labels = inputs["input_ids"].clone()
mask_token_id = tokenizer.mask_token_id
probability_matrix = torch.full(labels.shape, 0.15)
mask_matrix = torch.bernoulli(probability_matrix).bool()
labels[~mask_matrix] = -100 # Ignore labels that are not masked
inputs["input_ids"][mask_matrix] = mask_token_id

inputs["labels"] = labels
return inputs

Enter fullscreen mode Exit fullscreen mode




Apply the preprocessing function to the dataset

processed_dataset = dataset.map(preprocess_function, batched=True, remove_columns=["text"])

Set up training arguments

training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=4,
num_train_epochs=1,
save_steps=10_000,
save_total_limit=2,
)

Initialize Trainer

trainer = Trainer(
model=model,
args=training_args,
train_dataset=processed_dataset,
)

Fine-tune the model

trainer.train()

Save the fine-tuned model

model.save_pretrained("./fine-tuned-bart-mlm")
tokenizer.save_pretrained("./fine-tuned-bart-mlm")
Explanation
Tokenizer and Model:

BartTokenizer: Tokenizes the text into input IDs that the model can process.
BartForConditionalGeneration: BART model used for conditional generation tasks like summarization, translation, etc.
Dataset Loading:

We load the "wikitext-2-raw-v1" dataset from the datasets library, which contains raw text data.
Preprocessing:

The preprocess_function tokenizes the text and creates input IDs.
We create a mask over the input tokens (15% masking probability), replacing some of them with the mask token ().
The labels tensor is created where the unmasked tokens are set to -100, which tells the model to ignore those tokens during the loss computation.
Training Arguments:

output_dir: Directory to save the model checkpoints.
per_device_train_batch_size: Batch size for training.
num_train_epochs: Number of training epochs.
save_steps and save_total_limit: Control model checkpointing.
Trainer:

We use the Hugging Face Trainer class to manage the training loop, including handling data loading, model updates, and saving.
Fine-Tuning:

trainer.train(): Trains the model on the processed dataset using the defined training arguments.
Model Saving:

After training, the fine-tuned model and tokenizer are saved for future use.
Summary
The provided code demonstrates how to fine-tune a BART model using the Masked Language Modeling objective. This approach is beneficial when you want the model to better understand and predict masked tokens, which is essential for tasks like text completion, inpainting, or pre-training before transfer learning to other NLP tasks.

Top comments (0)