File size: 12,829 Bytes
2bb9056
 
 
721e8b5
2bb9056
721e8b5
 
68f6e3a
 
2bb9056
68f6e3a
 
 
2bb9056
 
 
721e8b5
 
2bb9056
68f6e3a
 
2bb9056
 
 
 
68f6e3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721e8b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bb9056
7c3b051
2bb9056
 
 
721e8b5
 
2bb9056
 
721e8b5
 
 
 
2bb9056
721e8b5
2bb9056
 
 
 
 
 
68f6e3a
2bb9056
1b92ee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bb9056
 
68f6e3a
7c3b051
 
 
 
 
68f6e3a
 
 
 
 
 
2bb9056
721e8b5
 
 
 
 
 
 
 
 
68f6e3a
721e8b5
7c3b051
721e8b5
 
 
7c3b051
 
 
2bb9056
 
 
 
7c3b051
 
 
 
 
2bb9056
68f6e3a
7c3b051
68f6e3a
7c3b051
 
 
 
 
 
 
 
 
 
 
68f6e3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bb9056
68f6e3a
 
2bb9056
 
68f6e3a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
import gradio as gr
import weaviate
import os
from openai import AsyncOpenAI
from dotenv import load_dotenv
import asyncio
from functools import wraps
import logging
import time

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()

# Set up AsyncOpenAI client
openai_client = AsyncOpenAI(api_key=os.getenv('OPENAI_API_KEY'))

# Initialize client as None
client = None

# Get the collection name from environment variable
COLLECTION_NAME = os.getenv('WEAVIATE_COLLECTION_NAME')

# Global variable to track connection status
# Global variable to track connection status
connection_status = {"status": "Disconnected", "color": "red"}

# Function to initialize the Weaviate client
async def initialize_weaviate_client(max_retries=3, retry_delay=5):
    global client, connection_status
    retries = 0
    while retries < max_retries:
        connection_status = {"status": "Connecting...", "color": "orange"}
        try:
            logger.info(f"Attempting to connect to Weaviate (Attempt {retries + 1}/{max_retries})")
            client = weaviate.Client(
                url=os.getenv('WCS_URL'),
                auth_client_secret=weaviate.auth.AuthApiKey(os.getenv('WCS_API_KEY')),
                additional_headers={
                    "X-OpenAI-Api-Key": os.getenv('OPENAI_API_KEY')
                }
            )
            # Test the connection
            await asyncio.to_thread(client.schema.get)
            connection_status = {"status": "Connected", "color": "green"}
            logger.info("Successfully connected to Weaviate")
            return connection_status
        except Exception as e:
            logger.error(f"Error connecting to Weaviate: {str(e)}")
            connection_status = {"status": f"Error: {str(e)}", "color": "red"}
            retries += 1
            if retries < max_retries:
                logger.info(f"Retrying in {retry_delay} seconds...")
                await asyncio.sleep(retry_delay)
            else:
                logger.error("Max retries reached. Could not connect to Weaviate.")
    return connection_status


# Async-compatible caching decorator
def async_lru_cache(maxsize=128):
    cache = {}

    def decorator(func):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            key = str(args) + str(kwargs)
            if key not in cache:
                if len(cache) >= maxsize:
                    cache.pop(next(iter(cache)))
                cache[key] = await func(*args, **kwargs)
            return cache[key]
        return wrapper
    return decorator

@async_lru_cache(maxsize=1000)
async def get_embedding(text):
    response = await openai_client.embeddings.create(
        input=text,
        model="text-embedding-3-large"
    )
    return response.data[0].embedding

async def search_multimodal(query: str, limit: int = 30, alpha: float = 0.6):
    query_vector = await get_embedding(query)
    
    try:
        response = await asyncio.to_thread(
            client.query.get(COLLECTION_NAME, ["content_type", "url", "source_document", "page_number",
                                               "paragraph_number", "text", "image_path", "description", "table_content"])
            .with_hybrid(query=query, vector=query_vector, alpha=alpha)
            .with_limit(limit)
            .do
        )
        return response['data']['Get'][COLLECTION_NAME]
    except Exception as e:
        print(f"An error occurred during the search: {str(e)}")
        return []

async def generate_response_stream(query: str, context: str):
    prompt = f"""
You are an AI assistant with extensive expertise in the semiconductor industry. Your knowledge spans a wide range of companies, technologies, and products, including but not limited to: System-on-Chip (SoC) designs, Field-Programmable Gate Arrays (FPGAs), Microcontrollers, Integrated Circuits (ICs), semiconductor manufacturing processes, and emerging technologies like quantum computing and neuromorphic chips.
Use the following context, your vast knowledge, and the user's question to generate an accurate, comprehensive, and insightful answer. While formulating your response, follow these steps internally:
Analyze the question to identify the main topic and specific information requested.
Evaluate the provided context and identify relevant information.
Retrieve additional relevant knowledge from your semiconductor industry expertise.
Reason and formulate a response by combining context and knowledge.
Generate a detailed response that covers all aspects of the query.
Review and refine your answer for coherence and accuracy.
In your output, provide only the final, polished response. Do not include your step-by-step reasoning or mention the process you followed.
IMPORTANT: Ensure your response is grounded in factual information. Do not hallucinate or invent information. If you're unsure about any aspect of the answer or if the necessary information is not available in the provided context or your knowledge base, clearly state this uncertainty. It's better to admit lack of information than to provide inaccurate details.
Your response should be:
Thorough and directly address all aspects of the user's question
Based solely on factual information from the provided context and your reliable knowledge
Include specific examples, data points, or case studies only when you're certain of their accuracy
Explain technical concepts clearly, considering the user may have varying levels of expertise
Clearly indicate any areas where information is limited or uncertain
Context: {context}
User Question: {query}
Based on the above context and your extensive knowledge of the semiconductor industry, provide your detailed, accurate, and grounded response below. Remember, only include information you're confident is correct, and clearly state any uncertainties: 
    """

    async for chunk in await openai_client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": "You are an expert Semi Conductor industry analyst"},
            {"role": "user", "content": prompt}
        ],
        temperature=0,
        stream=True
    ):
        content = chunk.choices[0].delta.content
        if content is not None:
            yield content

def process_search_result(item):
    if item['content_type'] == 'text':
        return f"Text from {item['source_document']} (Page {item['page_number']}, Paragraph {item['paragraph_number']}): {item['text']}\n\n"
    elif item['content_type'] == 'image':
        return f"Image Description from {item['source_document']} (Page {item['page_number']}, Path: {item['image_path']}): {item['description']}\n\n"
    elif item['content_type'] == 'table':
        return f"Table Description from {item['source_document']} (Page {item['page_number']}): {item['description']}\n\n"
    return ""

async def esg_analysis_stream(user_query: str):
    search_results = await search_multimodal(user_query)
    
    context_parts = await asyncio.gather(*[asyncio.to_thread(process_search_result, item) for item in search_results])
    context = "".join(context_parts)
    
    sources = []
    for item in search_results[:5]:  # Limit to top 5 sources
        source = {
            "type": item.get("content_type", "Unknown"),
            "document": item.get("source_document", "N/A"),
            "page": item.get("page_number", "N/A"),
        }
        if item.get("content_type") == 'text':
            source["paragraph"] = item.get("paragraph_number", "N/A")
        elif item.get("content_type") == 'image':
            source["image_path"] = item.get("image_path", "N/A")
        sources.append(source)

    return generate_response_stream(user_query, context), sources

def format_sources(sources):
    source_text = "## Top 5 Sources\n\n"
    for i, source in enumerate(sources, 1):
        source_text += f"### Source {i}\n"
        source_text += f"- **Type:** {source['type']}\n"
        source_text += f"- **Document:** {source['document']}\n"
        source_text += f"- **Page:** {source['page']}\n"
        if 'paragraph' in source:
            source_text += f"- **Paragraph:** {source['paragraph']}\n"
        if 'image_path' in source:
            source_text += f"- **Image Path:** {source['image_path']}\n"
        source_text += "\n"
    return source_text

# Custom CSS for the status box
custom_css = """
#status-box {
    position: absolute;
    top: 10px;
    right: 10px;
    background-color: white;
    padding: 5px 10px;
    border-radius: 5px;
    box-shadow: 0 2px 5px rgba(0,0,0,0.1);
    z-index: 1000;
    display: flex;
    align-items: center;
}
#status-light {
    width: 10px;
    height: 10px;
    border-radius: 50%;
    display: inline-block;
    margin-right: 5px;
}
#status-text {
    font-size: 14px;
    font-weight: bold;
}
"""

def get_connection_status():
    status = connection_status["status"]
    color = connection_status["color"]
    return f'<div id="status-box"><div id="status-light" style="background-color: {color};"></div><span id="status-text">{status}</span></div>'

async def check_connection():
    global connection_status
    try:
        if client:
            await asyncio.to_thread(client.schema.get)
            return {"status": "Connected", "color": "green"}
        else:
            return {"status": "Disconnected", "color": "red"}
    except Exception:
        return {"status": "Disconnected", "color": "red"}

async def update_status():
    global connection_status
    while True:
        new_status = await check_connection()
        if new_status != connection_status:
            connection_status = new_status
            yield new_status
        await asyncio.sleep(5)  # Check every 5 seconds

async def gradio_interface(user_question):
    if connection_status["status"] != "Connected":
        return "Error: Database not connected. Please wait for the connection to be established.", ""

    response_generator, sources = await esg_analysis_stream(user_question)
    formatted_sources = format_sources(sources)
    
    full_response = ""
    async for response_chunk in response_generator:
        full_response += response_chunk
    
    return full_response, formatted_sources

with gr.Blocks(css=custom_css) as iface:
    status_indicator = gr.HTML(get_connection_status())
    
    with gr.Row():
        gr.Markdown("# Semiconductor Industry Analysis")
    
    gr.Markdown("Ask questions about the semiconductor industry and get AI-powered answers with sources.")
    
    user_question = gr.Textbox(lines=2, placeholder="Enter your question about the semiconductor industry...", interactive=False)
    ai_response = gr.Markdown(label="AI Response")
    sources_output = gr.Markdown(label="Sources")
    
    submit_btn = gr.Button("Submit", interactive=False)
    
    submit_btn.click(
        fn=gradio_interface,
        inputs=user_question,
        outputs=[ai_response, sources_output],
    )

    # Update status
    def update_status_indicator(status):
        return get_connection_status()  # Return the HTML string directly

    def update_input_state(status):
        is_connected = status["status"] == "Connected"
        return gr.update(interactive=is_connected), gr.update(interactive=is_connected)

    status_updater = gr.State(connection_status)
    
    iface.load(
        lambda: connection_status,
        outputs=[status_updater],
        every=1,
    )

    status_updater.change(
        fn=update_status_indicator,
        inputs=[status_updater],
        outputs=[status_indicator],
    )

    status_updater.change(
        fn=update_input_state,
        inputs=[status_updater],
        outputs=[user_question, submit_btn],
    )

    status_updater = gr.State(connection_status)
    
    iface.load(
        lambda: connection_status,
        outputs=[status_updater],
        every=1,
    )

    status_updater.change(
        fn=update_status_indicator,
        inputs=[status_updater],
        outputs=[status_indicator],
    )

    status_updater.change(
        fn=update_input_state,
        inputs=[status_updater],
        outputs=[user_question, submit_btn],
    )

async def main():
    # Check environment variables
    required_env_vars = ['WCS_URL', 'WCS_API_KEY', 'OPENAI_API_KEY', 'WEAVIATE_COLLECTION_NAME']
    for var in required_env_vars:
        if not os.getenv(var):
            logger.error(f"Environment variable {var} is not set!")
            return

    # Initialize the client before launching the interface
    await initialize_weaviate_client()
    
    # Launch the interface regardless of connection status
    await iface.launch(server_name="0.0.0.0", server_port=7860, share=True)

if __name__ == "__main__":
    asyncio.run(main())