Spaces:
Sleeping
Sleeping
| import json | |
| import tempfile | |
| import requests | |
| import streamlit as st | |
| from lagent.schema import AgentStatusCode | |
| from pyvis.network import Network | |
| # Function to create the network graph | |
| def create_network_graph(nodes, adjacency_list): | |
| net = Network(height="500px", width="60%", bgcolor="white", font_color="black") | |
| for node_id, node_content in nodes.items(): | |
| net.add_node(node_id, label=node_id, title=node_content, color="#FF5733", size=25) | |
| for node_id, neighbors in adjacency_list.items(): | |
| for neighbor in neighbors: | |
| if neighbor["name"] in nodes: | |
| net.add_edge(node_id, neighbor["name"]) | |
| net.show_buttons(filter_=["physics"]) | |
| return net | |
| # Function to draw the graph and return the HTML file path | |
| def draw_graph(net): | |
| path = tempfile.mktemp(suffix=".html") | |
| net.save_graph(path) | |
| return path | |
| def streaming(raw_response): | |
| for chunk in raw_response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\n"): | |
| if chunk: | |
| decoded = chunk.decode("utf-8") | |
| if decoded == "\r": | |
| continue | |
| if decoded[:6] == "data: ": | |
| decoded = decoded[6:] | |
| elif decoded.startswith(": ping - "): | |
| continue | |
| response = json.loads(decoded) | |
| yield ( | |
| response["current_node"], | |
| ( | |
| response["response"]["formatted"]["node"][response["current_node"]]["response"] | |
| if response["current_node"] | |
| else response["response"] | |
| ), | |
| response["response"]["formatted"]["adjacency_list"], | |
| ) | |
| # Initialize Streamlit session state | |
| if "queries" not in st.session_state: | |
| st.session_state["queries"] = [] | |
| st.session_state["responses"] = [] | |
| st.session_state["graphs_html"] = [] | |
| st.session_state["nodes_list"] = [] | |
| st.session_state["adjacency_list_list"] = [] | |
| st.session_state["history"] = [] | |
| st.session_state["already_used_keys"] = list() | |
| # Set up page layout | |
| st.set_page_config(layout="wide") | |
| st.title("MindSearch-思索") | |
| # Function to update chat | |
| def update_chat(query): | |
| with st.chat_message("user"): | |
| st.write(query) | |
| if query not in st.session_state["queries"]: | |
| # Mock data to simulate backend response | |
| # response, history, nodes, adjacency_list | |
| st.session_state["queries"].append(query) | |
| st.session_state["responses"].append([]) | |
| history = None | |
| # 暂不支持多轮 | |
| # message = [dict(role='user', content=query)] | |
| url = "http://localhost:8002/solve" | |
| headers = {"Content-Type": "application/json"} | |
| data = {"inputs": query} | |
| raw_response = requests.post( | |
| url, headers=headers, data=json.dumps(data), timeout=20, stream=True | |
| ) | |
| _nodes, _node_cnt = {}, 0 | |
| for resp in streaming(raw_response): | |
| node_name, response, adjacency_list = resp | |
| for name in set(adjacency_list) | { | |
| val["name"] for vals in adjacency_list.values() for val in vals | |
| }: | |
| if name not in _nodes: | |
| _nodes[name] = query if name == "root" else name | |
| elif response["stream_state"] == 0: | |
| _nodes[node_name or "response"] = response["formatted"] and response[ | |
| "formatted" | |
| ].get("thought") | |
| if len(_nodes) != _node_cnt or response["stream_state"] == 0: | |
| net = create_network_graph(_nodes, adjacency_list) | |
| graph_html_path = draw_graph(net) | |
| with open(graph_html_path, encoding="utf-8") as f: | |
| graph_html = f.read() | |
| _node_cnt = len(_nodes) | |
| else: | |
| graph_html = None | |
| if "graph_placeholder" not in st.session_state: | |
| st.session_state["graph_placeholder"] = st.empty() | |
| if "expander_placeholder" not in st.session_state: | |
| st.session_state["expander_placeholder"] = st.empty() | |
| if graph_html: | |
| with st.session_state["expander_placeholder"].expander( | |
| "Show Graph", expanded=False | |
| ): | |
| st.session_state["graph_placeholder"]._html(graph_html, height=500) | |
| if "container_placeholder" not in st.session_state: | |
| st.session_state["container_placeholder"] = st.empty() | |
| with st.session_state["container_placeholder"].container(): | |
| if "columns_placeholder" not in st.session_state: | |
| st.session_state["columns_placeholder"] = st.empty() | |
| col1, col2 = st.session_state["columns_placeholder"].columns([2, 1]) | |
| with col1: | |
| if "planner_placeholder" not in st.session_state: | |
| st.session_state["planner_placeholder"] = st.empty() | |
| if "session_info_temp" not in st.session_state: | |
| st.session_state["session_info_temp"] = "" | |
| if not node_name: | |
| if response["stream_state"] in [ | |
| AgentStatusCode.STREAM_ING, | |
| AgentStatusCode.CODING, | |
| AgentStatusCode.CODE_END, | |
| ]: | |
| content = response["formatted"]["thought"] | |
| if response["formatted"]["tool_type"]: | |
| action = response["formatted"]["action"] | |
| if isinstance(action, dict): | |
| action = json.dumps(action, ensure_ascii=False, indent=4) | |
| content += "\n" + action | |
| st.session_state["session_info_temp"] = content.replace( | |
| "<|action_start|><|interpreter|>\n", "\n" | |
| ) | |
| elif response["stream_state"] == AgentStatusCode.CODE_RETURN: | |
| # assert history[-1]["role"] == "environment" | |
| st.session_state["session_info_temp"] += "\n" + response["content"] | |
| st.session_state["planner_placeholder"].markdown( | |
| st.session_state["session_info_temp"] | |
| ) | |
| if response["stream_state"] == AgentStatusCode.CODE_RETURN: | |
| st.session_state["responses"][-1].append( | |
| st.session_state["session_info_temp"] | |
| ) | |
| st.session_state["session_info_temp"] = "" | |
| else: | |
| st.session_state["planner_placeholder"].markdown( | |
| st.session_state["responses"][-1][-1] | |
| if not st.session_state["session_info_temp"] | |
| else st.session_state["session_info_temp"] | |
| ) | |
| with col2: | |
| if "selectbox_placeholder" not in st.session_state: | |
| st.session_state["selectbox_placeholder"] = st.empty() | |
| if "searcher_placeholder" not in st.session_state: | |
| st.session_state["searcher_placeholder"] = st.empty() | |
| if node_name: | |
| selected_node_key = ( | |
| f"selected_node_{len(st.session_state['queries'])}_{node_name}" | |
| ) | |
| if selected_node_key not in st.session_state: | |
| st.session_state[selected_node_key] = node_name | |
| if selected_node_key not in st.session_state["already_used_keys"]: | |
| selected_node = st.session_state["selectbox_placeholder"].selectbox( | |
| "Select a node:", | |
| list(_nodes.keys()), | |
| key=f"key_{selected_node_key}", | |
| index=list(_nodes.keys()).index(node_name), | |
| ) | |
| st.session_state["already_used_keys"].append(selected_node_key) | |
| else: | |
| selected_node = node_name | |
| st.session_state[selected_node_key] = selected_node | |
| node_info_key = f"{selected_node}_info" | |
| if node_info_key not in st.session_state: | |
| st.session_state[node_info_key] = [["thought", ""]] | |
| if response["stream_state"] in [AgentStatusCode.STREAM_ING]: | |
| content = response["formatted"]["thought"] | |
| st.session_state[node_info_key][-1][1] = content.replace( | |
| "<|action_start|><|plugin|>\n", "\n```json\n" | |
| ) | |
| elif response["stream_state"] in [ | |
| AgentStatusCode.PLUGIN_START, | |
| AgentStatusCode.PLUGIN_END, | |
| ]: | |
| thought = response["formatted"]["thought"] | |
| action = response["formatted"]["action"] | |
| if isinstance(action, dict): | |
| action = json.dumps(action, ensure_ascii=False, indent=4) | |
| content = thought + "\n```json\n" + action | |
| if response["stream_state"] == AgentStatusCode.PLUGIN_RETURN: | |
| content += "\n```" | |
| st.session_state[node_info_key][-1][1] = content | |
| elif ( | |
| response["stream_state"] == AgentStatusCode.PLUGIN_RETURN | |
| and st.session_state[node_info_key][-1][1] | |
| ): | |
| try: | |
| content = json.loads(response["content"]) | |
| except json.decoder.JSONDecodeError: | |
| content = response["content"] | |
| st.session_state[node_info_key].append( | |
| [ | |
| "observation", | |
| ( | |
| content | |
| if isinstance(content, str) | |
| else f"```json\n{json.dumps(content, ensure_ascii=False, indent=4)}\n```" | |
| ), | |
| ] | |
| ) | |
| st.session_state["searcher_placeholder"].markdown( | |
| st.session_state[node_info_key][-1][1] | |
| ) | |
| if ( | |
| response["stream_state"] == AgentStatusCode.PLUGIN_RETURN | |
| and st.session_state[node_info_key][-1][1] | |
| ): | |
| st.session_state[node_info_key].append(["thought", ""]) | |
| if st.session_state["session_info_temp"]: | |
| st.session_state["responses"][-1].append(st.session_state["session_info_temp"]) | |
| st.session_state["session_info_temp"] = "" | |
| # st.session_state['responses'][-1] = '\n'.join(st.session_state['responses'][-1]) | |
| st.session_state["graphs_html"].append(graph_html) | |
| st.session_state["nodes_list"].append(_nodes) | |
| st.session_state["adjacency_list_list"].append(adjacency_list) | |
| st.session_state["history"] = history | |
| def display_chat_history(): | |
| for i, query in enumerate(st.session_state["queries"][-1:]): | |
| # with st.chat_message('assistant'): | |
| if st.session_state["graphs_html"][i]: | |
| with st.session_state["expander_placeholder"].expander("Show Graph", expanded=False): | |
| st.session_state["graph_placeholder"]._html( | |
| st.session_state["graphs_html"][i], height=500 | |
| ) | |
| with st.session_state["container_placeholder"].container(): | |
| col1, col2 = st.session_state["columns_placeholder"].columns([2, 1]) | |
| with col1: | |
| st.session_state["planner_placeholder"].markdown( | |
| st.session_state["responses"][-1][-1] | |
| ) | |
| with col2: | |
| selected_node_key = st.session_state["already_used_keys"][-1] | |
| st.session_state["selectbox_placeholder"] = st.empty() | |
| selected_node = st.session_state["selectbox_placeholder"].selectbox( | |
| "Select a node:", | |
| list(st.session_state["nodes_list"][i].keys()), | |
| key=f"replay_key_{i}", | |
| index=list(st.session_state["nodes_list"][i].keys()).index( | |
| st.session_state[selected_node_key] | |
| ), | |
| ) | |
| st.session_state[selected_node_key] = selected_node | |
| if ( | |
| selected_node not in ["root", "response"] | |
| and selected_node in st.session_state["nodes_list"][i] | |
| ): | |
| node_info_key = f"{selected_node}_info" | |
| for item in st.session_state[node_info_key]: | |
| if item[0] in ["thought", "answer"]: | |
| st.session_state["searcher_placeholder"] = st.empty() | |
| st.session_state["searcher_placeholder"].markdown(item[1]) | |
| elif item[0] == "observation": | |
| st.session_state["observation_expander"] = st.empty() | |
| with st.session_state["observation_expander"].expander("Results"): | |
| st.write(item[1]) | |
| # st.session_state['searcher_placeholder'].markdown(st.session_state[node_info_key]) | |
| def clean_history(): | |
| st.session_state["queries"] = [] | |
| st.session_state["responses"] = [] | |
| st.session_state["graphs_html"] = [] | |
| st.session_state["nodes_list"] = [] | |
| st.session_state["adjacency_list_list"] = [] | |
| st.session_state["history"] = [] | |
| st.session_state["already_used_keys"] = list() | |
| for k in st.session_state: | |
| if k.endswith("placeholder") or k.endswith("_info"): | |
| del st.session_state[k] | |
| # Main function to run the Streamlit app | |
| def main(): | |
| st.sidebar.title("Model Control") | |
| col1, col2 = st.columns([4, 1]) | |
| with col1: | |
| user_input = st.chat_input("Enter your query:") | |
| with col2: | |
| if st.button("Clear History"): | |
| clean_history() | |
| if user_input: | |
| update_chat(user_input) | |
| display_chat_history() | |
| if __name__ == "__main__": | |
| main() | |