File size: 5,987 Bytes
31b2366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32f62dd
31b2366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
from openai import AsyncOpenAI
import chainlit as cl
from dotenv import load_dotenv
from swarms import Agent
from swarms.utils.function_caller_model import OpenAIFunctionCaller
from pydantic import BaseModel, Field
from swarms.structs.conversation import Conversation

# Import prompts
from prompts import (
    MASTER_AGENT_SYS_PROMPT,
    SUPERVISOR_AGENT_SYS_PROMPT,
    COUNSELOR_AGENT_SYS_PROMPT,
    BUDDY_AGENT_SYS_PROMPT
)

# RAG imports
from langchain_community.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Qdrant
import tiktoken
import logging

# Load environment variables
load_dotenv()

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

class CallLog(BaseModel):
    agent_name: str = Field(description="The name of the agent to call: either Counselor-Agent or Buddy-Agent")
    task: str = Field(description="The task for the selected agent to handle")

# Initialize RAG Components
async def initialize_rag():
    try:
        logging.info("Starting initialize_rag")
        file_id = "1Co1QBoPlWUfSShlS8Evw8e6t1PHtAPGT"
        direct_url = f"https://drive.google.com/uc?export=download&id={file_id}"
        
        logging.info(f"Attempting to load document from: {direct_url}")
        docs = PyMuPDFLoader(direct_url).load()
        logging.info(f"Successfully loaded document, got {len(docs)} pages")

        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=200,
            chunk_overlap=0,
            length_function=lambda x: len(tiktoken.encoding_for_model("gpt-4").encode(x))
        )
        split_chunks = text_splitter.split_documents(docs)
        logging.info(f"Split into {len(split_chunks)} chunks")
        
        embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
        qdrant_vectorstore = Qdrant.from_documents(
            split_chunks,
            embedding_model,
            location=":memory:",
            collection_name="knowledge_base",
        )
        
        return qdrant_vectorstore.as_retriever()
        
    except Exception as e:
        logging.error(f"Error in initialize_rag: {str(e)}")
        raise e

class EnhancedAgent(Agent):
    def __init__(self, agent_name, system_prompt, retriever):
        super().__init__(
            agent_name=agent_name,
            system_prompt=system_prompt,
            max_loops=1,
            model_name="gpt-4"
        )
        self.retriever = retriever
    
    async def process_with_context(self, task: str) -> str:
        try:
            # Get context from RAG using ainvoke instead of aget_relevant_documents
            context_docs = await self.retriever.ainvoke(task)
            context = "\n".join([doc.page_content for doc in context_docs])
            
            enhanced_task = f"""
            Context from knowledge base:
            {context}
            
            User query:
            {task}
            """
            
            # Critical change: Don't await self.run
            return self.run(task=enhanced_task)
            
        except Exception as e:
            logging.error(f"Error in process_with_context: {str(e)}")
            return str(e)  # Return error message instead of raising

class SwarmWithRAG:
    def __init__(self, retriever):
        self.master_agent = OpenAIFunctionCaller(
            base_model=CallLog,
            system_prompt=MASTER_AGENT_SYS_PROMPT
        )
        
        self.counselor_agent = EnhancedAgent(
            agent_name="Counselor-Agent",
            system_prompt=COUNSELOR_AGENT_SYS_PROMPT,
            retriever=retriever
        )
        
        self.buddy_agent = EnhancedAgent(
            agent_name="Buddy-Agent",
            system_prompt=BUDDY_AGENT_SYS_PROMPT,
            retriever=retriever
        )
        
        self.agents = {
            "Counselor-Agent": self.counselor_agent,
            "Buddy-Agent": self.buddy_agent
        }

    async def process(self, message: str) -> str:
        try:
            # Get agent selection
            function_call = self.master_agent.run(message)
            agent = self.agents.get(function_call.agent_name)
            
            if not agent:
                return f"No agent found for {function_call.agent_name}"
            
            # Process with context - note we're awaiting here
            response = await agent.process_with_context(function_call.task)
            return response
            
        except Exception as e:
            logging.error(f"Error in process: {str(e)}")
            return f"Error processing your request: {str(e)}"
        
@cl.on_chat_start
async def start_chat():
    try:
        retriever = await initialize_rag()
        swarm = SwarmWithRAG(retriever)
        cl.user_session.set("swarm", swarm)
        
        await cl.Message(
            content="Hi there! I’m SARASWATI, your AI guidance system. Whether you’re exploring career paths or looking for ways to improve your mental well-being, I’m here to support you. How can I assist you today?"
        ).send()
        
    except Exception as e:
        error_msg = f"Error initializing chat: {str(e)}"
        logging.error(error_msg)
        await cl.Message(
            content=f"System initialization error: {error_msg}"
        ).send()

@cl.on_message
async def main(message: cl.Message):
    try:
        swarm = cl.user_session.get("swarm")
        response = await swarm.process(message.content)
        await cl.Message(content=response).send()
        
    except Exception as e:
        logging.error(f"Error in message processing: {str(e)}")
        await cl.Message(
            content="I apologize, but I encountered an error. Please try again."
        ).send()

if __name__ == "__main__":
    cl.run()