File size: 5,125 Bytes
b33e7f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# AI_agent.py
from langchain_community.llms import LlamaCpp
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory
from src.data_ingestion import DataIngestion
from src.RAG import RetrievalModule
from transformers import pipeline
import sqlite3
import time
from src.logger import logger

# Load LLaMA with llama.cpp—simple chatter
llm = LlamaCpp(
    model_path="/Users/nitin/Downloads/llama-2-7b-chat.Q4_0.gguf",  # Update this!
    n_ctx=512,  # Fits 8 GB
    n_threads=4,  # Fast on M3 Pro
    temperature=0.7,
    max_tokens=150,
    verbose=True
)

# Instances
data_ingestion = DataIngestion()
retrieval_module = RetrievalModule()

# Memory
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

# Database Setup
conn = sqlite3.connect("research_data.db")
cursor = conn.cursor()
cursor.execute(
    """
    CREATE TABLE IF NOT EXISTS papers (
        query TEXT,
        retrieved_papers TEXT,
        summary TEXT,
        evaluation TEXT
    )
    """
)
conn.commit()

# Tools (just functions now)
def retrieve_relevant_papers(topic: str) -> str:
    """Fetch and retrieve relevant papers."""
    titles, abstracts = data_ingestion.fetch_papers(topic)
    if not abstracts:
        logger.warning(f"No papers retrieved for topic: {topic}")
        return "Could not retrieve papers."
    retrieval_module.build_vector_store(abstracts)
    relevant_sections = retrieval_module.retrieve_relevant(topic)
    logger.info(f"Retrieved {len(relevant_sections)} relevant papers for {topic}")
    return "\n".join(relevant_sections)

def summarize_text(text: str) -> str:
    """Summarize text using DistilBART."""
    summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", device="mps")
    text = text[:500]  # Keep it short
    summary = summarizer(text, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
    logger.info("Generated summary for retrieved papers")
    return summary

def evaluate_summary(summary: str) -> str:
    """Evaluate summary quality with LLaMA."""
    prompt = f"Evaluate this summary for accuracy, completeness, and clarity: {summary[:200]}"
    evaluation = llm(prompt)
    logger.info("Evaluated summary quality")
    return evaluation

# Simple Conversational Chain—no retriever needed
class ResearchAssistant:
    def __init__(self):
        self.prompt = PromptTemplate(
            input_variables=["chat_history", "query"],
            template="You are a research assistant. Based on the chat history and query, provide a helpful response.\n\nChat History: {chat_history}\nQuery: {query}\n\nResponse: "
        )
        self.chain = LLMChain(llm=llm, prompt=self.prompt, memory=memory)

    def process_query(self, query: str) -> tuple:
        """Process query with retries—no ReAct mess."""
        retries = 0
        max_retries = 3

        while retries < max_retries:
            try:
                # Step 1: Retrieve papers
                retrieved_papers = retrieve_relevant_papers(query)
                if "Could not retrieve papers" in retrieved_papers:
                    query = f"more detailed {query}"
                    retries += 1
                    time.sleep(2)
                    continue

                # Step 2: Summarize
                summary = summarize_text(retrieved_papers)
                if len(summary.split()) < 10:
                    retries += 1
                    time.sleep(2)
                    continue

                # Step 3: Evaluate
                evaluation = evaluate_summary(summary)

                # Save to memory and DB
                memory.save_context(
                    {"input": query},
                    {"output": f"Summary: {summary}\nEvaluation: {evaluation}\nAsk me anything about these findings!"}
                )
                cursor.execute(
                    "INSERT INTO papers (query, retrieved_papers, summary, evaluation) VALUES (?, ?, ?, ?)",
                    (query, retrieved_papers, summary, evaluation)
                )
                conn.commit()
                return summary, evaluation

            except Exception as e:
                logger.error(f"Error in processing: {str(e)}")
                retries += 1
                time.sleep(2)

        logger.error("Max retries reached—task failed.")
        return "Failed after retries.", "N/A"

    def chat(self, user_input: str) -> str:
        """Handle follow-up chats."""
        if not memory.chat_memory.messages:
            return "Please start with a research query like 'large language model memory optimization'."
        return self.chain.run(query=user_input)

if __name__ == "__main__":
    assistant = ResearchAssistant()
    query = "large language model memory optimization"
    summary, evaluation = assistant.process_query(query)
    print("Summary:", summary)
    print("Evaluation:", evaluation)
    # Test follow-up
    follow_up = "Tell me more about memory optimization."
    print("Follow-up:", assistant.chat(follow_up))
    conn.close()