|
import torch |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
def setup_llm(): |
|
"""Set up a more capable LLM for CSV analysis.""" |
|
try: |
|
|
|
|
|
model_name = "google/flan-t5-small" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
generator = pipeline( |
|
"text2text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_length=512 |
|
) |
|
|
|
|
|
class FlanT5LLM: |
|
def complete(self, prompt): |
|
class Response: |
|
def __init__(self, text): |
|
self.text = text |
|
|
|
try: |
|
|
|
result = generator(prompt, max_length=150, do_sample=False)[0] |
|
response_text = result["generated_text"].strip() |
|
|
|
if not response_text: |
|
response_text = "I couldn't generate a proper response." |
|
|
|
return Response(response_text) |
|
except Exception as e: |
|
print(f"Error generating response: {e}") |
|
return Response(f"Error generating response: {str(e)}") |
|
|
|
return FlanT5LLM() |
|
|
|
except Exception as e: |
|
print(f"Error setting up FLAN-T5 model: {e}") |
|
|
|
|
|
try: |
|
|
|
model_name = "t5-small" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
generator = pipeline( |
|
"text2text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_length=512 |
|
) |
|
|
|
class T5LLM: |
|
def complete(self, prompt): |
|
class Response: |
|
def __init__(self, text): |
|
self.text = text |
|
|
|
try: |
|
result = generator(prompt, max_length=150, do_sample=False)[0] |
|
return Response(result["generated_text"].strip()) |
|
except Exception as e: |
|
return Response(f"Error: {str(e)}") |
|
|
|
return T5LLM() |
|
|
|
except Exception as e2: |
|
print(f"Error setting up fallback model: {e2}") |
|
|
|
|
|
class DummyLLM: |
|
def complete(self, prompt): |
|
class Response: |
|
def __init__(self, text): |
|
self.text = text |
|
|
|
return Response("Model initialization failed. Please check logs.") |
|
|
|
return DummyLLM() |
|
|