Large language models (LLMs) have demonstrated exceptional language capabilities. In the context of Text Classification, if Labelled Data is unavailable, LLMs are commonly employed using In-Context Learning (ICL). With ICL, the LLM implicitly learns how to classify text by relying on a task instruction and (optionally) a few labelled examples relevant to the task. While this approach may appear to be flexible and powerful, it can often be sensitive to the choice of prompts, choice of ICL examples, etc. resulting in poor performance. In such scenarios, can we improve the performance of the LLM without manually labelling more data ?
In this article, we will be talking about Self-Training LLMs for Text Classification. Self-Training is a semi-supervised learning approach which leverages a model’s own predictions on unlabelled data to build a labelled dataset for training of the model. Concretely, we will use the LLM to predict labels for unlabelled data to construct a training dataset and then fine-tune the LLM on the training data.
Intuitively, the main downside of Self-Training is its inability to correct its own mistakes. Typically, the most confident predictions of the model are the only samples considered to be included in the labelled dataset. However, “confidence” does not always imply “correctness”. Incorrectly labelled samples can end up amplifying the LLM errors.
To address this, we include a “Label Correction” step. We use DQC-Toolkit, a Python library that facilitates improvement of Machine Learning models by identifying and mitigating label errors in training dataset.
Pre-liminaries
Most of the code is based on our previous post. For the purposes of our experiment, we will be using Mistral-7B as our LLM. We will also extend the observations to Llama3–8B at the end of the article.
We begin by installing and loading the required dependencies. We will require the Python version to be ≥ 3.9
!pip install transformers
!pip install bitsandbytes
!pip install accelerate
!pip install huggingface_hub
!pip install peft
!pip install dqc-toolkit
from datasets import load_dataset, Dataset
from typing import List, Union
import numpy as np
import pandas as pd
import torch
import transformers
import wandb
import warnings
transformers.logging.set_verbosity_error()
wandb.init(mode="disabled")
warnings.filterwarnings('ignore')
Dataset
We will be using emotion, a publicly available dataset hosted on Hugging Face. It consists of English-language tweets annotated with one of six emotions as shown below —
[‘sadness’, ‘joy’, ‘love’, ‘anger’, ‘fear’, ‘surprise’]
The dataset has 16,000 training samples and 2,000 validation samples.
We also extend the observations to the MTOP domain dataset towards the end of the article.
from datasets import load_dataset
import pandas as pd
dataset = 'dair-ai/emotion'
dset = load_dataset(dataset, trust_remote_code=True)
train_data = pd.DataFrame(dset['train'])
val_data = pd.DataFrame(dset['validation'])
train_data.head()
Since LLMs cannot comprehend the emotion labels in integer format, we define a mapping of the integer labels to semantic text descriptions and create text labels for downstream consumption.
label_to_text = {0 : 'sadness',
1 : 'joy',
2 : 'love',
3 : 'anger',
4 : 'fear',
5 : 'surprise'}
train_data['label_text'] = train_data['label'].map(label_to_text)
val_data['label_text'] = val_data['label'].map(label_to_text)
Evaluation Metric
For the purpose of benchmarking our experiments, we choose Weighted F1 score as the metric. We also display the classification report and confusion matrix for detailed interpretation.
from sklearn.metrics import (classification_report, confusion_matrix,
ConfusionMatrixDisplay, f1_score)
import matplotlib.pyplot as plt
def fetch_performance_metrics(y_true: np.ndarray, y_pred: np.ndarray, exp_name: str,
display_report: bool = True, display_confusion_matrix: bool = True,
label_list: List[str] = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise'],
num_labels: int = 6) -> dict:
"""
Util function to compute F1 score and optionally display the classification report and confusion matrix for a given experiment.
Args:
y_true (np.ndarray): Array containing true labels.
y_pred (np.ndarray): Array containing predicted labels.
exp_name (str): Name of the experiment (used to save results).
display_report (bool, optional): Boolean flag indicating whether to display classification report (True) or not (False). Defaults to True.
display_confusion_matrix (bool, optional): Boolean flag indicating whether to display confusion matrix (True) or not (False). Defaults to True.
label_list (list, optional): List of labels. Defaults to ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise'].
num_labels (int, optional): Number of unique labels. Defaults to 6.
Returns:
dict: A dictionary containing F1 score.
"""
if display_report:
print('\nClassification Report:')
print(classification_report(y_true=y_true, y_pred=y_pred, labels=list(range(num_labels)),
target_names=label_list[:num_labels]))
if display_confusion_matrix:
cm = confusion_matrix(y_true=y_true, y_pred=y_pred)
fig, ax = plt.subplots(figsize=(8, 8))
display = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_list)
display.plot(ax=ax)
plt.savefig(exp_name)
return {'F1-score' : f1_score(y_true, y_pred, average='weighted')}
Alright ! let’s begin.
Baseline : LLM with ICL
We will need to login to Hugging Face hub to be able to access the LLM. We do this via Hugging Face’s notebook_login
from huggingface_hub import notebook_login
notebook_login()
Defining the LLM Pre-liminaries
We define a few LLM Utility functions as we did in the previous post.
from peft import AutoPeftModelForCausalLM
from tqdm import tqdm
from transformers import (AutoTokenizer, AutoModelForCausalLM,
BitsAndBytesConfig, pipeline)
import datasets
def _generate_predictions(example: datasets.formatting.formatting.LazyBatch,
generator: pipeline, text_column: str,
max_new_tokens: int = 9, split_token: str ='[/EMOTION]') -> dict:
"""
Generates predictions using the text generation model for a given example.
Args:
example (datasets.formatting.formatting.LazyBatch): Batch of samples from a dataset.
generator (pipeline): Huggingface pipeline for text generation.
text_column (str): Prompt for the text generation model.
max_new_tokens (int, optional): Maximum number of tokens to generate. Defaults to 9.
split_token (str, optional): Token to demarcate the emotion prediction. Defaults to '[/EMOTION]'.
Returns:
dict: A dictionary containing the generated predictions.
"""
num_examples = len(dataset)
predictions = []
batch_results = generator(example[text_column], max_new_tokens=max_new_tokens, num_return_sequences=1)
predictions.extend([result[0]["generated_text"] for result in batch_results])
return {'prediction' : predictions}
def infer_LLM(model_name: str, input_ds: Dataset, batch_size: int = 4, max_new_tokens: int = 9,
text_column: str = 'emotion_prompt', finetuned_model_path: str = None) -> Dataset:
"""
Util function to run LLM inference
Args:
model_name (str): The name or path of the LLM model.
input_ds (Dataset): Input dataset containing text prompts.
batch_size (int, optional): Batch size for inference. Defaults to 4.
max_new_tokens (int, optional): Maximum number of tokens to generate. Defaults to 9.
text_column (str, optional): Name of the column containing text prompts. Defaults to 'emotion_prompt'.
finetuned_model_path (str, optional): Path to the fine-tuned model. Defaults to None.
Returns:
dataset: Dataset with generated predictions.
"""
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
if finetuned_model_path is None:
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto",
quantization_config=quantization_config)
else:
model = AutoPeftModelForCausalLM.from_pretrained(finetuned_model_path,
device_map="auto",
quantization_config=quantization_config)
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,
batch_size=batch_size, truncation=False)
text_generator.tokenizer.pad_token_id = model.config.eos_token_id
input_ds = input_ds.map(_generate_predictions, fn_kwargs={'generator' : text_generator,
'text_column' : text_column,
'max_new_tokens' : max_new_tokens
},
batched=True, batch_size=batch_size)
return input_ds
def build_LLM_prompt(input_ds: Dataset, label_column: str = None, prompt_template: Union[str, None] = None,
with_label: bool = False) -> Dataset:
"""Util function to build the LLM prompt from input text data
Args:
input_ds (Dataset): Input dataset containing text
label_column (str, optional): Label column in the data. Applicable if constructing prompts for in-context samples / finetuning LLM. Defaults to None.
prompt_template (Union[str, None], optional): Text instruction to prepend to each transformed input text sample. Defaults to None.
with_label (bool, optional): `True` if the prompts should include labels from the `label_column`. Defaults to False.
Returns:
Dataset: Dataset with generated predictions.
"""
if type(input_ds) == pd.DataFrame:
input_ds = Dataset.from_pandas(input_ds)
if with_label:
input_ds = input_ds.map(lambda x: {'emotion_prompt': '[UTTERANCE]' + x['text'] + '[/UTTERANCE]' + \
'[EMOTION]' + x[label_column] + '[/EMOTION]'})
else:
input_ds = input_ds.map(lambda x: {'emotion_prompt': prompt_template + '[UTTERANCE]' + x['text'] + '[/UTTERANCE]' + \
'[EMOTION]'})
return input_ds
def _extract_label(sample: datasets.formatting.formatting.LazyRow, label_list: List[str]) -> dict:
"""Util function to extract the emotion from the generated LLM prediction
Args:
sample (datasets.formatting.formatting.LazyRow): Batch of samples from a dataset
label_list (List[str]): List of possible emotions
Returns:
dict: Dictionary of extracted predicted labels
"""
prompt_length = len(sample['emotion_prompt'])
generated_answer = sample['prediction'][prompt_length:].split('[/EMOTION]')[0].lower()
label_matched = False
predicted_label = None
for label in label_list:
if label in generated_answer:
predicted_label = label
label_matched = True
break
if not label_matched:
predicted_label = "no_match"
return {'predicted_label' : predicted_label}
def run_llm(val_data: pd.DataFrame, prompt_template: str, model_name: str, emotion_list: List[str], label_mapping: dict,
label_column: str = 'label', batch_size: int = 4, finetuned_model_path: str = None,
num_labels: int = 6, compute_metrics: bool = True) -> dict:
"""Run end-to-end LLM inference (from pre-processing input data to post-processing the predictions) and return the computed performance metrics on input validation data
Args:
val_data (pd.DataFrame): Validation data with labels
prompt_template (str): Text instruction to prepend to each transformed input text sample.
model_name (str): The name or path of the pre-trained LLM.
emotion_list (List[str]): List of possible emotions
label_mapping (dict): Dictionary mapping to convert text labels to integers
label_column (str, optional): Label column in the data. Defaults to 'label'.
batch_size (int, optional): Batch size for inference. Defaults to 4.
finetuned_model_path (str, optional): Path to the fine-tuned model, if available.. Defaults to None.
num_labels (int, optional): Number of unique labels. Defaults to 6.
compute_metrics (bool, optional): Boolean flag indicating whether to compute the performance metrics (True) or not (False)
Returns:
dict: A dictionary containing F1 score.
"""
predicted_label_list = []
val_ds = build_LLM_prompt(val_data, prompt_template=prompt_template)
val_ds_with_pred = infer_LLM(model_name, val_ds, batch_size, finetuned_model_path=finetuned_model_path)
predicted_label_list = val_ds_with_pred.map(_extract_label,
fn_kwargs={"label_list": emotion_list[:num_labels]})['predicted_label']
y_pred = [label_mapping[pred] if pred in label_mapping else num_labels for pred in predicted_label_list]
y_true = val_data[label_column].astype(int).values.tolist()
if num_labels not in y_pred:
# All LLM predictions match a valid emotion from `emotion_list`
emotion_list.remove('no_match')
if compute_metrics:
return y_pred, fetch_performance_metrics(y_true, y_pred, 'mistral_7b', label_list=emotion_list)
return y_pred
In summary -
build_LLM_prompt
transforms the input text into a LLM promptinfer_LLM
and_generate_predictions
instantiate the LLM using 4 bit quantization and run inference with the constructed input prompts._extract_label
maps the LLM free text outputs to valid emotion predictions. If the generated text has no matching emotion, the predicted label is set to “no_match”.run_LLM
invokes functionsbuild_LLM_prompt
andinfer_LLM
to perform inference and return the computed performance metrics on input validation data.
Build the LLM prompt
We select one sample at random for each label and build the prompt prefix to run ICL.
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
seed = 43
sample_data = train_data.groupby('label_text').sample(n=1, random_state=seed).reset_index(drop=True)
emotion_list = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
emotion_list_str = ', '.join(emotion_list)
transformed_sample_data = build_LLM_prompt(sample_data, with_label=True, label_column='label_text')
samples_str = '\n'.join(transformed_sample_data['emotion_prompt'])
prompt_template = "<s>[INST] You are a helpful, respectful and honest assistant. Choose one option that best describes the emotion behind the given utterance based on the following comma separated options: " + emotion_list_str + "[/INST] </s>"
Putting it all to work
We are ready to run our LLM now.
text_to_label = {v: k for k, v in label_to_text.items()}
llm_emotion_list = emotion_list + ['no_match']
_, score = run_llm(val_data, prompt_template, model_name, llm_emotion_list, text_to_label,
batch_size=64)
print(score)
The F1-score is 0.442 with a large proportion of the samples ending up in the “no_match” bucket. Can we do better than this ? Let’s find out.
Our Approach : Self-Training using DQC Toolkit
Self-Training LLMs for Text Classification using DQC Toolkit
As shown in the figure, our proposed self-training approach is comprised of the following three steps -
Generate LLM Predictions for Unlabelled Data
Apply Label Correction using DQC Toolkit
Fine-tune LLM using Reliably Labelled Data
Step 1 — Generate LLM Predictions for Unlabelled Data
We leverage LLM with ICL to generate initial labels for training our model.
predictions = run_llm(train_data, prompt_template, model_name, llm_emotion_list, text_to_label,
batch_size=64, compute_metrics=False)
As mentioned before, many predictions can end up being mapped to “no_match” (when we are unable to extract the emotion prediction from the LLM’s generated answer). We remove such samples from the data.
train_data['llm_predicted_label'] = pd.Series(predictions)
## Only valid label predictions
train_data_with_llm_pred = train_data.loc[train_data['llm_predicted_label'] < len(emotion_list), ].reset_index(drop=True)
Step 2 — Apply Label Correction using DQC Toolkit
Currently, DQC toolkit offers CrossValCurate
for curation of text classification datasets (binary / multi-class) using cross validation based label prediction. We will leverage this module to acquire better quality labels for our data.
cvc = CrossValCurate(random_state=seed,
calibration_method='calibrate_using_baseline' )
train_data_curated = cvc.fit_transform(train_data_with_llm_pred, y_col_name='llm_predicted_label')
CrossValCurate
accepts two parameters random_state
(random seed for reproducibility) and calibration_method
(whether/how to calibrate the prediction probabilities of the model being trained for label correction). You can check out all the hyper-parameters available in the documentation here.
The returned object train_data_curated
is a Pandas dataframe similar to the input dataframe train_data_with_llm_pred
with the following additional columns -
‘
label_correctness_score
’ represents a normalized score quantifying the correctness ofllm_predicted_label
.‘
is_label_correct
’ is a boolean flag indicating whether thellm_predicted_label
is to be considered correct (True) or incorrect (False).‘
predicted_label
’ and ‘prediction_probability
’ represent DQC Toolkit’s predicted label for a given sample and the corresponding probability score.
We leverage is_label_correct
to identify reliably labelled samples
train_data_curated = train_data_curated.loc[train_data_curated['is_label_correct']].reset_index(drop=True)
Step 3 — Fine-tune LLM using Reliably Labelled Data
We fine-tune the LLM Using train_data_curated
with llm_predicted_label
as the target variable. First, we map the integer labels to text labels for LLM interpretability.
train_data_curated['llm_predicted_label_text'] = train_data_curated['llm_predicted_label'].map(label_to_text)
Next, we transform the data into instruction prompts for better performance
prompt_template = "<s>[INST] You are a helpful, respectful and honest assistant. Choose one option that best describes the emotion behind the given utterance based on the following comma separated options: " + emotion_list_str + "[/INST] </s>"
label_column = 'llm_predicted_label_text'
train_data_curated_ds = build_LLM_prompt(train_data_curated, with_label=True, label_column=label_column)
train_data_curated_ds = train_data_curated_ds.map(lambda example, prompt_template=prompt_template : {'emotion_prompt' : prompt_template + example['emotion_prompt']})
Then, we define the LLM fine-tuning function
from peft import get_peft_model, LoraConfig, PeftConfig, PeftModel, prepare_model_for_kbit_training
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, DataCollatorForLanguageModeling,
pipeline, Trainer, TrainingArguments
)
import bitsandbytes as bnb
import torch.nn as nn
def tokenize(example: datasets.formatting.formatting.LazyRow, tokenizer: AutoTokenizer ) -> dict:
"""Util function to tokenize text data
Args:
example (datasets.formatting.formatting.LazyRow): Batch of samples containing text to tokenize.
tokenizer (AutoTokenizer): Tokenizer object used for tokenization.
Returns:
dict: Dictionary containing tokenized text.
"""
tokenized = tokenizer(
example['emotion_prompt'],
truncation=False
)
return {**tokenized}
def finetune_LLM(base_model_name: str, train_ds: Dataset,
save_path: str, seed: int, batch_size: int = 64, num_epochs: int = 1):
"""Function to fine-tune an LLM on the given input training data
Args:
base_model_name (str): The name or path of the LLM model to be fine-tuned
train_ds (Dataset): Input dataset containing text prompts.
save_path (str): Path to save the trained model
seed (int): Random seed for reproducibility
batch_size (int, optional): Batch size to use during training. Defaults to 64.
num_epochs (int, optional): Number of training epochs. Defaults to 1.
"""
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(base_model_name,
quantization_config=bnb_config,
device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(base_model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
train_ds = train_ds.map(
tokenize,
batched=False,
fn_kwargs={"tokenizer": tokenizer},
)
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
)
args = TrainingArguments(
disable_tqdm=False,
output_dir=save_path,
warmup_steps=1,
per_device_train_batch_size=batch_size,
num_train_epochs=num_epochs,
learning_rate=2e-4,
fp16=True,
optim="paged_adamw_8bit",
logging_dir="./logs",
save_strategy="no",
evaluation_strategy="no",
report_to=None
)
model = get_peft_model(model, peft_config)
model.config.use_cache = False
trainer = Trainer(
model=model,
train_dataset=train_ds.select_columns(['input_ids', 'attention_mask']),
eval_dataset=None,
args=args,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()
trainer.model.save_pretrained(save_path)
return
Finally, we are ready to fine-tune the model. The number of training epochs is set to 1 and batch size is set to 64.
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
finetuned_model_path = "selftrained-mistral-emotion"
finetune_LLM(model_name, train_data_curated_ds, save_path=finetuned_model_path, seed=seed)
The fine-tuned model is stored in your working directory under the folder ‘selftrained-mistral-emotion’
Test the Self-Trained Model’s Performance
We run the inference with the fine-tuned model using the same function run_llm
as we did for the ICL baseline.
text_to_label = {v: k for k, v in label_to_text.items()}
LLM_emotion_list = emotion_list + ['no_match']
_, score = run_llm(val_data, prompt_template, model_name, LLM_emotion_list, text_to_label,
finetuned_model_path=finetuned_model_path, batch_size=64)
print(score)
There’s a 29.41% improvement in the F1-score (from 0.442 to 0.572). The number of “no_match” predictions have also drastically reduced. And we didn’t have to label any data manually !
The following plot summarizes our results visually —
Performance of Mistral 7B in Text Classification using Emotion dataset with Minimal Labelled Data
Further Experimental Validation
Additional LLM — To verify the reproducibility of our observations with Mistral-7B, we run experiments with Llama3–8B as well.
Additional Dataset — We also include the MTOP domain dataset where LLM ICL is known to perform well in general. This helps us understand if our approach is capable of achieving improvements when LLMs are already doing a reasonable job.
We re-run our experiments with the new LLM and dataset. The code for these experiments can be found here. Following are the results —
The first plot from the left shows LLama3–8B’s performance in Text Classification with the Emotion dataset using ICL. The observations are similar to Mistral-7B experiment. The results with ICL are poor (F1-score of 0.365) and there is a 49.86% improvement in the F1-score after Self-Training using DQC Toolkit (F1 score of 0.547).
With MTOP Domain, both the LLMs perform well with ICL. As shown in the second and third plot, ICL with Mistral-7B and Llama3–8B achieve F1-scores of 0.9 and 0.88 respectively. Post Self-Training using DQC Toolkit, Mistral-7B scores 0.916 while Llama3–8B scores 0.938. Essentially, we observe a 1.78% improvement with Mistral-7B and a 6.59% improvement with Llama3–8B.
In a Nutshell
We observe that Self-Training using DQC Toolkit improves the ICL performance of both Mistral-7B and Llama3–8B for both Emotion and MTOP Domain datasets in Text Classification.
Similarity to “Teacher-Student” Learning
Self Training can be considered a special case of “Teacher-Student” framework where the Teacher model is an LLM and the Student model is the same LLM. In practice, you would want to explore a Student model that is more cost effective when it comes to deployment. Similar to what we’ve seen in this article, we can bootstrap smaller models using LLM ICL predictions to achieve improved performance. We leave this discussion for future posts.
Currently, DQC Toolkit supports text classification (binary/multi class) problems with various parameter customization options. The plan is to enhance it further by adding more capabilities. Any form of feedback / support will be much appreciated ! Following is the link to the repo.
sumanthprabhu / DQC-Toolkit
Quality Checks for Training Data in Machine Learning
DQC Toolkit is a Python library and framework designed with the goal to facilitate improvement of Machine Learning models by identifying and mitigating label errors in training dataset. Currently, DQC toolkit offers CrossValCurate
and LLMCurate
. CrossValCurate
can be used for label error detection / correction in text classification (binary / multi-class) based on cross validation. LLMCurate
extends PEDAL: Enhancing Greedy Decoding with Large Language Models using Diverse Exemplars to compute LLM-based confidence scores for free-text labels.
Installation
Installation of DQC-toolkit can be done as shown below
pip install dqc-toolkit
Quick Start
CrossValCurate
Assuming your text classification data is stored as a pandas dataframe data
, with each sample represented by the text
column and its corresponding noisy label represented by the label
column, here is how you use CrossValCurate
-
from dqc import CrossValCurate
cvc = CrossValCurate()
data_curated = cvc.fit_transform(data[['text'
…
PS - If you found this helpful, it would be great if you could give the repo a shout out.
Thank you for reading
Passionate about Machine Learning? Please feel free to add me on Linkedin
Top comments (0)