Spaces:
Running
Running
# AI_agent.py | |
from langchain_community.llms import LlamaCpp | |
from langchain.chains import LLMChain | |
from langchain.prompts import PromptTemplate | |
from langchain.memory import ConversationBufferMemory | |
from src.data_ingestion import DataIngestion | |
from src.RAG import RetrievalModule | |
from transformers import pipeline | |
import sqlite3 | |
import time | |
from src.logger import logger | |
# Load LLaMA with llama.cpp—simple chatter | |
llm = LlamaCpp( | |
model_path="/Users/nitin/Downloads/llama-2-7b-chat.Q4_0.gguf", # Update this! | |
n_ctx=512, # Fits 8 GB | |
n_threads=4, # Fast on M3 Pro | |
temperature=0.7, | |
max_tokens=150, | |
verbose=True | |
) | |
# Instances | |
data_ingestion = DataIngestion() | |
retrieval_module = RetrievalModule() | |
# Memory | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
# Database Setup | |
conn = sqlite3.connect("research_data.db") | |
cursor = conn.cursor() | |
cursor.execute( | |
""" | |
CREATE TABLE IF NOT EXISTS papers ( | |
query TEXT, | |
retrieved_papers TEXT, | |
summary TEXT, | |
evaluation TEXT | |
) | |
""" | |
) | |
conn.commit() | |
# Tools (just functions now) | |
def retrieve_relevant_papers(topic: str) -> str: | |
"""Fetch and retrieve relevant papers.""" | |
titles, abstracts = data_ingestion.fetch_papers(topic) | |
if not abstracts: | |
logger.warning(f"No papers retrieved for topic: {topic}") | |
return "Could not retrieve papers." | |
retrieval_module.build_vector_store(abstracts) | |
relevant_sections = retrieval_module.retrieve_relevant(topic) | |
logger.info(f"Retrieved {len(relevant_sections)} relevant papers for {topic}") | |
return "\n".join(relevant_sections) | |
def summarize_text(text: str) -> str: | |
"""Summarize text using DistilBART.""" | |
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", device="mps") | |
text = text[:500] # Keep it short | |
summary = summarizer(text, max_length=150, min_length=50, do_sample=False)[0]["summary_text"] | |
logger.info("Generated summary for retrieved papers") | |
return summary | |
def evaluate_summary(summary: str) -> str: | |
"""Evaluate summary quality with LLaMA.""" | |
prompt = f"Evaluate this summary for accuracy, completeness, and clarity: {summary[:200]}" | |
evaluation = llm(prompt) | |
logger.info("Evaluated summary quality") | |
return evaluation | |
# Simple Conversational Chain—no retriever needed | |
class ResearchAssistant: | |
def __init__(self): | |
self.prompt = PromptTemplate( | |
input_variables=["chat_history", "query"], | |
template="You are a research assistant. Based on the chat history and query, provide a helpful response.\n\nChat History: {chat_history}\nQuery: {query}\n\nResponse: " | |
) | |
self.chain = LLMChain(llm=llm, prompt=self.prompt, memory=memory) | |
def process_query(self, query: str) -> tuple: | |
"""Process query with retries—no ReAct mess.""" | |
retries = 0 | |
max_retries = 3 | |
while retries < max_retries: | |
try: | |
# Step 1: Retrieve papers | |
retrieved_papers = retrieve_relevant_papers(query) | |
if "Could not retrieve papers" in retrieved_papers: | |
query = f"more detailed {query}" | |
retries += 1 | |
time.sleep(2) | |
continue | |
# Step 2: Summarize | |
summary = summarize_text(retrieved_papers) | |
if len(summary.split()) < 10: | |
retries += 1 | |
time.sleep(2) | |
continue | |
# Step 3: Evaluate | |
evaluation = evaluate_summary(summary) | |
# Save to memory and DB | |
memory.save_context( | |
{"input": query}, | |
{"output": f"Summary: {summary}\nEvaluation: {evaluation}\nAsk me anything about these findings!"} | |
) | |
cursor.execute( | |
"INSERT INTO papers (query, retrieved_papers, summary, evaluation) VALUES (?, ?, ?, ?)", | |
(query, retrieved_papers, summary, evaluation) | |
) | |
conn.commit() | |
return summary, evaluation | |
except Exception as e: | |
logger.error(f"Error in processing: {str(e)}") | |
retries += 1 | |
time.sleep(2) | |
logger.error("Max retries reached—task failed.") | |
return "Failed after retries.", "N/A" | |
def chat(self, user_input: str) -> str: | |
"""Handle follow-up chats.""" | |
if not memory.chat_memory.messages: | |
return "Please start with a research query like 'large language model memory optimization'." | |
return self.chain.run(query=user_input) | |
if __name__ == "__main__": | |
assistant = ResearchAssistant() | |
query = "large language model memory optimization" | |
summary, evaluation = assistant.process_query(query) | |
print("Summary:", summary) | |
print("Evaluation:", evaluation) | |
# Test follow-up | |
follow_up = "Tell me more about memory optimization." | |
print("Follow-up:", assistant.chat(follow_up)) | |
conn.close() |