Technologic101 commited on
Commit
4180985
·
1 Parent(s): 91a124e

task: [wip] set up graph and node structure

Browse files
src/app.py CHANGED
@@ -1,7 +1,7 @@
1
  import chainlit as cl
2
  from langchain_openai import ChatOpenAI
3
  from langchain_core.messages import HumanMessage, SystemMessage
4
- from chains.design_rag import DesignRAG
5
 
6
  # Initialize components
7
  design_rag = DesignRAG()
@@ -14,8 +14,7 @@ For every user message, analyze their design preferences and requirements, consi
14
  3. Layout and structural needs
15
  4. Key visual elements
16
  5. Intended audience and user experience
17
-
18
- First briefly explain how you understand their requirements, then show the closest match."""
19
 
20
  @cl.on_chat_start
21
  async def init():
@@ -28,7 +27,7 @@ async def init():
28
  )
29
 
30
  # Store the LLM in the user session
31
- cl.user_session.set("llm", llm)
32
 
33
  # init conversation history for each user
34
  cl.user_session.set("conversation_history", [
@@ -41,9 +40,9 @@ async def init():
41
  @cl.on_message
42
  async def main(message: cl.Message):
43
  # Get the LLM from the user session
44
- llm = cl.user_session.get("llm")
45
-
46
  conversation_history = cl.user_session.get("conversation_history")
 
47
  # Add user message to history
48
  conversation_history.append(HumanMessage(content=message.content))
49
 
 
1
  import chainlit as cl
2
  from langchain_openai import ChatOpenAI
3
  from langchain_core.messages import HumanMessage, SystemMessage
4
+ from nodes.design_rag import DesignRAG
5
 
6
  # Initialize components
7
  design_rag = DesignRAG()
 
14
  3. Layout and structural needs
15
  4. Key visual elements
16
  5. Intended audience and user experience
17
+ """
 
18
 
19
  @cl.on_chat_start
20
  async def init():
 
27
  )
28
 
29
  # Store the LLM in the user session
30
+ cl.user_session.set("design_llm", llm)
31
 
32
  # init conversation history for each user
33
  cl.user_session.set("conversation_history", [
 
40
  @cl.on_message
41
  async def main(message: cl.Message):
42
  # Get the LLM from the user session
43
+ llm = cl.user_session.get("design_llm")
 
44
  conversation_history = cl.user_session.get("conversation_history")
45
+
46
  # Add user message to history
47
  conversation_history.append(HumanMessage(content=message.content))
48
 
src/graph.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated
2
+
3
+ from typing_extensions import TypedDict
4
+
5
+ from langgraph.graph import StateGraph, START, END
6
+ from langgraph.graph.message import add_messages
7
+
8
+
9
+ class State(TypedDict):
10
+ # Messages have the type "list". The `add_messages` function
11
+ # in the annotation defines how this state key should be updated
12
+ # (in this case, it appends messages to the list, rather than overwriting them)
13
+ messages: Annotated[list, add_messages]
14
+
15
+
16
+ graph_builder = StateGraph(State)
17
+
18
+
src/nodes/analyzer.py ADDED
File without changes
src/nodes/design_rag.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.runnables import RunnablePassthrough
2
+ from langchain_core.output_parsers import StrOutputParser
3
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
4
+ from langchain.smith import RunEvalConfig, run_on_dataset
5
+ import os
6
+
7
+ from langchain_community.vectorstores import FAISS
8
+ from langchain.prompts import ChatPromptTemplate
9
+ from pathlib import Path
10
+ import json
11
+ from typing import Dict, List, Optional
12
+ from langchain_core.documents import Document
13
+ from langchain.callbacks.tracers import ConsoleCallbackHandler
14
+
15
+ class DesignRAG:
16
+ def __init__(self):
17
+ # Get API keys from environment
18
+ api_key = os.getenv("OPENAI_API_KEY")
19
+ if not api_key:
20
+ raise ValueError(
21
+ "OPENAI_API_KEY environment variable not set. "
22
+ "Please set it in HuggingFace Spaces settings."
23
+ )
24
+
25
+ # Initialize embedding model with explicit API key
26
+ self.embeddings = OpenAIEmbeddings(
27
+ openai_api_key=api_key
28
+ )
29
+
30
+ # Load design data and create vector store
31
+ self.vector_store = self._create_vector_store()
32
+
33
+ # Create retriever with tracing
34
+ self.retriever = self.vector_store.as_retriever(
35
+ search_type="similarity",
36
+ search_kwargs={"k": 1},
37
+ tags=["design_retriever"] # Add tags for tracing
38
+ )
39
+
40
+ # Create LLM with tracing
41
+ self.llm = ChatOpenAI(
42
+ temperature=0.2,
43
+ tags=["design_llm"] # Add tags for tracing
44
+ )
45
+
46
+ def _create_vector_store(self) -> FAISS:
47
+ """Create FAISS vector store from design metadata"""
48
+ try:
49
+ # Update path to look in data/designs
50
+ designs_dir = Path(__file__).parent.parent / "data" / "designs"
51
+
52
+ documents = []
53
+
54
+ # Load all metadata files
55
+ for design_dir in designs_dir.glob("**/metadata.json"):
56
+ try:
57
+ with open(design_dir, "r") as f:
58
+ metadata = json.load(f)
59
+
60
+ # Create document text from metadata with safe gets
61
+ text = f"""
62
+ Design {metadata.get('id', 'unknown')}:
63
+ Description: {metadata.get('description', 'No description available')}
64
+ Categories: {', '.join(metadata.get('categories', []))}
65
+ Visual Characteristics: {', '.join(metadata.get('visual_characteristics', []))}
66
+ """
67
+
68
+ # Load associated CSS
69
+ '''
70
+ css_path = design_dir.parent / "style.css"
71
+ if css_path.exists():
72
+ with open(css_path, "r") as f:
73
+ css = f.read()
74
+ text += f"\nCSS:\n{css}"
75
+ '''
76
+
77
+ # Create Document object with minimal metadata
78
+ documents.append(
79
+ Document(
80
+ page_content=text.strip(),
81
+ metadata={
82
+ "id": metadata.get('id', 'unknown'),
83
+ "path": str(design_dir.parent)
84
+ }
85
+ )
86
+ )
87
+ except Exception as e:
88
+ print(f"Error processing design {design_dir}: {e}")
89
+ continue
90
+
91
+ if not documents:
92
+ print("Warning: No valid design documents found")
93
+ # Create empty vector store with a placeholder document
94
+ return FAISS.from_documents(
95
+ [Document(page_content="No designs available", metadata={"id": "placeholder"})],
96
+ self.embeddings
97
+ )
98
+
99
+ print(f"Loaded {len(documents)} design documents")
100
+ # Create and return vector store
101
+ return FAISS.from_documents(documents, self.embeddings)
102
+ except Exception as e:
103
+ print(f"Error creating vector store: {str(e)}")
104
+ raise
105
+
106
+ async def query_similar_designs(self, conversation_history: List[str], num_examples: int = 1) -> str:
107
+ """Find similar designs based on conversation history"""
108
+ from langsmith import Client
109
+ from langchain.callbacks.tracers import ConsoleCallbackHandler
110
+
111
+ # Create LangSmith client
112
+ client = Client()
113
+
114
+ # Create query generation prompt with tracing
115
+ query_prompt = ChatPromptTemplate.from_template(
116
+ """Based on this conversation history:
117
+ {conversation}
118
+ Extract the key design requirements and create a search query to find similar designs.
119
+ Focus on:
120
+ 1. Visual style and aesthetics mentioned
121
+ 2. Design categories and themes discussed
122
+ 3. Key visual characteristics requested
123
+ 4. Overall mood and impact desired
124
+ 5. Any specific preferences or constraints
125
+ Return only the search query text, no additional explanation or analysis."""
126
+ ).with_config(tags=["query_generation"])
127
+
128
+ # Format conversation history
129
+ conversation_text = "\n".join([
130
+ f"{'User' if i % 2 == 0 else 'Assistant'}: {msg}"
131
+ for i, msg in enumerate(conversation_history)
132
+ ])
133
+
134
+ # Generate optimized search query with tracing
135
+ query_response = await self.llm.ainvoke(
136
+ query_prompt.format(
137
+ conversation=conversation_text
138
+ )
139
+ )
140
+
141
+ print(f"Generated query: {query_response.content}")
142
+
143
+ # Get relevant documents with tracing
144
+ docs = self.retriever.get_relevant_documents(
145
+ query_response.content,
146
+ k=num_examples,
147
+ callbacks=[ConsoleCallbackHandler()] # Use standard callback instead
148
+ )
149
+
150
+ # Format examples
151
+ examples = []
152
+ for doc in docs:
153
+ design_id = doc.metadata.get("id", "unknown")
154
+ content_lines = doc.page_content.strip().split("\n")
155
+ examples.append(
156
+ "\n".join(line.strip() for line in content_lines if line.strip()) +
157
+ f"\nURL: https://csszengarden.com/{design_id}"
158
+ )
159
+
160
+ return "\n\n".join(examples)
src/nodes/designer.py ADDED
File without changes