AI_Research_Buddy / src /AI_agent.py
your-github-username
Auto-update from GitHub Actions
d9e62f5
raw
history blame
5.13 kB
# AI_agent.py
from langchain_community.llms import LlamaCpp
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory
from src.data_ingestion import DataIngestion
from src.RAG import RetrievalModule
from transformers import pipeline
import sqlite3
import time
from src.logger import logger
# Load LLaMA with llama.cpp—simple chatter
llm = LlamaCpp(
model_path="/Users/nitin/Downloads/llama-2-7b-chat.Q4_0.gguf", # Update this!
n_ctx=512, # Fits 8 GB
n_threads=4, # Fast on M3 Pro
temperature=0.7,
max_tokens=150,
verbose=True
)
# Instances
data_ingestion = DataIngestion()
retrieval_module = RetrievalModule()
# Memory
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# Database Setup
conn = sqlite3.connect("research_data.db")
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS papers (
query TEXT,
retrieved_papers TEXT,
summary TEXT,
evaluation TEXT
)
"""
)
conn.commit()
# Tools (just functions now)
def retrieve_relevant_papers(topic: str) -> str:
"""Fetch and retrieve relevant papers."""
titles, abstracts = data_ingestion.fetch_papers(topic)
if not abstracts:
logger.warning(f"No papers retrieved for topic: {topic}")
return "Could not retrieve papers."
retrieval_module.build_vector_store(abstracts)
relevant_sections = retrieval_module.retrieve_relevant(topic)
logger.info(f"Retrieved {len(relevant_sections)} relevant papers for {topic}")
return "\n".join(relevant_sections)
def summarize_text(text: str) -> str:
"""Summarize text using DistilBART."""
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", device="mps")
text = text[:500] # Keep it short
summary = summarizer(text, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
logger.info("Generated summary for retrieved papers")
return summary
def evaluate_summary(summary: str) -> str:
"""Evaluate summary quality with LLaMA."""
prompt = f"Evaluate this summary for accuracy, completeness, and clarity: {summary[:200]}"
evaluation = llm(prompt)
logger.info("Evaluated summary quality")
return evaluation
# Simple Conversational Chain—no retriever needed
class ResearchAssistant:
def __init__(self):
self.prompt = PromptTemplate(
input_variables=["chat_history", "query"],
template="You are a research assistant. Based on the chat history and query, provide a helpful response.\n\nChat History: {chat_history}\nQuery: {query}\n\nResponse: "
)
self.chain = LLMChain(llm=llm, prompt=self.prompt, memory=memory)
def process_query(self, query: str) -> tuple:
"""Process query with retries—no ReAct mess."""
retries = 0
max_retries = 3
while retries < max_retries:
try:
# Step 1: Retrieve papers
retrieved_papers = retrieve_relevant_papers(query)
if "Could not retrieve papers" in retrieved_papers:
query = f"more detailed {query}"
retries += 1
time.sleep(2)
continue
# Step 2: Summarize
summary = summarize_text(retrieved_papers)
if len(summary.split()) < 10:
retries += 1
time.sleep(2)
continue
# Step 3: Evaluate
evaluation = evaluate_summary(summary)
# Save to memory and DB
memory.save_context(
{"input": query},
{"output": f"Summary: {summary}\nEvaluation: {evaluation}\nAsk me anything about these findings!"}
)
cursor.execute(
"INSERT INTO papers (query, retrieved_papers, summary, evaluation) VALUES (?, ?, ?, ?)",
(query, retrieved_papers, summary, evaluation)
)
conn.commit()
return summary, evaluation
except Exception as e:
logger.error(f"Error in processing: {str(e)}")
retries += 1
time.sleep(2)
logger.error("Max retries reached—task failed.")
return "Failed after retries.", "N/A"
def chat(self, user_input: str) -> str:
"""Handle follow-up chats."""
if not memory.chat_memory.messages:
return "Please start with a research query like 'large language model memory optimization'."
return self.chain.run(query=user_input)
if __name__ == "__main__":
assistant = ResearchAssistant()
query = "large language model memory optimization"
summary, evaluation = assistant.process_query(query)
print("Summary:", summary)
print("Evaluation:", evaluation)
# Test follow-up
follow_up = "Tell me more about memory optimization."
print("Follow-up:", assistant.chat(follow_up))
conn.close()