Everytime I run this LLM I get a memory error. Please help.
importimport os
os
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
from transformers import pipeline
Replace with the path to your local folder containing the text files
folder_path = "C:\Users\asokw\Downloads\new"
Function to read and process text files
def read_text_files(folder_path):
all_files = os.listdir(folder_path)
text_files = [os.path.join(folder_path, f) for f in all_files if f.endswith('.txt')]
documents = []
for file_path in text_files:
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read()
documents.append(content)
return documents
Load and preprocess documents
documents = read_text_files(folder_path)
Initialize RAG tokenizer, retriever, and model
tokenizer = RagTokenizer.from_pretrained('facebook/rag-token-base')
retriever = RagRetriever.from_pretrained('facebook/rag-token-base', index_name='exact', passages=documents)
model = RagTokenForGeneration.from_pretrained('facebook/rag-token-base', retriever=retriever)
your_prompt = "What information can be found in these documents?"
inputs = tokenizer(your_prompt, return_tensors="pt")
retrieval_output = model.get_retrieval_vector(inputs)
generation_inputs = {
"input_ids": inputs.input_ids,
"attention_mask": inputs.attention_mask,
"retrieval_logits": retrieval_output,
}
generation_output = model.generate(**generation_inputs)
generated_text = tokenizer.decode(generation_output.sequences[0])
print(f"Retrieved documents:", retrieval_output)
print(f"Generated text:", generated_text)
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
from transformers import pipeline
Top comments (0)