DEV Community

Olivier
Olivier

Posted on

2

LoReFT and pyreft for surgical fine-tuning

Here’s trying to understand and summarise the paper "ReFT: Representation Finetuning for Language Models"

https://arxiv.org/abs/2404.03592
https://github.com/stanfordnlp/pyreft

The paper proposes Representation Finetuning (ReFT) as a more parameter-efficient alternative to PEFTs like adapters and LoRA for adapting large language models to downstream tasks. Key ideas:

  • ReFT methods train task-specific interventions that modify hidden representations of a frozen base model, leveraging the insight from interpretability work that representations encode rich semantics.

  • The Low-rank Linear Subspace ReFT (LoReFT) variant uses 10-50x fewer parameters than leading PEFTs by intervening on representations in a learned low-dimensional subspace.

  • LoReFT provides state-of-the-art efficiency-performance tradeoffs on commonsense reasoning, arithmetic, instruction following, and language understanding tasks. On instruction tuning Llama-2 7B, it nearly matches GPT-3.5 using only 0.004% extra parameters.

  • Editing representations may be more powerful than modifying weights as done in PEFTs. LoReFT did struggle more on arithmetic reasoning, possibly due to long output lengths.

To make ReFT easy to use, the authors also introduce pyreft, a library that enables:

  • Fine-tuning any HuggingFace pretrained LM with ReFT
  • Configuring ReFT hyperparameters via config files
  • Easily sharing fine-tuned models on HuggingFace

Here is some pseudo-code showing a basic usage of pyreft to fine-tune a Llama-7B model:

from pyreft import get_reft_model, ReftConfig, ReftTrainerForCausalLM

# Load pretrained LM
model = AutoModelForCausalLM.from_pretrained("llama-7b-hf") 

# Configure LoReFT intervention
reft_config = ReftConfig(
  representations={
    "layer": 15, 
    "component": "block_output",
    "intervention": LoReftIntervention(embed_dim=model.hidden_size, rank=1)
  }  
)

# Wrap model with ReFT  
reft_model = get_reft_model(model, reft_config)

# Load training data
data_module = make_supervised_data_module(...)

# Configure training
trainer = ReftTrainerForCausalLM(
  model=reft_model,
  args=TrainingArguments(...),  
  **data_module
)

# Train model
trainer.train()
Enter fullscreen mode Exit fullscreen mode

This fine-tunes a Llama-7B model by training a rank-1 LoReFT intervention at layer 15, modifying only 0.00006% of the model parameters. Once trained, the reft_model can be used for inference on downstream tasks.

By establishing ReFT as a promising new paradigm for LM adaptation, this work points to representation editing as a powerful lever for both model efficiency and interpretability. The pyreft library enables researchers and practitioners to easily experiment with and build on these ideas.

Image of Timescale

🚀 pgai Vectorizer: SQLAlchemy and LiteLLM Make Vector Search Simple

We built pgai Vectorizer to simplify embedding management for AI applications—without needing a separate database or complex infrastructure. Since launch, developers have created over 3,000 vectorizers on Timescale Cloud, with many more self-hosted.

Read full post →

Top comments (0)

Image of Docusign

🛠️ Bring your solution into Docusign. Reach over 1.6M customers.

Docusign is now extensible. Overcome challenges with disconnected products and inaccessible data by bringing your solutions into Docusign and publishing to 1.6M customers in the App Center.

Learn more