Spaces:
Building
Building
import chainlit as cl | |
from langchain_anthropic import ChatAnthropic | |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
from langchain_qdrant import QdrantVectorStore | |
from qdrant_client import QdrantClient | |
from qdrant_client.http.models import Distance, VectorParams | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from pathlib import Path | |
import os | |
from huggingface_hub import HfApi, hf_hub_download | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import Optional | |
import asyncio | |
from contextlib import asynccontextmanager | |
# Global variable to store the vectorstore instance | |
_vectorstore = None | |
# Configurable model selection | |
MODELS = { | |
"gpt-3.5-turbo": ChatOpenAI(model_name="gpt-3.5-turbo"), | |
"gpt-4": ChatOpenAI(model_name="gpt-4"), | |
"claude-3-sonnet": ChatAnthropic(model="claude-3-sonnet") | |
} | |
# Custom prompt for vegan ethics | |
VEGAN_ETHICS_TEMPLATE = """ | |
You are an ethical philosopher specializing in animal rights and veganism. | |
Use the following pieces of context to answer the user's question. | |
Always guide your response toward applying the "Name the Trait" argument: | |
1. Ask what trait in non-human animals, if present in humans, would justify treating humans the same way | |
2. Examine if this trait is morally relevant or arbitrary | |
3. Challenge the user to maintain logical consistency in their ethical framework | |
4. Use scientific evidence from the context to support your points | |
Context: {context} | |
Question: {question} | |
Answer: | |
""" | |
def get_vectorstore(persist_dir: str = "vector_store"): | |
"""Create or return cached vectorstore instance""" | |
global _vectorstore | |
if _vectorstore is not None: | |
return _vectorstore | |
# Initialize vector store with persistence | |
persist_dir = Path(persist_dir) | |
client = QdrantClient( | |
path=str(persist_dir), | |
force_disable_check_same_thread=True # Important for concurrent access | |
) | |
# Check if collection exists | |
collections = client.get_collections().collections | |
collection_names = [c.name for c in collections] | |
if "vegan_ethics" not in collection_names: | |
print(f"Creating new vector store in {persist_dir}") | |
client.create_collection( | |
collection_name="vegan_ethics", | |
vectors_config=VectorParams( | |
size=1536, | |
distance=Distance.COSINE, | |
), | |
) | |
_vectorstore = QdrantVectorStore( | |
client=client, | |
embedding=OpenAIEmbeddings(), | |
collection_name="vegan_ethics" | |
) | |
return _vectorstore | |
async def process_and_load_documents(vectorstore, repo_id="Frikster42/name-that-trait", data_folder="data"): | |
# Create a single TaskList for the entire process | |
tasks = cl.TaskList() | |
tasks.status = "Initializing..." | |
await tasks.send() | |
msg = cl.Message(content="Loading documents from Hugging Face repository... please be patient...") | |
await msg.send() | |
data_dir = Path("data") | |
data_dir.mkdir(exist_ok=True) | |
# Get list of files in the repository | |
api = HfApi() | |
dataset_files = api.list_repo_files(repo_id, repo_type="dataset") | |
dataset_pdf_files = [f for f in dataset_files if f.endswith('.pdf')] | |
# Download phase | |
tasks.status = "Downloading files..." | |
await tasks.send() | |
for i, pdf_file in enumerate(dataset_pdf_files): | |
task = cl.Task(title=f"Downloading {pdf_file}") | |
await tasks.add_task(task) | |
hf_hub_download( | |
repo_id=repo_id, | |
filename=pdf_file, | |
local_dir=str(data_dir), | |
local_dir_use_symlinks=False, | |
repo_type="dataset" | |
) | |
task.status = cl.TaskStatus.DONE | |
await tasks.send() | |
# Loading phase | |
documents = [] | |
pdf_files = [f for f in os.listdir(data_folder) if f.endswith('.pdf')] | |
tasks.status = "Loading files..." | |
await tasks.send() | |
for i, filename in enumerate(pdf_files): | |
task = cl.Task(title=f"Loading {filename}") | |
await tasks.add_task(task) | |
filepath = os.path.join(data_folder, filename) | |
if filename.endswith('.pdf'): | |
from langchain.document_loaders import PyPDFLoader | |
loader = PyPDFLoader(filepath) | |
else: | |
from langchain.document_loaders import TextLoader | |
loader = TextLoader(filepath) | |
documents.extend(loader.load()) | |
task.status = cl.TaskStatus.DONE | |
await tasks.send() | |
# Split and process documents | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
chunks = text_splitter.split_documents(documents) | |
if chunks: | |
tasks.status = "Processing chunks..." | |
await tasks.send() | |
batch_size = 100 | |
num_batches = (len(chunks) + batch_size - 1) // batch_size | |
for i in range(0, len(chunks), batch_size): | |
task = cl.Task(title=f"Processing batch {(i//batch_size)+1}/{num_batches}") | |
await tasks.add_task(task) | |
batch = chunks[i:i + batch_size] | |
vectorstore.add_documents(batch) | |
task.status = cl.TaskStatus.DONE | |
await tasks.send() | |
tasks.status = "Completed" | |
await tasks.send() | |
msg = cl.Message(content="β Documents loaded successfully!") | |
await msg.send() | |
return vectorstore | |
# Create FastAPI app | |
app = FastAPI(title="Vegan Ethics RAG API") | |
class QueryRequest(BaseModel): | |
question: str | |
model_name: Optional[str] = "gpt-3.5-turbo" | |
async def query_endpoint(request: QueryRequest): | |
try: | |
# Get or create vectorstore instance | |
vectorstore = get_vectorstore() | |
# Create prompt template | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", VEGAN_ETHICS_TEMPLATE), | |
("user", "{question}") | |
]) | |
# Validate model selection | |
if request.model_name not in MODELS: | |
raise HTTPException(status_code=400, detail=f"Invalid model name. Choose from: {list(MODELS.keys())}") | |
# Get relevant documents | |
docs = vectorstore.similarity_search(request.question, k=3) | |
context = "\n".join(doc.page_content for doc in docs) | |
# Generate response | |
chain = prompt | MODELS[request.model_name] | |
response = await chain.ainvoke({ | |
"context": context, | |
"question": request.question | |
}) | |
return { | |
"response": response.content, | |
"model_used": request.model_name | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def start(): | |
# Get or create vectorstore instance | |
vectorstore = get_vectorstore() | |
# Load documents if needed | |
collection_info = vectorstore.client.get_collection("vegan_ethics") | |
if collection_info.points_count == 0: | |
await cl.Message(content="Vector store is empty, loading documents...").send() | |
vectorstore = await process_and_load_documents(vectorstore) | |
# Create prompt template | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", VEGAN_ETHICS_TEMPLATE), | |
("user", "{question}") | |
]) | |
# Store components in session | |
cl.user_session.set("vectorstore", vectorstore) | |
cl.user_session.set("prompt", prompt) | |
cl.user_session.set("model_name", "gpt-3.5-turbo") | |
# UI for model selection | |
actions = [ | |
cl.Action( | |
name="model_select", | |
label="Current Model: gpt-3.5-turbo", | |
description="Change the AI model", | |
payload={"current_model": "gpt-3.5-turbo"} | |
) | |
] | |
await cl.Message( | |
content="Welcome to the Vegan Ethics Assistant. Ask any question about veganism, ethics, or animal consumption.", | |
actions=actions | |
).send() | |
async def on_action(action): | |
models_list = list(MODELS.keys()) | |
current_index = models_list.index(action.payload["current_model"]) | |
next_index = (current_index + 1) % len(models_list) | |
next_model_name = models_list[next_index] | |
cl.user_session.set("model_name", next_model_name) | |
actions = [ | |
cl.Action( | |
name="model_select", | |
label=f"Current Model: {next_model_name}", | |
description="Change the AI model", | |
payload={"current_model": next_model_name} | |
) | |
] | |
await cl.Message(content=f"Model switched to {next_model_name}", actions=actions).send() | |
async def main(message): | |
vectorstore = cl.user_session.get("vectorstore") | |
prompt = cl.user_session.get("prompt") | |
model_name = cl.user_session.get("model_name") | |
# Get relevant documents | |
docs = vectorstore.similarity_search(message.content, k=3) | |
context = "\n".join(doc.page_content for doc in docs) | |
# Generate response | |
chain = prompt | MODELS[model_name] | |
response = await chain.ainvoke({ | |
"context": context, | |
"question": message.content | |
}) | |
await cl.Message(content=response.content).send() |