ChatCSV / models /llm_setup.py
Chamin09's picture
Create llm_setup.py
4fbcbff verified
raw
history blame
3.28 kB
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
def setup_llm():
"""Set up a more capable LLM for CSV analysis."""
try:
# Try to load FLAN-T5-small, which is better for instruction following
# while still being relatively small (~300MB)
model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
generator = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_length=512
)
# Create a wrapper class that matches the expected interface
class FlanT5LLM:
def complete(self, prompt):
class Response:
def __init__(self, text):
self.text = text
try:
# For FLAN-T5, we don't need to strip the prompt from the output
result = generator(prompt, max_length=150, do_sample=False)[0]
response_text = result["generated_text"].strip()
if not response_text:
response_text = "I couldn't generate a proper response."
return Response(response_text)
except Exception as e:
print(f"Error generating response: {e}")
return Response(f"Error generating response: {str(e)}")
return FlanT5LLM()
except Exception as e:
print(f"Error setting up FLAN-T5 model: {e}")
# Fallback to a simpler model if FLAN-T5 fails
try:
# Try T5-small as a fallback
model_name = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
generator = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_length=512
)
class T5LLM:
def complete(self, prompt):
class Response:
def __init__(self, text):
self.text = text
try:
result = generator(prompt, max_length=150, do_sample=False)[0]
return Response(result["generated_text"].strip())
except Exception as e:
return Response(f"Error: {str(e)}")
return T5LLM()
except Exception as e2:
print(f"Error setting up fallback model: {e2}")
# Last resort - dummy LLM
class DummyLLM:
def complete(self, prompt):
class Response:
def __init__(self, text):
self.text = text
return Response("Model initialization failed. Please check logs.")
return DummyLLM()