|
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
|
from config import LLM_MODEL, CONFIDENCE_THRESHOLD, VECTORSTORE_DIR
|
|
import os
|
|
import sys
|
|
import logging
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
hf_cache = "/tmp/huggingface"
|
|
os.environ["HF_HOME"] = hf_cache
|
|
os.environ["TRANSFORMERS_CACHE"] = hf_cache
|
|
os.environ["HUGGINGFACE_HUB_CACHE"] = hf_cache
|
|
os.makedirs(hf_cache, exist_ok=True)
|
|
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
if BASE_DIR not in sys.path:
|
|
sys.path.insert(0, BASE_DIR)
|
|
|
|
|
|
class BioMistralModel:
|
|
def __init__(self, model_name=LLM_MODEL, device=None):
|
|
logger.info(f"Loading model: {model_name}")
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
try:
|
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
model_name,
|
|
cache_dir=hf_cache
|
|
)
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
|
device_map="auto" if self.device == "cuda" else None,
|
|
cache_dir=hf_cache
|
|
)
|
|
logger.info("Model loaded successfully")
|
|
except Exception as e:
|
|
logger.error(f"Error loading model: {e}")
|
|
|
|
self.pipeline = pipeline(
|
|
"text-generation",
|
|
model=model_name,
|
|
device=0 if self.device == "cuda" else -1,
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
|
)
|
|
self.use_pipeline = True
|
|
else:
|
|
self.use_pipeline = False
|
|
|
|
def generate_answer(self, query: str) -> str:
|
|
prompt = f"""You are a helpful bioinformatics tutor. Answer clearly and concisely.
|
|
|
|
Question: {query}
|
|
Answer:"""
|
|
|
|
try:
|
|
if hasattr(self, 'use_pipeline') and self.use_pipeline:
|
|
|
|
result = self.pipeline(
|
|
prompt,
|
|
max_new_tokens=256,
|
|
do_sample=True,
|
|
top_p=0.9,
|
|
temperature=0.7,
|
|
pad_token_id=self.pipeline.tokenizer.eos_token_id
|
|
)
|
|
full_text = result[0]['generated_text']
|
|
else:
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
|
|
|
with torch.no_grad():
|
|
outputs = self.model.generate(
|
|
**inputs,
|
|
max_new_tokens=256,
|
|
do_sample=True,
|
|
top_p=0.9,
|
|
temperature=0.7,
|
|
pad_token_id=self.tokenizer.eos_token_id
|
|
)
|
|
|
|
full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
if "Answer:" in full_text:
|
|
return full_text.split("Answer:", 1)[-1].strip()
|
|
else:
|
|
return full_text.replace(prompt, "").strip()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating answer: {e}")
|
|
return f"I apologize, but I encountered an error while processing your question: {str(e)}"
|
|
|
|
|
|
class TextFormatter:
|
|
@staticmethod
|
|
def format_text(text: str) -> str:
|
|
"""Clean and format text output"""
|
|
if not text:
|
|
return "I don't have an answer for that question. Could you please rephrase or ask something else?"
|
|
|
|
|
|
cleaned = " ".join(text.split())
|
|
if cleaned:
|
|
cleaned = cleaned[0].upper() + cleaned[1:]
|
|
|
|
if not cleaned[-1] in {'.', '!', '?'}:
|
|
cleaned += '.'
|
|
return cleaned
|
|
|
|
|
|
class TutorAgent:
|
|
def __init__(self):
|
|
logger.info("Initializing TutorAgent")
|
|
self.model = BioMistralModel()
|
|
self.formatter = TextFormatter()
|
|
|
|
|
|
self.rag_agent = None
|
|
try:
|
|
from rag import RAGAgent
|
|
self.rag_agent = RAGAgent(vectorstore_dir=str(VECTORSTORE_DIR))
|
|
logger.info("RAG agent initialized")
|
|
except ImportError as e:
|
|
logger.warning(f"RAG not available: {e}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize RAG: {e}")
|
|
|
|
def process_query(self, query: str) -> str:
|
|
logger.info(f"Processing query: {query}")
|
|
|
|
if not query or len(query.strip()) < 2:
|
|
return "Please ask a meaningful question about bioinformatics."
|
|
|
|
|
|
answer = self.model.generate_answer(query)
|
|
confidence = self.estimate_confidence(answer)
|
|
|
|
logger.info(f"Confidence: {confidence:.2f}")
|
|
|
|
|
|
if confidence < CONFIDENCE_THRESHOLD and self.rag_agent:
|
|
logger.info("Low confidence, attempting RAG enhancement")
|
|
try:
|
|
rag_answer = self._enhance_with_rag(query)
|
|
if rag_answer and len(rag_answer) > len(answer):
|
|
answer = rag_answer
|
|
except Exception as e:
|
|
logger.warning(f"RAG enhancement failed: {e}")
|
|
|
|
return self.formatter.format_text(answer)
|
|
|
|
def _enhance_with_rag(self, query: str) -> str:
|
|
"""Enhance answer using RAG if available"""
|
|
if not self.rag_agent:
|
|
return ""
|
|
|
|
try:
|
|
|
|
if hasattr(self.rag_agent, 'answer'):
|
|
result = self.rag_agent.answer(query)
|
|
return result.get('answer', '') if isinstance(result, dict) else str(result)
|
|
else:
|
|
return ""
|
|
except Exception as e:
|
|
logger.error(f"RAG error: {e}")
|
|
return ""
|
|
|
|
def estimate_confidence(self, answer: str) -> float:
|
|
"""Simple confidence estimation"""
|
|
answer = answer.strip()
|
|
if not answer:
|
|
return 0.0
|
|
|
|
length = len(answer)
|
|
if length > 150:
|
|
return 0.85
|
|
elif length > 80:
|
|
return 0.7
|
|
elif length > 30:
|
|
return 0.5
|
|
else:
|
|
return 0.3
|
|
|
|
|
|
class BioUser:
|
|
def __init__(self, name="BioUser"):
|
|
self.name = name
|
|
|
|
def ask_question(self, question: str, tutor: TutorAgent) -> str:
|
|
return tutor.process_query(question) |