rivapereira123 commited on
Commit
b80977d
Β·
verified Β·
1 Parent(s): a5319f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -12,6 +12,9 @@ import time
12
  import asyncio
13
  from concurrent.futures import ThreadPoolExecutor
14
 
 
 
 
15
  # Suppress warnings for cleaner output
16
  warnings.filterwarnings("ignore")
17
 
@@ -24,7 +27,7 @@ import faiss
24
  import torch
25
  from transformers import (
26
  AutoTokenizer,
27
- AutoModelForCausalLM,
28
  BitsAndBytesConfig,
29
  pipeline
30
  )
@@ -311,18 +314,22 @@ class OptimizedGazaRAGSystem:
311
 
312
 
313
  def _initialize_llm(self):
314
- """Initialize FLAN-T5 model (CPU-friendly)"""
315
  model_name = "google/flan-t5-base"
316
  try:
317
  logger.info(f"πŸ”„ Loading fallback CPU model: {model_name}")
318
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
319
- self.llm = AutoModelForCausalLM.from_pretrained(model_name)
320
- self.generation_pipeline = pipeline("text2text-generation", model=self.llm,tokenizer=self.tokenizer,return_full_text=False)
 
 
 
 
 
321
  logger.info("βœ… FLAN-T5 model loaded successfully")
322
  except Exception as e:
323
  logger.error(f"❌ Error loading FLAN-T5 model: {e}")
324
  self.llm = None
325
- self.generation_pipeline = None
326
 
327
 
328
 
 
12
  import asyncio
13
  from concurrent.futures import ThreadPoolExecutor
14
 
15
+ from transformers import AutoModelForSeq2SeqLM # βœ… Needed for T5 and FLAN models
16
+
17
+
18
  # Suppress warnings for cleaner output
19
  warnings.filterwarnings("ignore")
20
 
 
27
  import torch
28
  from transformers import (
29
  AutoTokenizer,
30
+ AutoModelForSeq2SeqLM,
31
  BitsAndBytesConfig,
32
  pipeline
33
  )
 
314
 
315
 
316
  def _initialize_llm(self):
317
+ """Load flan-t5-base for CPU fallback"""
318
  model_name = "google/flan-t5-base"
319
  try:
320
  logger.info(f"πŸ”„ Loading fallback CPU model: {model_name}")
321
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
322
+ self.llm = AutoModelForSeq2SeqLM.from_pretrained(model_name)
323
+ self.generation_pipeline = pipeline(
324
+ "text2text-generation", # βœ… correct pipeline for T5
325
+ model=self.llm,
326
+ tokenizer=self.tokenizer,
327
+ return_full_text=False
328
+ )
329
  logger.info("βœ… FLAN-T5 model loaded successfully")
330
  except Exception as e:
331
  logger.error(f"❌ Error loading FLAN-T5 model: {e}")
332
  self.llm = None
 
333
 
334
 
335