Spaces:
Runtime error
Runtime error
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from langchain.smith import RunEvalConfig, run_on_dataset | |
import os | |
from langchain_community.vectorstores import FAISS | |
from langchain.prompts import ChatPromptTemplate | |
from pathlib import Path | |
import json | |
from typing import Dict, List, Optional | |
from langchain_core.documents import Document | |
from langchain.callbacks.tracers import ConsoleCallbackHandler | |
class DesignRAG: | |
def __init__(self): | |
# Get API keys from environment | |
api_key = os.getenv("OPENAI_API_KEY") | |
if not api_key: | |
raise ValueError( | |
"OPENAI_API_KEY environment variable not set. " | |
"Please set it in HuggingFace Spaces settings." | |
) | |
# Initialize embedding model with explicit API key | |
self.embeddings = OpenAIEmbeddings( | |
openai_api_key=api_key | |
) | |
# Load design data and create vector store | |
self.vector_store = self._create_vector_store() | |
# Create retriever with tracing | |
self.retriever = self.vector_store.as_retriever( | |
search_type="similarity", | |
search_kwargs={"k": 1}, | |
tags=["design_retriever"] # Add tags for tracing | |
) | |
# Create LLM with tracing | |
self.llm = ChatOpenAI( | |
temperature=0.2, | |
tags=["design_llm"] # Add tags for tracing | |
) | |
def _create_vector_store(self) -> FAISS: | |
"""Create FAISS vector store from design metadata""" | |
try: | |
# Update path to look in data/designs | |
designs_dir = Path(__file__).parent.parent / "data" / "designs" | |
documents = [] | |
# Load all metadata files | |
for design_dir in designs_dir.glob("**/metadata.json"): | |
try: | |
with open(design_dir, "r") as f: | |
metadata = json.load(f) | |
# Create document text from metadata with safe gets | |
text = f""" | |
Design {metadata.get('id', 'unknown')}: | |
Description: {metadata.get('description', 'No description available')} | |
Categories: {', '.join(metadata.get('categories', []))} | |
Visual Characteristics: {', '.join(metadata.get('visual_characteristics', []))} | |
""" | |
# Load associated CSS | |
''' | |
css_path = design_dir.parent / "style.css" | |
if css_path.exists(): | |
with open(css_path, "r") as f: | |
css = f.read() | |
text += f"\nCSS:\n{css}" | |
''' | |
# Create Document object with minimal metadata | |
documents.append( | |
Document( | |
page_content=text.strip(), | |
metadata={ | |
"id": metadata.get('id', 'unknown'), | |
"path": str(design_dir.parent) | |
} | |
) | |
) | |
except Exception as e: | |
print(f"Error processing design {design_dir}: {e}") | |
continue | |
if not documents: | |
print("Warning: No valid design documents found") | |
# Create empty vector store with a placeholder document | |
return FAISS.from_documents( | |
[Document(page_content="No designs available", metadata={"id": "placeholder"})], | |
self.embeddings | |
) | |
print(f"Loaded {len(documents)} design documents") | |
# Create and return vector store | |
return FAISS.from_documents(documents, self.embeddings) | |
except Exception as e: | |
print(f"Error creating vector store: {str(e)}") | |
raise | |
async def query_similar_designs(self, conversation_history: List[str], num_examples: int = 1) -> str: | |
"""Find similar designs based on conversation history""" | |
from langsmith import Client | |
from langchain.callbacks.tracers import ConsoleCallbackHandler | |
# Create LangSmith client | |
client = Client() | |
# Create query generation prompt with tracing | |
query_prompt = ChatPromptTemplate.from_template( | |
"""Based on this conversation history: | |
{conversation} | |
Extract the key design requirements and create a search query to find similar designs. | |
Focus on: | |
1. Visual style and aesthetics mentioned | |
2. Design categories and themes discussed | |
3. Key visual characteristics requested | |
4. Overall mood and impact desired | |
5. Any specific preferences or constraints | |
Return only the search query text, no additional explanation or analysis.""" | |
).with_config(tags=["query_generation"]) | |
# Format conversation history | |
conversation_text = "\n".join([ | |
f"{'User' if i % 2 == 0 else 'Assistant'}: {msg}" | |
for i, msg in enumerate(conversation_history) | |
]) | |
# Generate optimized search query with tracing | |
query_response = await self.llm.ainvoke( | |
query_prompt.format( | |
conversation=conversation_text | |
) | |
) | |
print(f"Generated query: {query_response.content}") | |
# Get relevant documents with tracing | |
docs = self.retriever.get_relevant_documents( | |
query_response.content, | |
k=num_examples, | |
callbacks=[ConsoleCallbackHandler()] # Use standard callback instead | |
) | |
# Format examples | |
examples = [] | |
for doc in docs: | |
design_id = doc.metadata.get("id", "unknown") | |
content_lines = doc.page_content.strip().split("\n") | |
examples.append( | |
"\n".join(line.strip() for line in content_lines if line.strip()) + | |
f"\nURL: https://csszengarden.com/{design_id}" | |
) | |
return "\n\n".join(examples) |