Chamin09 commited on
Commit
4fbcbff
·
verified ·
1 Parent(s): 19eecbc

Create llm_setup.py

Browse files
Files changed (1) hide show
  1. models/llm_setup.py +87 -0
models/llm_setup.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
3
+
4
+ def setup_llm():
5
+ """Set up a more capable LLM for CSV analysis."""
6
+ try:
7
+ # Try to load FLAN-T5-small, which is better for instruction following
8
+ # while still being relatively small (~300MB)
9
+ model_name = "google/flan-t5-small"
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
+
14
+ generator = pipeline(
15
+ "text2text-generation",
16
+ model=model,
17
+ tokenizer=tokenizer,
18
+ max_length=512
19
+ )
20
+
21
+ # Create a wrapper class that matches the expected interface
22
+ class FlanT5LLM:
23
+ def complete(self, prompt):
24
+ class Response:
25
+ def __init__(self, text):
26
+ self.text = text
27
+
28
+ try:
29
+ # For FLAN-T5, we don't need to strip the prompt from the output
30
+ result = generator(prompt, max_length=150, do_sample=False)[0]
31
+ response_text = result["generated_text"].strip()
32
+
33
+ if not response_text:
34
+ response_text = "I couldn't generate a proper response."
35
+
36
+ return Response(response_text)
37
+ except Exception as e:
38
+ print(f"Error generating response: {e}")
39
+ return Response(f"Error generating response: {str(e)}")
40
+
41
+ return FlanT5LLM()
42
+
43
+ except Exception as e:
44
+ print(f"Error setting up FLAN-T5 model: {e}")
45
+
46
+ # Fallback to a simpler model if FLAN-T5 fails
47
+ try:
48
+ # Try T5-small as a fallback
49
+ model_name = "t5-small"
50
+
51
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
52
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
53
+
54
+ generator = pipeline(
55
+ "text2text-generation",
56
+ model=model,
57
+ tokenizer=tokenizer,
58
+ max_length=512
59
+ )
60
+
61
+ class T5LLM:
62
+ def complete(self, prompt):
63
+ class Response:
64
+ def __init__(self, text):
65
+ self.text = text
66
+
67
+ try:
68
+ result = generator(prompt, max_length=150, do_sample=False)[0]
69
+ return Response(result["generated_text"].strip())
70
+ except Exception as e:
71
+ return Response(f"Error: {str(e)}")
72
+
73
+ return T5LLM()
74
+
75
+ except Exception as e2:
76
+ print(f"Error setting up fallback model: {e2}")
77
+
78
+ # Last resort - dummy LLM
79
+ class DummyLLM:
80
+ def complete(self, prompt):
81
+ class Response:
82
+ def __init__(self, text):
83
+ self.text = text
84
+
85
+ return Response("Model initialization failed. Please check logs.")
86
+
87
+ return DummyLLM()