Spaces:
Running
Running
File size: 5,125 Bytes
b33e7f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# AI_agent.py
from langchain_community.llms import LlamaCpp
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory
from src.data_ingestion import DataIngestion
from src.RAG import RetrievalModule
from transformers import pipeline
import sqlite3
import time
from src.logger import logger
# Load LLaMA with llama.cpp—simple chatter
llm = LlamaCpp(
model_path="/Users/nitin/Downloads/llama-2-7b-chat.Q4_0.gguf", # Update this!
n_ctx=512, # Fits 8 GB
n_threads=4, # Fast on M3 Pro
temperature=0.7,
max_tokens=150,
verbose=True
)
# Instances
data_ingestion = DataIngestion()
retrieval_module = RetrievalModule()
# Memory
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# Database Setup
conn = sqlite3.connect("research_data.db")
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS papers (
query TEXT,
retrieved_papers TEXT,
summary TEXT,
evaluation TEXT
)
"""
)
conn.commit()
# Tools (just functions now)
def retrieve_relevant_papers(topic: str) -> str:
"""Fetch and retrieve relevant papers."""
titles, abstracts = data_ingestion.fetch_papers(topic)
if not abstracts:
logger.warning(f"No papers retrieved for topic: {topic}")
return "Could not retrieve papers."
retrieval_module.build_vector_store(abstracts)
relevant_sections = retrieval_module.retrieve_relevant(topic)
logger.info(f"Retrieved {len(relevant_sections)} relevant papers for {topic}")
return "\n".join(relevant_sections)
def summarize_text(text: str) -> str:
"""Summarize text using DistilBART."""
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", device="mps")
text = text[:500] # Keep it short
summary = summarizer(text, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
logger.info("Generated summary for retrieved papers")
return summary
def evaluate_summary(summary: str) -> str:
"""Evaluate summary quality with LLaMA."""
prompt = f"Evaluate this summary for accuracy, completeness, and clarity: {summary[:200]}"
evaluation = llm(prompt)
logger.info("Evaluated summary quality")
return evaluation
# Simple Conversational Chain—no retriever needed
class ResearchAssistant:
def __init__(self):
self.prompt = PromptTemplate(
input_variables=["chat_history", "query"],
template="You are a research assistant. Based on the chat history and query, provide a helpful response.\n\nChat History: {chat_history}\nQuery: {query}\n\nResponse: "
)
self.chain = LLMChain(llm=llm, prompt=self.prompt, memory=memory)
def process_query(self, query: str) -> tuple:
"""Process query with retries—no ReAct mess."""
retries = 0
max_retries = 3
while retries < max_retries:
try:
# Step 1: Retrieve papers
retrieved_papers = retrieve_relevant_papers(query)
if "Could not retrieve papers" in retrieved_papers:
query = f"more detailed {query}"
retries += 1
time.sleep(2)
continue
# Step 2: Summarize
summary = summarize_text(retrieved_papers)
if len(summary.split()) < 10:
retries += 1
time.sleep(2)
continue
# Step 3: Evaluate
evaluation = evaluate_summary(summary)
# Save to memory and DB
memory.save_context(
{"input": query},
{"output": f"Summary: {summary}\nEvaluation: {evaluation}\nAsk me anything about these findings!"}
)
cursor.execute(
"INSERT INTO papers (query, retrieved_papers, summary, evaluation) VALUES (?, ?, ?, ?)",
(query, retrieved_papers, summary, evaluation)
)
conn.commit()
return summary, evaluation
except Exception as e:
logger.error(f"Error in processing: {str(e)}")
retries += 1
time.sleep(2)
logger.error("Max retries reached—task failed.")
return "Failed after retries.", "N/A"
def chat(self, user_input: str) -> str:
"""Handle follow-up chats."""
if not memory.chat_memory.messages:
return "Please start with a research query like 'large language model memory optimization'."
return self.chain.run(query=user_input)
if __name__ == "__main__":
assistant = ResearchAssistant()
query = "large language model memory optimization"
summary, evaluation = assistant.process_query(query)
print("Summary:", summary)
print("Evaluation:", evaluation)
# Test follow-up
follow_up = "Tell me more about memory optimization."
print("Follow-up:", assistant.chat(follow_up))
conn.close() |