|
import gradio as gr |
|
from gradio.helpers import Progress |
|
import asyncio |
|
import subprocess |
|
import yaml |
|
import os |
|
import networkx as nx |
|
import plotly.graph_objects as go |
|
import numpy as np |
|
import plotly.io as pio |
|
import lancedb |
|
import random |
|
import io |
|
import shutil |
|
import logging |
|
import queue |
|
import threading |
|
import time |
|
from collections import deque |
|
import re |
|
import glob |
|
from datetime import datetime |
|
import json |
|
import requests |
|
import aiohttp |
|
from openai import OpenAI |
|
from openai import AsyncOpenAI |
|
import pyarrow.parquet as pq |
|
import pandas as pd |
|
import sys |
|
import colorsys |
|
from dotenv import load_dotenv, set_key |
|
import argparse |
|
import socket |
|
import tiktoken |
|
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey |
|
from graphrag.query.indexer_adapters import ( |
|
read_indexer_covariates, |
|
read_indexer_entities, |
|
read_indexer_relationships, |
|
read_indexer_reports, |
|
read_indexer_text_units, |
|
) |
|
from graphrag.llm.openai import create_openai_chat_llm |
|
from graphrag.llm.openai.factories import create_openai_embedding_llm |
|
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings |
|
from graphrag.query.llm.oai.chat_openai import ChatOpenAI |
|
from graphrag.llm.openai.openai_configuration import OpenAIConfiguration |
|
from graphrag.llm.openai.openai_embeddings_llm import OpenAIEmbeddingsLLM |
|
from graphrag.query.llm.oai.typing import OpenaiApiType |
|
from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext |
|
from graphrag.query.structured_search.local_search.search import LocalSearch |
|
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext |
|
from graphrag.query.structured_search.global_search.search import GlobalSearch |
|
from graphrag.vector_stores.lancedb import LanceDBVectorStore |
|
import textwrap |
|
|
|
|
|
|
|
|
|
import warnings |
|
warnings.filterwarnings("ignore", category=UserWarning, module="gradio_client.documentation") |
|
|
|
|
|
load_dotenv('indexing/.env') |
|
|
|
|
|
os.environ.setdefault("LLM_API_BASE", os.getenv("LLM_API_BASE")) |
|
os.environ.setdefault("LLM_API_KEY", os.getenv("LLM_API_KEY")) |
|
os.environ.setdefault("LLM_MODEL", os.getenv("LLM_MODEL")) |
|
os.environ.setdefault("EMBEDDINGS_API_BASE", os.getenv("EMBEDDINGS_API_BASE")) |
|
os.environ.setdefault("EMBEDDINGS_API_KEY", os.getenv("EMBEDDINGS_API_KEY")) |
|
os.environ.setdefault("EMBEDDINGS_MODEL", os.getenv("EMBEDDINGS_MODEL")) |
|
|
|
|
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) |
|
sys.path.insert(0, project_root) |
|
|
|
|
|
|
|
log_queue = queue.Queue() |
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
llm = None |
|
text_embedder = None |
|
|
|
class QueueHandler(logging.Handler): |
|
def __init__(self, log_queue): |
|
super().__init__() |
|
self.log_queue = log_queue |
|
|
|
def emit(self, record): |
|
self.log_queue.put(self.format(record)) |
|
queue_handler = QueueHandler(log_queue) |
|
logging.getLogger().addHandler(queue_handler) |
|
|
|
|
|
|
|
def initialize_models(): |
|
global llm, text_embedder |
|
|
|
llm_api_base = os.getenv("LLM_API_BASE") |
|
llm_api_key = os.getenv("LLM_API_KEY") |
|
embeddings_api_base = os.getenv("EMBEDDINGS_API_BASE") |
|
embeddings_api_key = os.getenv("EMBEDDINGS_API_KEY") |
|
|
|
llm_service_type = os.getenv("LLM_SERVICE_TYPE", "openai_chat").lower() |
|
embeddings_service_type = os.getenv("EMBEDDINGS_SERVICE_TYPE", "openai").lower() |
|
|
|
llm_model = os.getenv("LLM_MODEL") |
|
embeddings_model = os.getenv("EMBEDDINGS_MODEL") |
|
|
|
logging.info("Fetching models...") |
|
models = fetch_models(llm_api_base, llm_api_key, llm_service_type) |
|
|
|
|
|
llm_models = models |
|
embeddings_models = models |
|
|
|
|
|
if llm_service_type == "openai_chat": |
|
llm = ChatOpenAI( |
|
api_key=llm_api_key, |
|
api_base=f"{llm_api_base}/v1", |
|
model=llm_model, |
|
api_type=OpenaiApiType.OpenAI, |
|
max_retries=20, |
|
) |
|
|
|
openai_client = OpenAI( |
|
api_key=embeddings_api_key or "dummy_key", |
|
base_url=f"{embeddings_api_base}/v1" |
|
) |
|
|
|
|
|
text_embedder = OpenAIEmbeddingsLLM( |
|
client=openai_client, |
|
configuration={ |
|
"model": embeddings_model, |
|
"api_type": "open_ai", |
|
"api_base": embeddings_api_base, |
|
"api_key": embeddings_api_key or None, |
|
"provider": embeddings_service_type |
|
} |
|
) |
|
|
|
return llm_models, embeddings_models, llm_service_type, embeddings_service_type, llm_api_base, embeddings_api_base, text_embedder |
|
|
|
def find_latest_output_folder(): |
|
root_dir = "./indexing/output" |
|
folders = [f for f in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, f))] |
|
|
|
if not folders: |
|
raise ValueError("No output folders found") |
|
|
|
|
|
sorted_folders = sorted(folders, key=lambda x: os.path.getctime(os.path.join(root_dir, x)), reverse=True) |
|
|
|
latest_folder = None |
|
timestamp = None |
|
|
|
for folder in sorted_folders: |
|
try: |
|
|
|
timestamp = datetime.strptime(folder, "%Y%m%d-%H%M%S") |
|
latest_folder = folder |
|
break |
|
except ValueError: |
|
|
|
continue |
|
|
|
if latest_folder is None: |
|
raise ValueError("No valid timestamp folders found") |
|
|
|
latest_path = os.path.join(root_dir, latest_folder) |
|
artifacts_path = os.path.join(latest_path, "artifacts") |
|
|
|
if not os.path.exists(artifacts_path): |
|
raise ValueError(f"Artifacts folder not found in {latest_path}") |
|
|
|
return latest_path, latest_folder |
|
|
|
def initialize_data(): |
|
global entity_df, relationship_df, text_unit_df, report_df, covariate_df |
|
|
|
tables = { |
|
"entity_df": "create_final_nodes", |
|
"relationship_df": "create_final_edges", |
|
"text_unit_df": "create_final_text_units", |
|
"report_df": "create_final_reports", |
|
"covariate_df": "create_final_covariates" |
|
} |
|
|
|
timestamp = None |
|
|
|
try: |
|
latest_output_folder, timestamp = find_latest_output_folder() |
|
artifacts_folder = os.path.join(latest_output_folder, "artifacts") |
|
|
|
for df_name, file_prefix in tables.items(): |
|
file_pattern = os.path.join(artifacts_folder, f"{file_prefix}*.parquet") |
|
matching_files = glob.glob(file_pattern) |
|
|
|
if matching_files: |
|
latest_file = max(matching_files, key=os.path.getctime) |
|
df = pd.read_parquet(latest_file) |
|
globals()[df_name] = df |
|
logging.info(f"Successfully loaded {df_name} from {latest_file}") |
|
else: |
|
logging.warning(f"No matching file found for {df_name} in {artifacts_folder}. Initializing as an empty DataFrame.") |
|
globals()[df_name] = pd.DataFrame() |
|
|
|
except Exception as e: |
|
logging.error(f"Error initializing data: {str(e)}") |
|
for df_name in tables.keys(): |
|
globals()[df_name] = pd.DataFrame() |
|
|
|
return timestamp |
|
|
|
|
|
current_timestamp = initialize_data() |
|
|
|
|
|
def find_available_port(start_port, max_attempts=100): |
|
for port in range(start_port, start_port + max_attempts): |
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
|
try: |
|
s.bind(('', port)) |
|
return port |
|
except OSError: |
|
continue |
|
raise IOError("No free ports found") |
|
|
|
def start_api_server(port): |
|
subprocess.Popen([sys.executable, "api_server.py", "--port", str(port)]) |
|
|
|
def wait_for_api_server(port): |
|
max_retries = 30 |
|
for _ in range(max_retries): |
|
try: |
|
response = requests.get(f"http://localhost:{port}") |
|
if response.status_code == 200: |
|
print(f"API server is up and running on port {port}") |
|
return |
|
else: |
|
print(f"Unexpected response from API server: {response.status_code}") |
|
except requests.ConnectionError: |
|
time.sleep(1) |
|
print("Failed to connect to API server") |
|
|
|
def load_settings(): |
|
try: |
|
with open("indexing/settings.yaml", "r") as f: |
|
return yaml.safe_load(f) or {} |
|
except FileNotFoundError: |
|
return {} |
|
|
|
def update_setting(key, value): |
|
settings = load_settings() |
|
try: |
|
settings[key] = json.loads(value) |
|
except json.JSONDecodeError: |
|
settings[key] = value |
|
|
|
try: |
|
with open("indexing/settings.yaml", "w") as f: |
|
yaml.dump(settings, f, default_flow_style=False) |
|
return f"Setting '{key}' updated successfully" |
|
except Exception as e: |
|
return f"Error updating setting '{key}': {str(e)}" |
|
|
|
def create_setting_component(key, value): |
|
with gr.Accordion(key, open=False): |
|
if isinstance(value, (dict, list)): |
|
value_str = json.dumps(value, indent=2) |
|
lines = value_str.count('\n') + 1 |
|
else: |
|
value_str = str(value) |
|
lines = 1 |
|
|
|
text_area = gr.TextArea(value=value_str, label="Value", lines=lines, max_lines=20) |
|
update_btn = gr.Button("Update", variant="primary") |
|
status = gr.Textbox(label="Status", visible=False) |
|
|
|
update_btn.click( |
|
fn=update_setting, |
|
inputs=[gr.Textbox(value=key, visible=False), text_area], |
|
outputs=[status] |
|
).then( |
|
fn=lambda: gr.update(visible=True), |
|
outputs=[status] |
|
) |
|
|
|
|
|
|
|
def get_openai_client(): |
|
return OpenAI( |
|
base_url=os.getenv("LLM_API_BASE"), |
|
api_key=os.getenv("LLM_API_KEY"), |
|
llm_model = os.getenv("LLM_MODEL") |
|
) |
|
|
|
async def chat_with_openai(messages, model, temperature, max_tokens, api_base): |
|
client = AsyncOpenAI( |
|
base_url=api_base, |
|
api_key=os.getenv("LLM_API_KEY") |
|
) |
|
|
|
try: |
|
response = await client.chat.completions.create( |
|
model=model, |
|
messages=messages, |
|
temperature=temperature, |
|
max_tokens=max_tokens |
|
) |
|
return response.choices[0].message.content |
|
except Exception as e: |
|
logging.error(f"Error in chat_with_openai: {str(e)}") |
|
return f"An error occurred: {str(e)}" |
|
return f"Error: {str(e)}" |
|
|
|
def chat_with_llm(query, history, system_message, temperature, max_tokens, model, api_base): |
|
try: |
|
messages = [{"role": "system", "content": system_message}] |
|
for item in history: |
|
if isinstance(item, tuple) and len(item) == 2: |
|
human, ai = item |
|
messages.append({"role": "user", "content": human}) |
|
messages.append({"role": "assistant", "content": ai}) |
|
messages.append({"role": "user", "content": query}) |
|
|
|
logging.info(f"Sending chat request to {api_base} with model {model}") |
|
client = OpenAI(base_url=api_base, api_key=os.getenv("LLM_API_KEY", "dummy-key")) |
|
response = client.chat.completions.create( |
|
model=model, |
|
messages=messages, |
|
temperature=temperature, |
|
max_tokens=max_tokens |
|
) |
|
return response.choices[0].message.content |
|
except Exception as e: |
|
logging.error(f"Error in chat_with_llm: {str(e)}") |
|
logging.error(f"Attempted with model: {model}, api_base: {api_base}") |
|
raise RuntimeError(f"Chat request failed: {str(e)}") |
|
|
|
def run_graphrag_query(cli_args): |
|
try: |
|
command = ' '.join(cli_args) |
|
logging.info(f"Executing command: {command}") |
|
result = subprocess.run(cli_args, capture_output=True, text=True, check=True) |
|
return result.stdout.strip() |
|
except subprocess.CalledProcessError as e: |
|
logging.error(f"Error running GraphRAG query: {e}") |
|
logging.error(f"Command output (stdout): {e.stdout}") |
|
logging.error(f"Command output (stderr): {e.stderr}") |
|
raise RuntimeError(f"GraphRAG query failed: {e.stderr}") |
|
|
|
def parse_query_response(response: str): |
|
try: |
|
|
|
parts = response.split("\n\n", 1) |
|
if len(parts) < 2: |
|
return response |
|
|
|
metadata_str, content = parts |
|
metadata = json.loads(metadata_str) |
|
|
|
|
|
query_type = metadata.get("query_type", "Unknown") |
|
execution_time = metadata.get("execution_time", "N/A") |
|
tokens_used = metadata.get("tokens_used", "N/A") |
|
|
|
|
|
content_lines = content.split('\n') |
|
filtered_content = '\n'.join([line for line in content_lines if not line.startswith("INFO:") and not line.startswith("creating llm client")]) |
|
|
|
|
|
parsed_response = f""" |
|
Query Type: {query_type} |
|
Execution Time: {execution_time} seconds |
|
Tokens Used: {tokens_used} |
|
|
|
{filtered_content.strip()} |
|
""" |
|
return parsed_response |
|
except Exception as e: |
|
print(f"Error parsing query response: {str(e)}") |
|
return response |
|
|
|
def send_message(query_type, query, history, system_message, temperature, max_tokens, preset, community_level, response_type, custom_cli_args, selected_folder): |
|
try: |
|
if query_type in ["global", "local"]: |
|
cli_args = construct_cli_args(query_type, preset, community_level, response_type, custom_cli_args, query, selected_folder) |
|
logging.info(f"Executing {query_type} search with command: {' '.join(cli_args)}") |
|
result = run_graphrag_query(cli_args) |
|
parsed_result = parse_query_response(result) |
|
logging.info(f"Parsed query result: {parsed_result}") |
|
else: |
|
llm_model = os.getenv("LLM_MODEL") |
|
api_base = os.getenv("LLM_API_BASE") |
|
logging.info(f"Executing direct chat with model: {llm_model}") |
|
|
|
try: |
|
result = chat_with_llm(query, history, system_message, temperature, max_tokens, llm_model, api_base) |
|
parsed_result = result |
|
logging.info(f"Direct chat result: {parsed_result[:100]}...") |
|
except Exception as chat_error: |
|
logging.error(f"Error in chat_with_llm: {str(chat_error)}") |
|
raise RuntimeError(f"Direct chat failed: {str(chat_error)}") |
|
|
|
history.append((query, parsed_result)) |
|
except Exception as e: |
|
error_message = f"An error occurred: {str(e)}" |
|
logging.error(error_message) |
|
logging.exception("Exception details:") |
|
history.append((query, error_message)) |
|
|
|
return history, gr.update(value=""), update_logs() |
|
|
|
def construct_cli_args(query_type, preset, community_level, response_type, custom_cli_args, query, selected_folder): |
|
if not selected_folder: |
|
raise ValueError("No folder selected. Please select an output folder before querying.") |
|
|
|
artifacts_folder = os.path.join("./indexing/output", selected_folder, "artifacts") |
|
if not os.path.exists(artifacts_folder): |
|
raise ValueError(f"Artifacts folder not found in {artifacts_folder}") |
|
|
|
base_args = [ |
|
"python", "-m", "graphrag.query", |
|
"--data", artifacts_folder, |
|
"--method", query_type, |
|
] |
|
|
|
|
|
if preset.startswith("Default"): |
|
base_args.extend(["--community_level", "2", "--response_type", "Multiple Paragraphs"]) |
|
elif preset.startswith("Detailed"): |
|
base_args.extend(["--community_level", "4", "--response_type", "Multi-Page Report"]) |
|
elif preset.startswith("Quick"): |
|
base_args.extend(["--community_level", "1", "--response_type", "Single Paragraph"]) |
|
elif preset.startswith("Bullet"): |
|
base_args.extend(["--community_level", "2", "--response_type", "List of 3-7 Points"]) |
|
elif preset.startswith("Comprehensive"): |
|
base_args.extend(["--community_level", "5", "--response_type", "Multi-Page Report"]) |
|
elif preset.startswith("High-Level"): |
|
base_args.extend(["--community_level", "1", "--response_type", "Single Page"]) |
|
elif preset.startswith("Focused"): |
|
base_args.extend(["--community_level", "3", "--response_type", "Multiple Paragraphs"]) |
|
elif preset == "Custom Query": |
|
base_args.extend([ |
|
"--community_level", str(community_level), |
|
"--response_type", f'"{response_type}"', |
|
]) |
|
if custom_cli_args: |
|
base_args.extend(custom_cli_args.split()) |
|
|
|
|
|
base_args.append(query) |
|
|
|
return base_args |
|
|
|
|
|
|
|
|
|
|
|
|
|
def upload_file(file): |
|
if file is not None: |
|
input_dir = os.path.join("indexing", "input") |
|
os.makedirs(input_dir, exist_ok=True) |
|
|
|
|
|
original_filename = file.name |
|
|
|
|
|
destination_path = os.path.join(input_dir, os.path.basename(original_filename)) |
|
|
|
|
|
shutil.move(file.name, destination_path) |
|
|
|
logging.info(f"File uploaded and moved to: {destination_path}") |
|
status = f"File uploaded: {os.path.basename(original_filename)}" |
|
else: |
|
status = "No file uploaded" |
|
|
|
|
|
updated_file_list = [f["path"] for f in list_input_files()] |
|
|
|
return status, gr.update(choices=updated_file_list), update_logs() |
|
|
|
def list_input_files(): |
|
input_dir = os.path.join("indexing", "input") |
|
files = [] |
|
if os.path.exists(input_dir): |
|
files = os.listdir(input_dir) |
|
return [{"name": f, "path": os.path.join(input_dir, f)} for f in files] |
|
|
|
def delete_file(file_path): |
|
try: |
|
os.remove(file_path) |
|
logging.info(f"File deleted: {file_path}") |
|
status = f"File deleted: {os.path.basename(file_path)}" |
|
except Exception as e: |
|
logging.error(f"Error deleting file: {str(e)}") |
|
status = f"Error deleting file: {str(e)}" |
|
|
|
|
|
updated_file_list = [f["path"] for f in list_input_files()] |
|
|
|
return status, gr.update(choices=updated_file_list), update_logs() |
|
|
|
def read_file_content(file_path): |
|
try: |
|
if file_path.endswith('.parquet'): |
|
df = pd.read_parquet(file_path) |
|
|
|
|
|
info = f"Parquet File: {os.path.basename(file_path)}\n" |
|
info += f"Rows: {len(df)}, Columns: {len(df.columns)}\n\n" |
|
info += "Column Names:\n" + "\n".join(df.columns) + "\n\n" |
|
|
|
|
|
info += "First 5 rows:\n" |
|
info += df.head().to_string() + "\n\n" |
|
|
|
|
|
info += "Basic Statistics:\n" |
|
info += df.describe().to_string() |
|
|
|
return info |
|
else: |
|
with open(file_path, 'r', encoding='utf-8', errors='replace') as file: |
|
content = file.read() |
|
return content |
|
except Exception as e: |
|
logging.error(f"Error reading file: {str(e)}") |
|
return f"Error reading file: {str(e)}" |
|
|
|
def save_file_content(file_path, content): |
|
try: |
|
with open(file_path, 'w') as file: |
|
file.write(content) |
|
logging.info(f"File saved: {file_path}") |
|
status = f"File saved: {os.path.basename(file_path)}" |
|
except Exception as e: |
|
logging.error(f"Error saving file: {str(e)}") |
|
status = f"Error saving file: {str(e)}" |
|
return status, update_logs() |
|
|
|
def manage_data(): |
|
db = lancedb.connect("./indexing/lancedb") |
|
tables = db.table_names() |
|
table_info = "" |
|
if tables: |
|
table = db[tables[0]] |
|
table_info = f"Table: {tables[0]}\nSchema: {table.schema}" |
|
|
|
input_files = list_input_files() |
|
|
|
return { |
|
"database_info": f"Tables: {', '.join(tables)}\n\n{table_info}", |
|
"input_files": input_files |
|
} |
|
|
|
|
|
def find_latest_graph_file(root_dir): |
|
pattern = os.path.join(root_dir, "output", "*", "artifacts", "*.graphml") |
|
graph_files = glob.glob(pattern) |
|
if not graph_files: |
|
|
|
output_dir = os.path.join(root_dir, "output") |
|
run_dirs = [d for d in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, d)) and d != ".DS_Store"] |
|
if run_dirs: |
|
latest_run = max(run_dirs) |
|
pattern = os.path.join(root_dir, "output", latest_run, "artifacts", "*.graphml") |
|
graph_files = glob.glob(pattern) |
|
|
|
if not graph_files: |
|
return None |
|
|
|
|
|
latest_file = max(graph_files, key=os.path.getmtime) |
|
return latest_file |
|
|
|
def update_visualization(folder_name, file_name, layout_type, node_size, edge_width, node_color_attribute, color_scheme, show_labels, label_size): |
|
root_dir = "./indexing" |
|
if not folder_name or not file_name: |
|
return None, "Please select a folder and a GraphML file." |
|
file_name = file_name.split("] ")[1] if "]" in file_name else file_name |
|
graph_path = os.path.join(root_dir, "output", folder_name, "artifacts", file_name) |
|
if not graph_path.endswith('.graphml'): |
|
return None, "Please select a GraphML file for visualization." |
|
try: |
|
|
|
graph = nx.read_graphml(graph_path) |
|
|
|
|
|
if layout_type == "3D Spring": |
|
pos = nx.spring_layout(graph, dim=3, seed=42, k=0.5) |
|
elif layout_type == "2D Spring": |
|
pos = nx.spring_layout(graph, dim=2, seed=42, k=0.5) |
|
else: |
|
pos = nx.circular_layout(graph) |
|
|
|
|
|
if layout_type == "3D Spring": |
|
x_nodes = [pos[node][0] for node in graph.nodes()] |
|
y_nodes = [pos[node][1] for node in graph.nodes()] |
|
z_nodes = [pos[node][2] for node in graph.nodes()] |
|
else: |
|
x_nodes = [pos[node][0] for node in graph.nodes()] |
|
y_nodes = [pos[node][1] for node in graph.nodes()] |
|
z_nodes = [0] * len(graph.nodes()) |
|
|
|
|
|
x_edges, y_edges, z_edges = [], [], [] |
|
for edge in graph.edges(): |
|
x_edges.extend([pos[edge[0]][0], pos[edge[1]][0], None]) |
|
y_edges.extend([pos[edge[0]][1], pos[edge[1]][1], None]) |
|
if layout_type == "3D Spring": |
|
z_edges.extend([pos[edge[0]][2], pos[edge[1]][2], None]) |
|
else: |
|
z_edges.extend([0, 0, None]) |
|
|
|
|
|
if node_color_attribute == "Degree": |
|
node_colors = [graph.degree(node) for node in graph.nodes()] |
|
else: |
|
node_colors = [random.random() for _ in graph.nodes()] |
|
node_colors = np.array(node_colors) |
|
node_colors = (node_colors - node_colors.min()) / (node_colors.max() - node_colors.min()) |
|
|
|
|
|
edge_trace = go.Scatter3d( |
|
x=x_edges, y=y_edges, z=z_edges, |
|
mode='lines', |
|
line=dict(color='lightgray', width=edge_width), |
|
hoverinfo='none' |
|
) |
|
|
|
|
|
node_trace = go.Scatter3d( |
|
x=x_nodes, y=y_nodes, z=z_nodes, |
|
mode='markers+text' if show_labels else 'markers', |
|
marker=dict( |
|
size=node_size, |
|
color=node_colors, |
|
colorscale=color_scheme, |
|
colorbar=dict( |
|
title='Node Degree' if node_color_attribute == "Degree" else "Random Value", |
|
thickness=10, |
|
x=1.1, |
|
tickvals=[0, 1], |
|
ticktext=['Low', 'High'] |
|
), |
|
line=dict(width=1) |
|
), |
|
text=[node for node in graph.nodes()], |
|
textposition="top center", |
|
textfont=dict(size=label_size, color='black'), |
|
hoverinfo='text' |
|
) |
|
|
|
|
|
fig = go.Figure(data=[edge_trace, node_trace]) |
|
|
|
|
|
fig.update_layout( |
|
title=f'{layout_type} Graph Visualization: {os.path.basename(graph_path)}', |
|
showlegend=False, |
|
scene=dict( |
|
xaxis=dict(showbackground=False, showticklabels=False, title=''), |
|
yaxis=dict(showbackground=False, showticklabels=False, title=''), |
|
zaxis=dict(showbackground=False, showticklabels=False, title='') |
|
), |
|
margin=dict(l=0, r=0, b=0, t=40), |
|
annotations=[ |
|
dict( |
|
showarrow=False, |
|
text=f"Interactive {layout_type} visualization of GraphML data", |
|
xref="paper", |
|
yref="paper", |
|
x=0, |
|
y=0 |
|
) |
|
], |
|
autosize=True |
|
) |
|
|
|
fig.update_layout(autosize=True) |
|
fig.update_layout(height=600) |
|
return fig, f"Graph visualization generated successfully. Using file: {graph_path}" |
|
except Exception as e: |
|
return go.Figure(), f"Error visualizing graph: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
def update_logs(): |
|
logs = [] |
|
while not log_queue.empty(): |
|
logs.append(log_queue.get()) |
|
return "\n".join(logs) |
|
|
|
|
|
|
|
def fetch_models(base_url, api_key, service_type): |
|
try: |
|
if service_type.lower() == "ollama": |
|
response = requests.get(f"{base_url}/tags", timeout=10) |
|
else: |
|
headers = { |
|
"Authorization": f"Bearer {api_key}", |
|
"Content-Type": "application/json" |
|
} |
|
response = requests.get(f"{base_url}/models", headers=headers, timeout=10) |
|
|
|
logging.info(f"Raw API response: {response.text}") |
|
|
|
if response.status_code == 200: |
|
data = response.json() |
|
if service_type.lower() == "ollama": |
|
models = [model.get('name', '') for model in data.get('models', data) if isinstance(model, dict)] |
|
else: |
|
models = [model.get('id', '') for model in data.get('data', []) if isinstance(model, dict)] |
|
|
|
models = [model for model in models if model] |
|
|
|
if not models: |
|
logging.warning(f"No models found in {service_type} API response") |
|
return ["No models available"] |
|
|
|
logging.info(f"Successfully fetched {service_type} models: {models}") |
|
return models |
|
else: |
|
logging.error(f"Error fetching {service_type} models. Status code: {response.status_code}, Response: {response.text}") |
|
return ["Error fetching models"] |
|
except requests.RequestException as e: |
|
logging.error(f"Exception while fetching {service_type} models: {str(e)}") |
|
return ["Error: Connection failed"] |
|
except Exception as e: |
|
logging.error(f"Unexpected error in fetch_models: {str(e)}") |
|
return ["Error: Unexpected issue"] |
|
|
|
def update_model_choices(base_url, api_key, service_type, settings_key): |
|
models = fetch_models(base_url, api_key, service_type) |
|
|
|
if not models: |
|
logging.warning(f"No models fetched for {service_type}.") |
|
|
|
|
|
current_model = settings.get(settings_key, {}).get('llm', {}).get('model') |
|
|
|
|
|
if current_model and current_model not in models: |
|
models.append(current_model) |
|
|
|
return gr.update(choices=models, value=current_model if current_model in models else (models[0] if models else None)) |
|
|
|
def update_llm_model_choices(base_url, api_key, service_type): |
|
return update_model_choices(base_url, api_key, service_type, 'llm') |
|
|
|
def update_embeddings_model_choices(base_url, api_key, service_type): |
|
return update_model_choices(base_url, api_key, service_type, 'embeddings') |
|
|
|
|
|
|
|
|
|
def update_llm_settings(llm_model, embeddings_model, context_window, system_message, temperature, max_tokens, |
|
llm_api_base, llm_api_key, |
|
embeddings_api_base, embeddings_api_key, embeddings_service_type): |
|
try: |
|
|
|
settings = load_settings() |
|
settings['llm'].update({ |
|
"type": "openai", |
|
"model": llm_model, |
|
"api_base": llm_api_base, |
|
"api_key": "${GRAPHRAG_API_KEY}", |
|
"temperature": temperature, |
|
"max_tokens": max_tokens, |
|
"provider": "openai_chat" |
|
}) |
|
settings['embeddings']['llm'].update({ |
|
"type": "openai_embedding", |
|
"model": embeddings_model, |
|
"api_base": embeddings_api_base, |
|
"api_key": "${GRAPHRAG_API_KEY}", |
|
"provider": embeddings_service_type |
|
}) |
|
|
|
with open("indexing/settings.yaml", 'w') as f: |
|
yaml.dump(settings, f, default_flow_style=False) |
|
|
|
|
|
update_env_file("LLM_API_BASE", llm_api_base) |
|
update_env_file("LLM_API_KEY", llm_api_key) |
|
update_env_file("LLM_MODEL", llm_model) |
|
update_env_file("EMBEDDINGS_API_BASE", embeddings_api_base) |
|
update_env_file("EMBEDDINGS_API_KEY", embeddings_api_key) |
|
update_env_file("EMBEDDINGS_MODEL", embeddings_model) |
|
update_env_file("CONTEXT_WINDOW", str(context_window)) |
|
update_env_file("SYSTEM_MESSAGE", system_message) |
|
update_env_file("TEMPERATURE", str(temperature)) |
|
update_env_file("MAX_TOKENS", str(max_tokens)) |
|
update_env_file("LLM_SERVICE_TYPE", "openai_chat") |
|
update_env_file("EMBEDDINGS_SERVICE_TYPE", embeddings_service_type) |
|
|
|
|
|
load_dotenv(override=True) |
|
|
|
return "LLM and embeddings settings updated successfully in both settings.yaml and .env files." |
|
except Exception as e: |
|
return f"Error updating LLM and embeddings settings: {str(e)}" |
|
|
|
def update_env_file(key, value): |
|
env_path = 'indexing/.env' |
|
with open(env_path, 'r') as file: |
|
lines = file.readlines() |
|
|
|
updated = False |
|
for i, line in enumerate(lines): |
|
if line.startswith(f"{key}="): |
|
lines[i] = f"{key}={value}\n" |
|
updated = True |
|
break |
|
|
|
if not updated: |
|
lines.append(f"{key}={value}\n") |
|
|
|
with open(env_path, 'w') as file: |
|
file.writelines(lines) |
|
|
|
custom_css = """ |
|
html, body { |
|
margin: 0; |
|
padding: 0; |
|
height: 100vh; |
|
overflow: hidden; |
|
} |
|
|
|
.gradio-container { |
|
margin: 0 !important; |
|
padding: 0 !important; |
|
width: 100vw !important; |
|
max-width: 100vw !important; |
|
height: 100vh !important; |
|
max-height: 100vh !important; |
|
overflow: auto; |
|
display: flex; |
|
flex-direction: column; |
|
} |
|
|
|
#main-container { |
|
flex: 1; |
|
display: flex; |
|
overflow: hidden; |
|
} |
|
|
|
#left-column, #right-column { |
|
height: 100%; |
|
overflow-y: auto; |
|
padding: 10px; |
|
} |
|
|
|
#left-column { |
|
flex: 1; |
|
} |
|
|
|
#right-column { |
|
flex: 2; |
|
display: flex; |
|
flex-direction: column; |
|
} |
|
|
|
#chat-container { |
|
flex: 0 0 auto; /* Don't allow this to grow */ |
|
height: 100%; |
|
display: flex; |
|
flex-direction: column; |
|
overflow: hidden; |
|
border: 1px solid var(--color-accent); |
|
border-radius: 8px; |
|
padding: 10px; |
|
overflow-y: auto; |
|
} |
|
|
|
#chatbot { |
|
overflow-y: hidden; |
|
height: 100%; |
|
} |
|
|
|
#chat-input-row { |
|
margin-top: 10px; |
|
} |
|
|
|
#visualization-plot { |
|
width: 100%; |
|
aspect-ratio: 1 / 1; |
|
max-height: 600px; /* Adjust this value as needed */ |
|
} |
|
|
|
#vis-controls-row { |
|
display: flex; |
|
justify-content: space-between; |
|
align-items: center; |
|
margin-top: 10px; |
|
} |
|
|
|
#vis-controls-row > * { |
|
flex: 1; |
|
margin: 0 5px; |
|
} |
|
|
|
#vis-status { |
|
margin-top: 10px; |
|
} |
|
|
|
/* Chat input styling */ |
|
#chat-input-row { |
|
display: flex; |
|
flex-direction: column; |
|
} |
|
|
|
#chat-input-row > div { |
|
width: 100% !important; |
|
} |
|
|
|
#chat-input-row input[type="text"] { |
|
width: 100% !important; |
|
} |
|
|
|
/* Adjust padding for all containers */ |
|
.gr-box, .gr-form, .gr-panel { |
|
padding: 10px !important; |
|
} |
|
|
|
/* Ensure all textboxes and textareas have full height */ |
|
.gr-textbox, .gr-textarea { |
|
height: auto !important; |
|
min-height: 100px !important; |
|
} |
|
|
|
/* Ensure all dropdowns have full width */ |
|
.gr-dropdown { |
|
width: 100% !important; |
|
} |
|
|
|
:root { |
|
--color-background: #2C3639; |
|
--color-foreground: #3F4E4F; |
|
--color-accent: #A27B5C; |
|
--color-text: #DCD7C9; |
|
} |
|
|
|
body, .gradio-container { |
|
background-color: var(--color-background); |
|
color: var(--color-text); |
|
} |
|
|
|
.gr-button { |
|
background-color: var(--color-accent); |
|
color: var(--color-text); |
|
} |
|
|
|
.gr-input, .gr-textarea, .gr-dropdown { |
|
background-color: var(--color-foreground); |
|
color: var(--color-text); |
|
border: 1px solid var(--color-accent); |
|
} |
|
|
|
.gr-panel { |
|
background-color: var(--color-foreground); |
|
border: 1px solid var(--color-accent); |
|
} |
|
|
|
.gr-box { |
|
border-radius: 8px; |
|
margin-bottom: 10px; |
|
background-color: var(--color-foreground); |
|
} |
|
|
|
.gr-padded { |
|
padding: 10px; |
|
} |
|
|
|
.gr-form { |
|
background-color: var(--color-foreground); |
|
} |
|
|
|
.gr-input-label, .gr-radio-label { |
|
color: var(--color-text); |
|
} |
|
|
|
.gr-checkbox-label { |
|
color: var(--color-text); |
|
} |
|
|
|
.gr-markdown { |
|
color: var(--color-text); |
|
} |
|
|
|
.gr-accordion { |
|
background-color: var(--color-foreground); |
|
border: 1px solid var(--color-accent); |
|
} |
|
|
|
.gr-accordion-header { |
|
background-color: var(--color-accent); |
|
color: var(--color-text); |
|
} |
|
|
|
#visualization-container { |
|
display: flex; |
|
flex-direction: column; |
|
border: 2px solid var(--color-accent); |
|
border-radius: 8px; |
|
margin-top: 20px; |
|
padding: 10px; |
|
background-color: var(--color-foreground); |
|
height: calc(100vh - 300px); /* Adjust this value as needed */ |
|
} |
|
|
|
#visualization-plot { |
|
width: 100%; |
|
height: 100%; |
|
} |
|
|
|
#vis-controls-row { |
|
display: flex; |
|
justify-content: space-between; |
|
align-items: center; |
|
margin-top: 10px; |
|
} |
|
|
|
#vis-controls-row > * { |
|
flex: 1; |
|
margin: 0 5px; |
|
} |
|
|
|
#vis-status { |
|
margin-top: 10px; |
|
} |
|
|
|
#log-container { |
|
background-color: var(--color-foreground); |
|
border: 1px solid var(--color-accent); |
|
border-radius: 8px; |
|
padding: 10px; |
|
margin-top: 20px; |
|
max-height: auto; |
|
overflow-y: auto; |
|
} |
|
|
|
.setting-accordion .label-wrap { |
|
cursor: pointer; |
|
} |
|
|
|
.setting-accordion .icon { |
|
transition: transform 0.3s ease; |
|
} |
|
|
|
.setting-accordion[open] .icon { |
|
transform: rotate(90deg); |
|
} |
|
|
|
.gr-form.gr-box { |
|
border: none !important; |
|
background: none !important; |
|
} |
|
|
|
.model-params { |
|
border-top: 1px solid var(--color-accent); |
|
margin-top: 10px; |
|
padding-top: 10px; |
|
} |
|
""" |
|
|
|
def list_output_files(root_dir): |
|
output_dir = os.path.join(root_dir, "output") |
|
files = [] |
|
for root, _, filenames in os.walk(output_dir): |
|
for filename in filenames: |
|
files.append(os.path.join(root, filename)) |
|
return files |
|
|
|
def update_file_list(): |
|
files = list_input_files() |
|
return gr.update(choices=[f["path"] for f in files]) |
|
|
|
def update_file_content(file_path): |
|
if not file_path: |
|
return "" |
|
try: |
|
with open(file_path, 'r', encoding='utf-8') as file: |
|
content = file.read() |
|
return content |
|
except Exception as e: |
|
logging.error(f"Error reading file: {str(e)}") |
|
return f"Error reading file: {str(e)}" |
|
|
|
def list_output_folders(root_dir): |
|
output_dir = os.path.join(root_dir, "output") |
|
folders = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f))] |
|
return sorted(folders, reverse=True) |
|
|
|
def list_folder_contents(folder_path): |
|
contents = [] |
|
for item in os.listdir(folder_path): |
|
item_path = os.path.join(folder_path, item) |
|
if os.path.isdir(item_path): |
|
contents.append(f"[DIR] {item}") |
|
else: |
|
_, ext = os.path.splitext(item) |
|
contents.append(f"[{ext[1:].upper()}] {item}") |
|
return contents |
|
|
|
def update_output_folder_list(): |
|
root_dir = "./" |
|
folders = list_output_folders(root_dir) |
|
return gr.update(choices=folders, value=folders[0] if folders else None) |
|
|
|
def update_folder_content_list(folder_name): |
|
root_dir = "./" |
|
if not folder_name: |
|
return gr.update(choices=[]) |
|
contents = list_folder_contents(os.path.join(root_dir, "output", folder_name, "artifacts")) |
|
return gr.update(choices=contents) |
|
|
|
def handle_content_selection(folder_name, selected_item): |
|
root_dir = "./" |
|
if isinstance(selected_item, list) and selected_item: |
|
selected_item = selected_item[0] |
|
|
|
if isinstance(selected_item, str) and selected_item.startswith("[DIR]"): |
|
dir_name = selected_item[6:] |
|
sub_contents = list_folder_contents(os.path.join(root_dir, "output", folder_name, dir_name)) |
|
return gr.update(choices=sub_contents), "", "" |
|
elif isinstance(selected_item, str): |
|
file_name = selected_item.split("] ")[1] if "]" in selected_item else selected_item |
|
file_path = os.path.join(root_dir, "output", folder_name, "artifacts", file_name) |
|
file_size = os.path.getsize(file_path) |
|
file_type = os.path.splitext(file_name)[1] |
|
file_info = f"File: {file_name}\nSize: {file_size} bytes\nType: {file_type}" |
|
content = read_file_content(file_path) |
|
return gr.update(), file_info, content |
|
else: |
|
return gr.update(), "", "" |
|
|
|
def initialize_selected_folder(folder_name): |
|
root_dir = "./" |
|
if not folder_name: |
|
return "Please select a folder first.", gr.update(choices=[]) |
|
folder_path = os.path.join(root_dir, "output", folder_name, "artifacts") |
|
if not os.path.exists(folder_path): |
|
return f"Artifacts folder not found in '{folder_name}'.", gr.update(choices=[]) |
|
contents = list_folder_contents(folder_path) |
|
return f"Folder '{folder_name}/artifacts' initialized with {len(contents)} items.", gr.update(choices=contents) |
|
|
|
|
|
settings = load_settings() |
|
default_model = settings['llm']['model'] |
|
cli_args = gr.State({}) |
|
stop_indexing = threading.Event() |
|
indexing_thread = None |
|
|
|
def start_indexing(*args): |
|
global indexing_thread, stop_indexing |
|
stop_indexing = threading.Event() |
|
indexing_thread = threading.Thread(target=run_indexing, args=args) |
|
indexing_thread.start() |
|
return gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False) |
|
|
|
def stop_indexing_process(): |
|
global indexing_thread |
|
logging.info("Stop indexing requested") |
|
stop_indexing.set() |
|
if indexing_thread and indexing_thread.is_alive(): |
|
logging.info("Waiting for indexing thread to finish") |
|
indexing_thread.join(timeout=10) |
|
logging.info("Indexing thread finished" if not indexing_thread.is_alive() else "Indexing thread did not finish within timeout") |
|
indexing_thread = None |
|
return gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=True) |
|
|
|
def refresh_indexing(): |
|
global indexing_thread, stop_indexing |
|
if indexing_thread and indexing_thread.is_alive(): |
|
logging.info("Cannot refresh: Indexing is still running") |
|
return gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False), "Cannot refresh: Indexing is still running" |
|
else: |
|
stop_indexing = threading.Event() |
|
indexing_thread = None |
|
return gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=True), "Indexing process refreshed. You can start indexing again." |
|
|
|
|
|
|
|
def run_indexing(root_dir, config_file, verbose, nocache, resume, reporter, emit_formats, custom_args): |
|
cmd = ["python", "-m", "graphrag.index", "--root", "./indexing"] |
|
|
|
|
|
if custom_args: |
|
cmd.extend(custom_args.split()) |
|
|
|
logging.info(f"Executing command: {' '.join(cmd)}") |
|
|
|
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, encoding='utf-8', universal_newlines=True) |
|
|
|
|
|
output = [] |
|
progress_value = 0 |
|
iterations_completed = 0 |
|
|
|
while True: |
|
if stop_indexing.is_set(): |
|
process.terminate() |
|
process.wait(timeout=5) |
|
if process.poll() is None: |
|
process.kill() |
|
return ("\n".join(output + ["Indexing stopped by user."]), |
|
"Indexing stopped.", |
|
100, |
|
gr.update(interactive=True), |
|
gr.update(interactive=False), |
|
gr.update(interactive=True), |
|
str(iterations_completed)) |
|
|
|
try: |
|
line = process.stdout.readline() |
|
if not line and process.poll() is not None: |
|
break |
|
|
|
if line: |
|
line = line.strip() |
|
output.append(line) |
|
|
|
if "Processing file" in line: |
|
progress_value += 1 |
|
iterations_completed += 1 |
|
elif "Indexing completed" in line: |
|
progress_value = 100 |
|
elif "ERROR" in line: |
|
line = f"🚨 ERROR: {line}" |
|
|
|
yield ("\n".join(output), |
|
line, |
|
progress_value, |
|
gr.update(interactive=False), |
|
gr.update(interactive=True), |
|
gr.update(interactive=False), |
|
str(iterations_completed)) |
|
except Exception as e: |
|
logging.error(f"Error during indexing: {str(e)}") |
|
return ("\n".join(output + [f"Error: {str(e)}"]), |
|
"Error occurred during indexing.", |
|
100, |
|
gr.update(interactive=True), |
|
gr.update(interactive=False), |
|
gr.update(interactive=True), |
|
str(iterations_completed)) |
|
|
|
if process.returncode != 0 and not stop_indexing.is_set(): |
|
final_output = "\n".join(output + [f"Error: Process exited with return code {process.returncode}"]) |
|
final_progress = "Indexing failed. Check output for details." |
|
else: |
|
final_output = "\n".join(output) |
|
final_progress = "Indexing completed successfully!" |
|
|
|
return (final_output, |
|
final_progress, |
|
100, |
|
gr.update(interactive=True), |
|
gr.update(interactive=False), |
|
gr.update(interactive=True), |
|
str(iterations_completed)) |
|
|
|
global_vector_store_wrapper = None |
|
|
|
def create_gradio_interface(): |
|
global global_vector_store_wrapper |
|
llm_models, embeddings_models, llm_service_type, embeddings_service_type, llm_api_base, embeddings_api_base, text_embedder = initialize_models() |
|
settings = load_settings() |
|
|
|
|
|
log_output = gr.TextArea(label="Logs", elem_id="log-output", interactive=False, visible=False) |
|
|
|
with gr.Blocks(css=custom_css, theme=gr.themes.Base()) as demo: |
|
gr.Markdown("# GraphRAG Local UI", elem_id="title") |
|
|
|
with gr.Row(elem_id="main-container"): |
|
with gr.Column(scale=1, elem_id="left-column"): |
|
with gr.Tabs(): |
|
with gr.TabItem("Data Management"): |
|
with gr.Accordion("File Upload (.txt)", open=True): |
|
file_upload = gr.File(label="Upload .txt File", file_types=[".txt"]) |
|
upload_btn = gr.Button("Upload File", variant="primary") |
|
upload_output = gr.Textbox(label="Upload Status", visible=False) |
|
|
|
with gr.Accordion("File Management", open=True): |
|
file_list = gr.Dropdown(label="Select File", choices=[], interactive=True) |
|
refresh_btn = gr.Button("Refresh File List", variant="secondary") |
|
|
|
file_content = gr.TextArea(label="File Content", lines=10) |
|
|
|
with gr.Row(): |
|
delete_btn = gr.Button("Delete Selected File", variant="stop") |
|
save_btn = gr.Button("Save Changes", variant="primary") |
|
|
|
operation_status = gr.Textbox(label="Operation Status", visible=False) |
|
|
|
|
|
|
|
with gr.TabItem("Indexing"): |
|
root_dir = gr.Textbox(label="Root Directory", value="./") |
|
config_file = gr.File(label="Config File (optional)") |
|
with gr.Row(): |
|
verbose = gr.Checkbox(label="Verbose", value=True) |
|
nocache = gr.Checkbox(label="No Cache", value=True) |
|
with gr.Row(): |
|
resume = gr.Textbox(label="Resume Timestamp (optional)") |
|
reporter = gr.Dropdown(label="Reporter", choices=["rich", "print", "none"], value=None) |
|
with gr.Row(): |
|
emit_formats = gr.CheckboxGroup(label="Emit Formats", choices=["json", "csv", "parquet"], value=None) |
|
with gr.Row(): |
|
run_index_button = gr.Button("Run Indexing") |
|
stop_index_button = gr.Button("Stop Indexing", variant="stop") |
|
refresh_index_button = gr.Button("Refresh Indexing", variant="secondary") |
|
|
|
with gr.Accordion("Custom CLI Arguments", open=True): |
|
custom_cli_args = gr.Textbox( |
|
label="Custom CLI Arguments", |
|
placeholder="--arg1 value1 --arg2 value2", |
|
lines=3 |
|
) |
|
cli_guide = gr.Markdown( |
|
textwrap.dedent(""" |
|
### CLI Argument Key Guide: |
|
- `--root <path>`: Set the root directory for the project |
|
- `--config <path>`: Specify a custom configuration file |
|
- `--verbose`: Enable verbose output |
|
- `--nocache`: Disable caching |
|
- `--resume <timestamp>`: Resume from a specific timestamp |
|
- `--reporter <type>`: Set the reporter type (rich, print, none) |
|
- `--emit <formats>`: Specify output formats (json, csv, parquet) |
|
|
|
Example: `--verbose --nocache --emit json,csv` |
|
""") |
|
) |
|
|
|
index_output = gr.Textbox(label="Indexing Output", lines=20, max_lines=30) |
|
index_progress = gr.Textbox(label="Indexing Progress", lines=3) |
|
iterations_completed = gr.Textbox(label="Iterations Completed", value="0") |
|
refresh_status = gr.Textbox(label="Refresh Status", visible=True) |
|
|
|
run_index_button.click( |
|
fn=start_indexing, |
|
inputs=[root_dir, config_file, verbose, nocache, resume, reporter, emit_formats, custom_cli_args], |
|
outputs=[run_index_button, stop_index_button, refresh_index_button] |
|
).then( |
|
fn=run_indexing, |
|
inputs=[root_dir, config_file, verbose, nocache, resume, reporter, emit_formats, custom_cli_args], |
|
outputs=[index_output, index_progress, run_index_button, stop_index_button, refresh_index_button, iterations_completed] |
|
) |
|
|
|
stop_index_button.click( |
|
fn=stop_indexing_process, |
|
outputs=[run_index_button, stop_index_button, refresh_index_button] |
|
) |
|
|
|
refresh_index_button.click( |
|
fn=refresh_indexing, |
|
outputs=[run_index_button, stop_index_button, refresh_index_button, refresh_status] |
|
) |
|
|
|
with gr.TabItem("Indexing Outputs/Visuals"): |
|
output_folder_list = gr.Dropdown(label="Select Output Folder (Select GraphML File to Visualize)", choices=list_output_folders("./indexing"), interactive=True) |
|
refresh_folder_btn = gr.Button("Refresh Folder List", variant="secondary") |
|
initialize_folder_btn = gr.Button("Initialize Selected Folder", variant="primary") |
|
folder_content_list = gr.Dropdown(label="Select File or Directory", choices=[], interactive=True) |
|
file_info = gr.Textbox(label="File Information", interactive=False) |
|
output_content = gr.TextArea(label="File Content", lines=20, interactive=False) |
|
initialization_status = gr.Textbox(label="Initialization Status") |
|
|
|
with gr.TabItem("LLM Settings"): |
|
llm_base_url = gr.Textbox(label="LLM API Base URL", value=os.getenv("LLM_API_BASE")) |
|
llm_api_key = gr.Textbox(label="LLM API Key", value=os.getenv("LLM_API_KEY"), type="password") |
|
llm_service_type = gr.Radio( |
|
label="LLM Service Type", |
|
choices=["openai", "ollama"], |
|
value="openai", |
|
visible=False |
|
) |
|
|
|
llm_model_dropdown = gr.Dropdown( |
|
label="LLM Model", |
|
choices=[], |
|
value=settings['llm'].get('model'), |
|
allow_custom_value=True |
|
) |
|
refresh_llm_models_btn = gr.Button("Refresh LLM Models", variant="secondary") |
|
|
|
embeddings_base_url = gr.Textbox(label="Embeddings API Base URL", value=os.getenv("EMBEDDINGS_API_BASE")) |
|
embeddings_api_key = gr.Textbox(label="Embeddings API Key", value=os.getenv("EMBEDDINGS_API_KEY"), type="password") |
|
embeddings_service_type = gr.Radio( |
|
label="Embeddings Service Type", |
|
choices=["openai", "ollama"], |
|
value=settings.get('embeddings', {}).get('llm', {}).get('type', 'openai'), |
|
visible=False, |
|
) |
|
|
|
embeddings_model_dropdown = gr.Dropdown( |
|
label="Embeddings Model", |
|
choices=[], |
|
value=settings.get('embeddings', {}).get('llm', {}).get('model'), |
|
allow_custom_value=True |
|
) |
|
refresh_embeddings_models_btn = gr.Button("Refresh Embedding Models", variant="secondary") |
|
system_message = gr.Textbox( |
|
lines=5, |
|
label="System Message", |
|
value=os.getenv("SYSTEM_MESSAGE", "You are a helpful AI assistant.") |
|
) |
|
context_window = gr.Slider( |
|
label="Context Window", |
|
minimum=512, |
|
maximum=32768, |
|
step=512, |
|
value=int(os.getenv("CONTEXT_WINDOW", 4096)) |
|
) |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
minimum=0.0, |
|
maximum=2.0, |
|
step=0.1, |
|
value=float(settings['llm'].get('TEMPERATURE', 0.5)) |
|
) |
|
max_tokens = gr.Slider( |
|
label="Max Tokens", |
|
minimum=1, |
|
maximum=8192, |
|
step=1, |
|
value=int(settings['llm'].get('MAX_TOKENS', 1024)) |
|
) |
|
update_settings_btn = gr.Button("Update LLM Settings", variant="primary") |
|
llm_settings_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
llm_base_url.change( |
|
fn=update_model_choices, |
|
inputs=[llm_base_url, llm_api_key, llm_service_type, gr.Textbox(value='llm', visible=False)], |
|
outputs=llm_model_dropdown |
|
) |
|
|
|
embeddings_service_type.change( |
|
fn=update_embeddings_model_choices, |
|
inputs=[embeddings_base_url, embeddings_api_key, embeddings_service_type], |
|
outputs=embeddings_model_dropdown |
|
) |
|
|
|
embeddings_base_url.change( |
|
fn=update_model_choices, |
|
inputs=[embeddings_base_url, embeddings_api_key, embeddings_service_type, gr.Textbox(value='embeddings', visible=False)], |
|
outputs=embeddings_model_dropdown |
|
) |
|
|
|
update_settings_btn.click( |
|
fn=update_llm_settings, |
|
inputs=[ |
|
llm_model_dropdown, |
|
embeddings_model_dropdown, |
|
context_window, |
|
system_message, |
|
temperature, |
|
max_tokens, |
|
llm_base_url, |
|
llm_api_key, |
|
embeddings_base_url, |
|
embeddings_api_key, |
|
embeddings_service_type |
|
], |
|
outputs=[llm_settings_status] |
|
) |
|
|
|
|
|
refresh_llm_models_btn.click( |
|
fn=update_model_choices, |
|
inputs=[llm_base_url, llm_api_key, llm_service_type, gr.Textbox(value='llm', visible=False)], |
|
outputs=[llm_model_dropdown] |
|
).then( |
|
fn=update_logs, |
|
outputs=[log_output] |
|
) |
|
|
|
refresh_embeddings_models_btn.click( |
|
fn=update_model_choices, |
|
inputs=[embeddings_base_url, embeddings_api_key, embeddings_service_type, gr.Textbox(value='embeddings', visible=False)], |
|
outputs=[embeddings_model_dropdown] |
|
).then( |
|
fn=update_logs, |
|
outputs=[log_output] |
|
) |
|
|
|
with gr.TabItem("YAML Settings"): |
|
settings = load_settings() |
|
with gr.Group(): |
|
for key, value in settings.items(): |
|
if key != 'llm': |
|
create_setting_component(key, value) |
|
|
|
with gr.Group(elem_id="log-container"): |
|
gr.Markdown("### Logs") |
|
log_output = gr.TextArea(label="Logs", elem_id="log-output", interactive=False) |
|
|
|
with gr.Column(scale=2, elem_id="right-column"): |
|
with gr.Group(elem_id="chat-container"): |
|
chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot") |
|
with gr.Row(elem_id="chat-input-row"): |
|
with gr.Column(scale=1): |
|
query_input = gr.Textbox( |
|
label="Input", |
|
placeholder="Enter your query here...", |
|
elem_id="query-input" |
|
) |
|
query_btn = gr.Button("Send Query", variant="primary") |
|
|
|
with gr.Accordion("Query Parameters", open=True): |
|
query_type = gr.Radio( |
|
["global", "local", "direct"], |
|
label="Query Type", |
|
value="global", |
|
info="Global: community-based search, Local: entity-based search, Direct: LLM chat" |
|
) |
|
preset_dropdown = gr.Dropdown( |
|
label="Preset Query Options", |
|
choices=[ |
|
"Default Global Search", |
|
"Default Local Search", |
|
"Detailed Global Analysis", |
|
"Detailed Local Analysis", |
|
"Quick Global Summary", |
|
"Quick Local Summary", |
|
"Global Bullet Points", |
|
"Local Bullet Points", |
|
"Comprehensive Global Report", |
|
"Comprehensive Local Report", |
|
"High-Level Global Overview", |
|
"High-Level Local Overview", |
|
"Focused Global Insight", |
|
"Focused Local Insight", |
|
"Custom Query" |
|
], |
|
value="Default Global Search", |
|
info="Select a preset or choose 'Custom Query' for manual configuration" |
|
) |
|
selected_folder = gr.Dropdown( |
|
label="Select Index Folder to Chat With", |
|
choices=list_output_folders("./indexing"), |
|
value=None, |
|
interactive=True |
|
) |
|
refresh_folder_btn = gr.Button("Refresh Folders", variant="secondary") |
|
clear_chat_btn = gr.Button("Clear Chat", variant="secondary") |
|
|
|
with gr.Group(visible=False) as custom_options: |
|
community_level = gr.Slider( |
|
label="Community Level", |
|
minimum=1, |
|
maximum=10, |
|
value=2, |
|
step=1, |
|
info="Higher values use reports on smaller communities" |
|
) |
|
response_type = gr.Dropdown( |
|
label="Response Type", |
|
choices=[ |
|
"Multiple Paragraphs", |
|
"Single Paragraph", |
|
"Single Sentence", |
|
"List of 3-7 Points", |
|
"Single Page", |
|
"Multi-Page Report" |
|
], |
|
value="Multiple Paragraphs", |
|
info="Specify the desired format of the response" |
|
) |
|
custom_cli_args = gr.Textbox( |
|
label="Custom CLI Arguments", |
|
placeholder="--arg1 value1 --arg2 value2", |
|
info="Additional CLI arguments for advanced users" |
|
) |
|
|
|
def update_custom_options(preset): |
|
if preset == "Custom Query": |
|
return gr.update(visible=True) |
|
else: |
|
return gr.update(visible=False) |
|
|
|
preset_dropdown.change(fn=update_custom_options, inputs=[preset_dropdown], outputs=[custom_options]) |
|
|
|
|
|
|
|
|
|
with gr.Group(elem_id="visualization-container"): |
|
vis_output = gr.Plot(label="Graph Visualization", elem_id="visualization-plot") |
|
with gr.Row(elem_id="vis-controls-row"): |
|
vis_btn = gr.Button("Visualize Graph", variant="secondary") |
|
|
|
|
|
with gr.Accordion("Visualization Settings", open=False): |
|
layout_type = gr.Dropdown(["3D Spring", "2D Spring", "Circular"], label="Layout Type", value="3D Spring") |
|
node_size = gr.Slider(1, 20, 7, label="Node Size", step=1) |
|
edge_width = gr.Slider(0.1, 5, 0.5, label="Edge Width", step=0.1) |
|
node_color_attribute = gr.Dropdown(["Degree", "Random"], label="Node Color Attribute", value="Degree") |
|
color_scheme = gr.Dropdown(["Viridis", "Plasma", "Inferno", "Magma", "Cividis"], label="Color Scheme", value="Viridis") |
|
show_labels = gr.Checkbox(label="Show Node Labels", value=True) |
|
label_size = gr.Slider(5, 20, 10, label="Label Size", step=1) |
|
|
|
|
|
|
|
upload_btn.click(fn=upload_file, inputs=[file_upload], outputs=[upload_output, file_list, log_output]) |
|
refresh_btn.click(fn=update_file_list, outputs=[file_list]).then( |
|
fn=update_logs, |
|
outputs=[log_output] |
|
) |
|
file_list.change(fn=update_file_content, inputs=[file_list], outputs=[file_content]).then( |
|
fn=update_logs, |
|
outputs=[log_output] |
|
) |
|
delete_btn.click(fn=delete_file, inputs=[file_list], outputs=[operation_status, file_list, log_output]) |
|
save_btn.click(fn=save_file_content, inputs=[file_list, file_content], outputs=[operation_status, log_output]) |
|
|
|
refresh_folder_btn.click( |
|
fn=lambda: gr.update(choices=list_output_folders("./indexing")), |
|
outputs=[selected_folder] |
|
) |
|
|
|
clear_chat_btn.click( |
|
fn=lambda: ([], ""), |
|
outputs=[chatbot, query_input] |
|
) |
|
|
|
refresh_folder_btn.click( |
|
fn=update_output_folder_list, |
|
outputs=[output_folder_list] |
|
).then( |
|
fn=update_logs, |
|
outputs=[log_output] |
|
) |
|
|
|
output_folder_list.change( |
|
fn=update_folder_content_list, |
|
inputs=[output_folder_list], |
|
outputs=[folder_content_list] |
|
).then( |
|
fn=update_logs, |
|
outputs=[log_output] |
|
) |
|
|
|
folder_content_list.change( |
|
fn=handle_content_selection, |
|
inputs=[output_folder_list, folder_content_list], |
|
outputs=[folder_content_list, file_info, output_content] |
|
).then( |
|
fn=update_logs, |
|
outputs=[log_output] |
|
) |
|
|
|
initialize_folder_btn.click( |
|
fn=initialize_selected_folder, |
|
inputs=[output_folder_list], |
|
outputs=[initialization_status, folder_content_list] |
|
).then( |
|
fn=update_logs, |
|
outputs=[log_output] |
|
) |
|
|
|
vis_btn.click( |
|
fn=update_visualization, |
|
inputs=[ |
|
output_folder_list, |
|
folder_content_list, |
|
layout_type, |
|
node_size, |
|
edge_width, |
|
node_color_attribute, |
|
color_scheme, |
|
show_labels, |
|
label_size |
|
], |
|
outputs=[vis_output, gr.Textbox(label="Visualization Status")] |
|
) |
|
|
|
query_btn.click( |
|
fn=send_message, |
|
inputs=[ |
|
query_type, |
|
query_input, |
|
chatbot, |
|
system_message, |
|
temperature, |
|
max_tokens, |
|
preset_dropdown, |
|
community_level, |
|
response_type, |
|
custom_cli_args, |
|
selected_folder |
|
], |
|
outputs=[chatbot, query_input, log_output] |
|
) |
|
|
|
query_input.submit( |
|
fn=send_message, |
|
inputs=[ |
|
query_type, |
|
query_input, |
|
chatbot, |
|
system_message, |
|
temperature, |
|
max_tokens, |
|
preset_dropdown, |
|
community_level, |
|
response_type, |
|
custom_cli_args, |
|
selected_folder |
|
], |
|
outputs=[chatbot, query_input, log_output] |
|
) |
|
refresh_llm_models_btn.click( |
|
fn=update_model_choices, |
|
inputs=[llm_base_url, llm_api_key, llm_service_type, gr.Textbox(value='llm', visible=False)], |
|
outputs=[llm_model_dropdown] |
|
) |
|
|
|
|
|
refresh_embeddings_models_btn.click( |
|
fn=update_model_choices, |
|
inputs=[embeddings_base_url, embeddings_api_key, embeddings_service_type, gr.Textbox(value='embeddings', visible=False)], |
|
outputs=[embeddings_model_dropdown] |
|
) |
|
|
|
|
|
demo.load(js=""" |
|
function addShiftEnterListener() { |
|
const queryInput = document.getElementById('query-input'); |
|
if (queryInput) { |
|
queryInput.addEventListener('keydown', function(event) { |
|
if (event.key === 'Enter' && event.shiftKey) { |
|
event.preventDefault(); |
|
const submitButton = queryInput.closest('.gradio-container').querySelector('button.primary'); |
|
if (submitButton) { |
|
submitButton.click(); |
|
} |
|
} |
|
}); |
|
} |
|
} |
|
document.addEventListener('DOMContentLoaded', addShiftEnterListener); |
|
""") |
|
|
|
return demo.queue() |
|
|
|
async def main(): |
|
api_port = 8088 |
|
gradio_port =78614 |
|
|
|
|
|
print(f"Starting API server on port {api_port}") |
|
start_api_server(api_port) |
|
|
|
|
|
threading.Thread(target=wait_for_api_server, args=(api_port,)).start() |
|
|
|
|
|
demo = create_gradio_interface() |
|
|
|
print(f"Starting Gradio app on port {gradio_port}") |
|
|
|
demo.launch(server_port=gradio_port, share=True) |
|
|
|
|
|
demo = create_gradio_interface() |
|
app = demo.app |
|
|
|
if __name__ == "__main__": |
|
initialize_data() |
|
demo.launch(server_port=7860, share=True) |
|
|