Spaces:
Build error
Build error
| import streamlit as st | |
| from apis import generate_sql, generate_chart | |
| # Page configuration | |
| st.set_page_config( | |
| page_title="Wren AI Cloud API Demo", | |
| page_icon="π", | |
| layout="wide" | |
| ) | |
| def main(): | |
| st.title("π Wren AI Cloud API Demo") | |
| st.markdown("Ask questions about your data and get both SQL queries and beautiful charts!") | |
| # Sidebar for API configuration | |
| with st.sidebar: | |
| st.header("π§ Configuration") | |
| api_key = st.text_input( | |
| "API Key", | |
| type="password", | |
| placeholder="sk-your-api-key-here", | |
| help="Enter your Wren AI Cloud API key" | |
| ) | |
| project_id = st.text_input( | |
| "Project ID", | |
| placeholder="1234", | |
| help="Enter your Wren AI Cloud project ID" | |
| ) | |
| # Sample size configuration | |
| sample_size = st.slider( | |
| "Chart Sample Size", | |
| min_value=100, | |
| max_value=10000, | |
| value=1000, | |
| step=100, | |
| help="Number of data points to include in charts" | |
| ) | |
| # Main chat interface | |
| if not api_key or not project_id: | |
| st.warning("β οΈ Please enter your API Key and Project ID in the sidebar to get started.") | |
| st.info(""" | |
| **How to get started:** | |
| 1. Enter your Wren AI Cloud API Key in the sidebar | |
| 2. Enter your Project ID | |
| 3. Ask questions about your data in natural language | |
| 4. Get SQL queries and interactive charts automatically! | |
| """) | |
| return | |
| # Initialize chat history | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "thread_id" not in st.session_state: | |
| st.session_state.thread_id = "" | |
| # Display chat history | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| if message["role"] == "user": | |
| st.write(message["content"]) | |
| else: | |
| st.write(message["content"]) | |
| if "sql" in message: | |
| with st.expander("π Generated SQL Query", expanded=False): | |
| st.code(message["sql"], language="sql") | |
| if "vega_spec" in message: | |
| try: | |
| with st.expander("π Chart Specification", expanded=False): | |
| st.json(message["vega_spec"]) | |
| st.vega_lite_chart(message["vega_spec"]) | |
| except Exception as e: | |
| st.toast(f"Error rendering chart: {e}", icon="π¨") | |
| # Chat input | |
| if prompt := st.chat_input("Ask a question about your data..."): | |
| # Add user message to chat history | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # Display user message | |
| with st.chat_message("user"): | |
| st.write(prompt) | |
| # Generate response | |
| with st.chat_message("assistant"): | |
| with st.spinner("Generating SQL query..."): | |
| sql_response, error = generate_sql(api_key, project_id, prompt, st.session_state.thread_id) | |
| if sql_response: | |
| sql_query = sql_response.get("sql", "") | |
| st.session_state.thread_id = sql_response.get("threadId", "") | |
| if sql_query: | |
| st.toast("SQL query generated successfully!", icon="π") | |
| # Store the response | |
| assistant_message = { | |
| "role": "assistant", | |
| "content": f"I've generated a SQL query for your question: '{prompt}'", | |
| "sql": sql_query | |
| } | |
| st.session_state.messages.append(assistant_message) | |
| st.write(assistant_message["content"]) | |
| # Display SQL query | |
| with st.expander("π Generated SQL Query", expanded=False): | |
| st.code(sql_query, language="sql") | |
| # Generate chart | |
| with st.spinner("Generating chart..."): | |
| chart_response, error = generate_chart( | |
| api_key, | |
| project_id, | |
| prompt, | |
| sql_query, | |
| thread_id=st.session_state.thread_id, | |
| sample_size=sample_size, | |
| ) | |
| if chart_response: | |
| vega_spec = chart_response.get("vegaSpec", {}) | |
| if vega_spec: | |
| st.toast("Chart generated successfully!", icon="π") | |
| assistant_message = { | |
| "role": "assistant", | |
| "content": f"I've generated a Chart for your question: '{prompt}'", | |
| "vega_spec": vega_spec | |
| } | |
| st.session_state.messages.append(assistant_message) | |
| st.write(assistant_message["content"]) | |
| # Display chart | |
| try: | |
| # Show chart specification in expander | |
| with st.expander("π Chart Specification", expanded=False): | |
| st.json(vega_spec) | |
| st.vega_lite_chart(vega_spec) | |
| except Exception as e: | |
| st.toast(f"Error rendering chart: {e}", icon="π¨") | |
| else: | |
| st.toast("Failed to generate chart. Please check your query and try again.", icon="π¨") | |
| else: | |
| st.toast(f"Failed to generate chart. Please check your query and try again.: {error}", icon="π¨") | |
| else: | |
| st.toast("No SQL query was generated. Please try rephrasing your question.", icon="π¨") | |
| assistant_message = { | |
| "role": "assistant", | |
| "content": "I couldn't generate a SQL query for your question. Please try rephrasing it or make sure it's related to your data." | |
| } | |
| st.session_state.messages.append(assistant_message) | |
| else: | |
| st.toast(f"Error generating SQL: {error}", icon="π¨") | |
| assistant_message = { | |
| "role": "assistant", | |
| "content": "Sorry, I couldn't process your request. Please check your API credentials and try again." | |
| } | |
| st.session_state.messages.append(assistant_message) | |
| # Clear chat button | |
| if st.sidebar.button("ποΈ Clear Chat History"): | |
| st.session_state.messages = [] | |
| st.session_state.thread_id = "" | |
| st.rerun() | |
| if __name__ == "__main__": | |
| main() |