from langchain_core.runnables import RunnablePassthrough from langchain_core.output_parsers import StrOutputParser from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain.smith import RunEvalConfig, run_on_dataset import os from langchain_community.vectorstores import FAISS from langchain.prompts import ChatPromptTemplate from pathlib import Path import json from typing import Dict, List, Optional from langchain_core.documents import Document from langchain.callbacks.tracers import ConsoleCallbackHandler class DesignRAG: def __init__(self): # Get API keys from environment api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError( "OPENAI_API_KEY environment variable not set. " "Please set it in HuggingFace Spaces settings." ) # Initialize embedding model with explicit API key self.embeddings = OpenAIEmbeddings( openai_api_key=api_key ) # Load design data and create vector store self.vector_store = self._create_vector_store() # Create retriever with tracing self.retriever = self.vector_store.as_retriever( search_type="similarity", search_kwargs={"k": 1}, tags=["design_retriever"] # Add tags for tracing ) # Create LLM with tracing self.llm = ChatOpenAI( temperature=0.2, tags=["design_llm"] # Add tags for tracing ) def _create_vector_store(self) -> FAISS: """Create FAISS vector store from design metadata""" try: # Update path to look in data/designs designs_dir = Path(__file__).parent.parent / "data" / "designs" documents = [] # Load all metadata files for design_dir in designs_dir.glob("**/metadata.json"): try: with open(design_dir, "r") as f: metadata = json.load(f) # Create document text from metadata with safe gets text = f""" Design {metadata.get('id', 'unknown')}: Description: {metadata.get('description', 'No description available')} Categories: {', '.join(metadata.get('categories', []))} Visual Characteristics: {', '.join(metadata.get('visual_characteristics', []))} """ # Load associated CSS ''' css_path = design_dir.parent / "style.css" if css_path.exists(): with open(css_path, "r") as f: css = f.read() text += f"\nCSS:\n{css}" ''' # Create Document object with minimal metadata documents.append( Document( page_content=text.strip(), metadata={ "id": metadata.get('id', 'unknown'), "path": str(design_dir.parent) } ) ) except Exception as e: print(f"Error processing design {design_dir}: {e}") continue if not documents: print("Warning: No valid design documents found") # Create empty vector store with a placeholder document return FAISS.from_documents( [Document(page_content="No designs available", metadata={"id": "placeholder"})], self.embeddings ) print(f"Loaded {len(documents)} design documents") # Create and return vector store return FAISS.from_documents(documents, self.embeddings) except Exception as e: print(f"Error creating vector store: {str(e)}") raise async def query_similar_designs(self, conversation_history: List[str], num_examples: int = 1) -> str: """Find similar designs based on conversation history""" from langsmith import Client from langchain.callbacks.tracers import ConsoleCallbackHandler # Create LangSmith client client = Client() # Create query generation prompt with tracing query_prompt = ChatPromptTemplate.from_template( """Based on this conversation history: {conversation} Extract the key design requirements and create a search query to find similar designs. Focus on: 1. Visual style and aesthetics mentioned 2. Design categories and themes discussed 3. Key visual characteristics requested 4. Overall mood and impact desired 5. Any specific preferences or constraints Return only the search query text, no additional explanation or analysis.""" ).with_config(tags=["query_generation"]) # Format conversation history conversation_text = "\n".join([ f"{'User' if i % 2 == 0 else 'Assistant'}: {msg}" for i, msg in enumerate(conversation_history) ]) # Generate optimized search query with tracing query_response = await self.llm.ainvoke( query_prompt.format( conversation=conversation_text ) ) print(f"Generated query: {query_response.content}") # Get relevant documents with tracing docs = self.retriever.get_relevant_documents( query_response.content, k=num_examples, callbacks=[ConsoleCallbackHandler()] # Use standard callback instead ) # Format examples examples = [] for doc in docs: design_id = doc.metadata.get("id", "unknown") content_lines = doc.page_content.strip().split("\n") examples.append( "\n".join(line.strip() for line in content_lines if line.strip()) + f"\nURL: https://csszengarden.com/{design_id}" ) return "\n\n".join(examples)