drrobot9's picture
push updated backend changes and auto start buiding
6584be3 verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from config import LLM_MODEL, CONFIDENCE_THRESHOLD, VECTORSTORE_DIR
import os
import sys
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
hf_cache = "/tmp/huggingface"
os.environ["HF_HOME"] = hf_cache
os.environ["TRANSFORMERS_CACHE"] = hf_cache
os.environ["HUGGINGFACE_HUB_CACHE"] = hf_cache
os.makedirs(hf_cache, exist_ok=True)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
if BASE_DIR not in sys.path:
sys.path.insert(0, BASE_DIR)
# Load BioMistral once
class BioMistralModel:
def __init__(self, model_name=LLM_MODEL, device=None):
logger.info(f"Loading model: {model_name}")
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
try:
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=hf_cache
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto" if self.device == "cuda" else None,
cache_dir=hf_cache
)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {e}")
# Fallback to pipeline
self.pipeline = pipeline(
"text-generation",
model=model_name,
device=0 if self.device == "cuda" else -1,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
)
self.use_pipeline = True
else:
self.use_pipeline = False
def generate_answer(self, query: str) -> str:
prompt = f"""You are a helpful bioinformatics tutor. Answer clearly and concisely.
Question: {query}
Answer:"""
try:
if hasattr(self, 'use_pipeline') and self.use_pipeline:
# Use pipeline fallback
result = self.pipeline(
prompt,
max_new_tokens=256,
do_sample=True,
top_p=0.9,
temperature=0.7,
pad_token_id=self.pipeline.tokenizer.eos_token_id
)
full_text = result[0]['generated_text']
else:
# Use model directly
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
top_p=0.9,
temperature=0.7,
pad_token_id=self.tokenizer.eos_token_id
)
full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the answer part
if "Answer:" in full_text:
return full_text.split("Answer:", 1)[-1].strip()
else:
return full_text.replace(prompt, "").strip()
except Exception as e:
logger.error(f"Error generating answer: {e}")
return f"I apologize, but I encountered an error while processing your question: {str(e)}"
# Formatting utility
class TextFormatter:
@staticmethod
def format_text(text: str) -> str:
"""Clean and format text output"""
if not text:
return "I don't have an answer for that question. Could you please rephrase or ask something else?"
# Basic cleaning
cleaned = " ".join(text.split())
if cleaned:
cleaned = cleaned[0].upper() + cleaned[1:]
# Ensure it ends with punctuation
if not cleaned[-1] in {'.', '!', '?'}:
cleaned += '.'
return cleaned
# Tutor Agent
class TutorAgent:
def __init__(self):
logger.info("Initializing TutorAgent")
self.model = BioMistralModel()
self.formatter = TextFormatter()
# Initialize RAG
self.rag_agent = None
try:
from rag import RAGAgent
self.rag_agent = RAGAgent(vectorstore_dir=str(VECTORSTORE_DIR))
logger.info("RAG agent initialized")
except ImportError as e:
logger.warning(f"RAG not available: {e}")
except Exception as e:
logger.warning(f"Failed to initialize RAG: {e}")
def process_query(self, query: str) -> str:
logger.info(f"Processing query: {query}")
if not query or len(query.strip()) < 2:
return "Please ask a meaningful question about bioinformatics."
# Generate answer
answer = self.model.generate_answer(query)
confidence = self.estimate_confidence(answer)
logger.info(f"Confidence: {confidence:.2f}")
# If confidence is low and RAG is available, try to enhance
if confidence < CONFIDENCE_THRESHOLD and self.rag_agent:
logger.info("Low confidence, attempting RAG enhancement")
try:
rag_answer = self._enhance_with_rag(query)
if rag_answer and len(rag_answer) > len(answer):
answer = rag_answer
except Exception as e:
logger.warning(f"RAG enhancement failed: {e}")
return self.formatter.format_text(answer)
def _enhance_with_rag(self, query: str) -> str:
"""Enhance answer using RAG if available"""
if not self.rag_agent:
return ""
try:
# Assuming RAGAgent has an answer method
if hasattr(self.rag_agent, 'answer'):
result = self.rag_agent.answer(query)
return result.get('answer', '') if isinstance(result, dict) else str(result)
else:
return ""
except Exception as e:
logger.error(f"RAG error: {e}")
return ""
def estimate_confidence(self, answer: str) -> float:
"""Simple confidence estimation"""
answer = answer.strip()
if not answer:
return 0.0
length = len(answer)
if length > 150:
return 0.85
elif length > 80:
return 0.7
elif length > 30:
return 0.5
else:
return 0.3
# User class (
class BioUser:
def __init__(self, name="BioUser"):
self.name = name
def ask_question(self, question: str, tutor: TutorAgent) -> str:
return tutor.process_query(question)