Spaces:
Sleeping
Sleeping
import os | |
from openai import AsyncOpenAI | |
import chainlit as cl | |
from dotenv import load_dotenv | |
from swarms import Agent | |
from swarms.utils.function_caller_model import OpenAIFunctionCaller | |
from pydantic import BaseModel, Field | |
from swarms.structs.conversation import Conversation | |
# Import prompts | |
from prompts import ( | |
MASTER_AGENT_SYS_PROMPT, | |
SUPERVISOR_AGENT_SYS_PROMPT, | |
COUNSELOR_AGENT_SYS_PROMPT, | |
BUDDY_AGENT_SYS_PROMPT | |
) | |
# RAG imports | |
from langchain_community.document_loaders import PyMuPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_openai.embeddings import OpenAIEmbeddings | |
from langchain_community.vectorstores import Qdrant | |
import tiktoken | |
import logging | |
# Load environment variables | |
load_dotenv() | |
# Setup logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
class CallLog(BaseModel): | |
agent_name: str = Field(description="The name of the agent to call: either Counselor-Agent or Buddy-Agent") | |
task: str = Field(description="The task for the selected agent to handle") | |
# Initialize RAG Components | |
async def initialize_rag(): | |
try: | |
logging.info("Starting initialize_rag") | |
file_id = "1Co1QBoPlWUfSShlS8Evw8e6t1PHtAPGT" | |
direct_url = f"https://drive.google.com/uc?export=download&id={file_id}" | |
logging.info(f"Attempting to load document from: {direct_url}") | |
docs = PyMuPDFLoader(direct_url).load() | |
logging.info(f"Successfully loaded document, got {len(docs)} pages") | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=200, | |
chunk_overlap=0, | |
length_function=lambda x: len(tiktoken.encoding_for_model("gpt-4").encode(x)) | |
) | |
split_chunks = text_splitter.split_documents(docs) | |
logging.info(f"Split into {len(split_chunks)} chunks") | |
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small") | |
qdrant_vectorstore = Qdrant.from_documents( | |
split_chunks, | |
embedding_model, | |
location=":memory:", | |
collection_name="knowledge_base", | |
) | |
return qdrant_vectorstore.as_retriever() | |
except Exception as e: | |
logging.error(f"Error in initialize_rag: {str(e)}") | |
raise e | |
class EnhancedAgent(Agent): | |
def __init__(self, agent_name, system_prompt, retriever): | |
super().__init__( | |
agent_name=agent_name, | |
system_prompt=system_prompt, | |
max_loops=1, | |
model_name="gpt-4" | |
) | |
self.retriever = retriever | |
async def process_with_context(self, task: str) -> str: | |
try: | |
# Get context from RAG using ainvoke instead of aget_relevant_documents | |
context_docs = await self.retriever.ainvoke(task) | |
context = "\n".join([doc.page_content for doc in context_docs]) | |
enhanced_task = f""" | |
Context from knowledge base: | |
{context} | |
User query: | |
{task} | |
""" | |
# Critical change: Don't await self.run | |
return self.run(task=enhanced_task) | |
except Exception as e: | |
logging.error(f"Error in process_with_context: {str(e)}") | |
return str(e) # Return error message instead of raising | |
class SwarmWithRAG: | |
def __init__(self, retriever): | |
self.master_agent = OpenAIFunctionCaller( | |
base_model=CallLog, | |
system_prompt=MASTER_AGENT_SYS_PROMPT | |
) | |
self.counselor_agent = EnhancedAgent( | |
agent_name="Counselor-Agent", | |
system_prompt=COUNSELOR_AGENT_SYS_PROMPT, | |
retriever=retriever | |
) | |
self.buddy_agent = EnhancedAgent( | |
agent_name="Buddy-Agent", | |
system_prompt=BUDDY_AGENT_SYS_PROMPT, | |
retriever=retriever | |
) | |
self.agents = { | |
"Counselor-Agent": self.counselor_agent, | |
"Buddy-Agent": self.buddy_agent | |
} | |
async def process(self, message: str) -> str: | |
try: | |
# Get agent selection | |
function_call = self.master_agent.run(message) | |
agent = self.agents.get(function_call.agent_name) | |
if not agent: | |
return f"No agent found for {function_call.agent_name}" | |
# Process with context - note we're awaiting here | |
response = await agent.process_with_context(function_call.task) | |
return response | |
except Exception as e: | |
logging.error(f"Error in process: {str(e)}") | |
return f"Error processing your request: {str(e)}" | |
async def start_chat(): | |
try: | |
retriever = await initialize_rag() | |
swarm = SwarmWithRAG(retriever) | |
cl.user_session.set("swarm", swarm) | |
await cl.Message( | |
content="Hi there! I’m SARASWATI, your AI guidance system. Whether you’re exploring career paths or looking for ways to improve your mental well-being, I’m here to support you. How can I assist you today?" | |
).send() | |
except Exception as e: | |
error_msg = f"Error initializing chat: {str(e)}" | |
logging.error(error_msg) | |
await cl.Message( | |
content=f"System initialization error: {error_msg}" | |
).send() | |
async def main(message: cl.Message): | |
try: | |
swarm = cl.user_session.get("swarm") | |
response = await swarm.process(message.content) | |
await cl.Message(content=response).send() | |
except Exception as e: | |
logging.error(f"Error in message processing: {str(e)}") | |
await cl.Message( | |
content="I apologize, but I encountered an error. Please try again." | |
).send() | |
if __name__ == "__main__": | |
cl.run() |