|
import gradio as gr |
|
import requests |
|
import logging |
|
import os |
|
import json |
|
import shutil |
|
import glob |
|
import queue |
|
import lancedb |
|
from datetime import datetime |
|
from dotenv import load_dotenv, set_key |
|
import yaml |
|
import pandas as pd |
|
from typing import List, Optional |
|
from pydantic import BaseModel |
|
|
|
|
|
log_queue = queue.Queue() |
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
load_dotenv('indexing/.env') |
|
|
|
API_BASE_URL = os.getenv('API_BASE_URL', 'http://localhost:8012') |
|
LLM_API_BASE = os.getenv('LLM_API_BASE', 'http://localhost:11434') |
|
EMBEDDINGS_API_BASE = os.getenv('EMBEDDINGS_API_BASE', 'http://localhost:11434') |
|
ROOT_DIR = os.getenv('ROOT_DIR', 'indexing') |
|
|
|
|
|
class IndexingRequest(BaseModel): |
|
llm_model: str |
|
embed_model: str |
|
llm_api_base: str |
|
embed_api_base: str |
|
root: str |
|
verbose: bool = False |
|
nocache: bool = False |
|
resume: Optional[str] = None |
|
reporter: str = "rich" |
|
emit: List[str] = ["parquet"] |
|
custom_args: Optional[str] = None |
|
|
|
class PromptTuneRequest(BaseModel): |
|
root: str = "./{ROOT_DIR}" |
|
domain: Optional[str] = None |
|
method: str = "random" |
|
limit: int = 15 |
|
language: Optional[str] = None |
|
max_tokens: int = 2000 |
|
chunk_size: int = 200 |
|
no_entity_types: bool = False |
|
output: str = "./{ROOT_DIR}/prompts" |
|
|
|
class QueueHandler(logging.Handler): |
|
def __init__(self, log_queue): |
|
super().__init__() |
|
self.log_queue = log_queue |
|
|
|
def emit(self, record): |
|
self.log_queue.put(self.format(record)) |
|
queue_handler = QueueHandler(log_queue) |
|
logging.getLogger().addHandler(queue_handler) |
|
|
|
|
|
def update_logs(): |
|
logs = [] |
|
while not log_queue.empty(): |
|
logs.append(log_queue.get()) |
|
return "\n".join(logs) |
|
|
|
|
|
def load_settings(): |
|
config_path = os.getenv('GRAPHRAG_CONFIG', 'config.yaml') |
|
if os.path.exists(config_path): |
|
with open(config_path, 'r') as config_file: |
|
config = yaml.safe_load(config_file) |
|
else: |
|
config = {} |
|
|
|
settings = { |
|
'llm_model': os.getenv('LLM_MODEL', config.get('llm_model')), |
|
'embedding_model': os.getenv('EMBEDDINGS_MODEL', config.get('embedding_model')), |
|
'community_level': int(os.getenv('COMMUNITY_LEVEL', config.get('community_level', 2))), |
|
'token_limit': int(os.getenv('TOKEN_LIMIT', config.get('token_limit', 4096))), |
|
'api_key': os.getenv('GRAPHRAG_API_KEY', config.get('api_key')), |
|
'api_base': os.getenv('LLM_API_BASE', config.get('api_base')), |
|
'embeddings_api_base': os.getenv('EMBEDDINGS_API_BASE', config.get('embeddings_api_base')), |
|
'api_type': os.getenv('API_TYPE', config.get('api_type', 'openai')), |
|
} |
|
|
|
return settings |
|
|
|
|
|
|
|
def list_output_files(root_dir): |
|
output_dir = os.path.join(root_dir, "output") |
|
files = [] |
|
for root, _, filenames in os.walk(output_dir): |
|
for filename in filenames: |
|
files.append(os.path.join(root, filename)) |
|
return files |
|
|
|
def update_file_list(): |
|
files = list_input_files() |
|
return gr.update(choices=[f["path"] for f in files]) |
|
|
|
def update_file_content(file_path): |
|
if not file_path: |
|
return "" |
|
try: |
|
with open(file_path, 'r', encoding='utf-8') as file: |
|
content = file.read() |
|
return content |
|
except Exception as e: |
|
logging.error(f"Error reading file: {str(e)}") |
|
return f"Error reading file: {str(e)}" |
|
|
|
def list_output_folders(): |
|
output_dir = os.path.join(ROOT_DIR, "output") |
|
folders = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f))] |
|
return sorted(folders, reverse=True) |
|
|
|
def update_output_folder_list(): |
|
folders = list_output_folders() |
|
return gr.update(choices=folders, value=folders[0] if folders else None) |
|
|
|
def list_folder_contents(folder_name): |
|
folder_path = os.path.join(ROOT_DIR, "output", folder_name, "artifacts") |
|
contents = [] |
|
if os.path.exists(folder_path): |
|
for item in os.listdir(folder_path): |
|
item_path = os.path.join(folder_path, item) |
|
if os.path.isdir(item_path): |
|
contents.append(f"[DIR] {item}") |
|
else: |
|
_, ext = os.path.splitext(item) |
|
contents.append(f"[{ext[1:].upper()}] {item}") |
|
return contents |
|
|
|
def update_folder_content_list(folder_name): |
|
if isinstance(folder_name, list) and folder_name: |
|
folder_name = folder_name[0] |
|
elif not folder_name: |
|
return gr.update(choices=[]) |
|
|
|
contents = list_folder_contents(folder_name) |
|
return gr.update(choices=contents) |
|
|
|
def handle_content_selection(folder_name, selected_item): |
|
if isinstance(selected_item, list) and selected_item: |
|
selected_item = selected_item[0] |
|
|
|
if isinstance(selected_item, str) and selected_item.startswith("[DIR]"): |
|
dir_name = selected_item[6:] |
|
sub_contents = list_folder_contents(os.path.join(ROOT_DIR, "output", folder_name, dir_name)) |
|
return gr.update(choices=sub_contents), "", "" |
|
elif isinstance(selected_item, str): |
|
file_name = selected_item.split("] ")[1] if "]" in selected_item else selected_item |
|
file_path = os.path.join(ROOT_DIR, "output", folder_name, "artifacts", file_name) |
|
file_size = os.path.getsize(file_path) |
|
file_type = os.path.splitext(file_name)[1] |
|
file_info = f"File: {file_name}\nSize: {file_size} bytes\nType: {file_type}" |
|
content = read_file_content(file_path) |
|
return gr.update(), file_info, content |
|
else: |
|
return gr.update(), "", "" |
|
|
|
def initialize_selected_folder(folder_name): |
|
if not folder_name: |
|
return "Please select a folder first.", gr.update(choices=[]) |
|
folder_path = os.path.join(ROOT_DIR, "output", folder_name, "artifacts") |
|
if not os.path.exists(folder_path): |
|
return f"Artifacts folder not found in '{folder_name}'.", gr.update(choices=[]) |
|
contents = list_folder_contents(folder_path) |
|
return f"Folder '{folder_name}/artifacts' initialized with {len(contents)} items.", gr.update(choices=contents) |
|
|
|
def upload_file(file): |
|
if file is not None: |
|
input_dir = os.path.join(ROOT_DIR, 'input') |
|
os.makedirs(input_dir, exist_ok=True) |
|
|
|
|
|
original_filename = file.name |
|
|
|
|
|
destination_path = os.path.join(input_dir, os.path.basename(original_filename)) |
|
|
|
|
|
shutil.move(file.name, destination_path) |
|
|
|
logging.info(f"File uploaded and moved to: {destination_path}") |
|
status = f"File uploaded: {os.path.basename(original_filename)}" |
|
else: |
|
status = "No file uploaded" |
|
|
|
|
|
updated_file_list = [f["path"] for f in list_input_files()] |
|
|
|
return status, gr.update(choices=updated_file_list), update_logs() |
|
|
|
def list_input_files(): |
|
input_dir = os.path.join(ROOT_DIR, 'input') |
|
files = [] |
|
if os.path.exists(input_dir): |
|
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] |
|
return [{"name": f, "path": os.path.join(input_dir, f)} for f in files] |
|
|
|
def delete_file(file_path): |
|
try: |
|
os.remove(file_path) |
|
logging.info(f"File deleted: {file_path}") |
|
status = f"File deleted: {os.path.basename(file_path)}" |
|
except Exception as e: |
|
logging.error(f"Error deleting file: {str(e)}") |
|
status = f"Error deleting file: {str(e)}" |
|
|
|
|
|
updated_file_list = [f["path"] for f in list_input_files()] |
|
|
|
return status, gr.update(choices=updated_file_list), update_logs() |
|
|
|
def read_file_content(file_path): |
|
try: |
|
if file_path.endswith('.parquet'): |
|
df = pd.read_parquet(file_path) |
|
|
|
|
|
info = f"Parquet File: {os.path.basename(file_path)}\n" |
|
info += f"Rows: {len(df)}, Columns: {len(df.columns)}\n\n" |
|
info += "Column Names:\n" + "\n".join(df.columns) + "\n\n" |
|
|
|
|
|
info += "First 5 rows:\n" |
|
info += df.head().to_string() + "\n\n" |
|
|
|
|
|
info += "Basic Statistics:\n" |
|
info += df.describe().to_string() |
|
|
|
return info |
|
else: |
|
with open(file_path, 'r', encoding='utf-8', errors='replace') as file: |
|
content = file.read() |
|
return content |
|
except Exception as e: |
|
logging.error(f"Error reading file: {str(e)}") |
|
return f"Error reading file: {str(e)}" |
|
|
|
def save_file_content(file_path, content): |
|
try: |
|
with open(file_path, 'w') as file: |
|
file.write(content) |
|
logging.info(f"File saved: {file_path}") |
|
status = f"File saved: {os.path.basename(file_path)}" |
|
except Exception as e: |
|
logging.error(f"Error saving file: {str(e)}") |
|
status = f"Error saving file: {str(e)}" |
|
return status, update_logs() |
|
|
|
def manage_data(): |
|
db = lancedb.connect(f"{ROOT_DIR}/lancedb") |
|
tables = db.table_names() |
|
table_info = "" |
|
if tables: |
|
table = db[tables[0]] |
|
table_info = f"Table: {tables[0]}\nSchema: {table.schema}" |
|
|
|
input_files = list_input_files() |
|
|
|
return { |
|
"database_info": f"Tables: {', '.join(tables)}\n\n{table_info}", |
|
"input_files": input_files |
|
} |
|
|
|
|
|
def find_latest_graph_file(root_dir): |
|
pattern = os.path.join(root_dir, "output", "*", "artifacts", "*.graphml") |
|
graph_files = glob.glob(pattern) |
|
if not graph_files: |
|
|
|
output_dir = os.path.join(root_dir, "output") |
|
run_dirs = [d for d in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, d)) and d != ".DS_Store"] |
|
if run_dirs: |
|
latest_run = max(run_dirs) |
|
pattern = os.path.join(root_dir, "output", latest_run, "artifacts", "*.graphml") |
|
graph_files = glob.glob(pattern) |
|
|
|
if not graph_files: |
|
return None |
|
|
|
|
|
latest_file = max(graph_files, key=os.path.getmtime) |
|
return latest_file |
|
|
|
def find_latest_output_folder(): |
|
root_dir =f"{ROOT_DIR}/output" |
|
folders = [f for f in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, f))] |
|
|
|
if not folders: |
|
raise ValueError("No output folders found") |
|
|
|
|
|
sorted_folders = sorted(folders, key=lambda x: os.path.getctime(os.path.join(root_dir, x)), reverse=True) |
|
|
|
latest_folder = None |
|
timestamp = None |
|
|
|
for folder in sorted_folders: |
|
try: |
|
|
|
timestamp = datetime.strptime(folder, "%Y%m%d-%H%M%S") |
|
latest_folder = folder |
|
break |
|
except ValueError: |
|
|
|
continue |
|
|
|
if latest_folder is None: |
|
raise ValueError("No valid timestamp folders found") |
|
|
|
latest_path = os.path.join(root_dir, latest_folder) |
|
artifacts_path = os.path.join(latest_path, "artifacts") |
|
|
|
if not os.path.exists(artifacts_path): |
|
raise ValueError(f"Artifacts folder not found in {latest_path}") |
|
|
|
return latest_path, latest_folder |
|
|
|
def initialize_data(): |
|
global entity_df, relationship_df, text_unit_df, report_df, covariate_df |
|
|
|
tables = { |
|
"entity_df": "create_final_nodes", |
|
"relationship_df": "create_final_edges", |
|
"text_unit_df": "create_final_text_units", |
|
"report_df": "create_final_reports", |
|
"covariate_df": "create_final_covariates" |
|
} |
|
|
|
timestamp = None |
|
|
|
try: |
|
latest_output_folder, timestamp = find_latest_output_folder() |
|
artifacts_folder = os.path.join(latest_output_folder, "artifacts") |
|
|
|
for df_name, file_prefix in tables.items(): |
|
file_pattern = os.path.join(artifacts_folder, f"{file_prefix}*.parquet") |
|
matching_files = glob.glob(file_pattern) |
|
|
|
if matching_files: |
|
latest_file = max(matching_files, key=os.path.getctime) |
|
df = pd.read_parquet(latest_file) |
|
globals()[df_name] = df |
|
logging.info(f"Successfully loaded {df_name} from {latest_file}") |
|
else: |
|
logging.warning(f"No matching file found for {df_name} in {artifacts_folder}. Initializing as an empty DataFrame.") |
|
globals()[df_name] = pd.DataFrame() |
|
|
|
except Exception as e: |
|
logging.error(f"Error initializing data: {str(e)}") |
|
for df_name in tables.keys(): |
|
globals()[df_name] = pd.DataFrame() |
|
|
|
return timestamp |
|
|
|
|
|
current_timestamp = initialize_data() |
|
|
|
|
|
|
|
def normalize_api_base(api_base: str) -> str: |
|
"""Normalize the API base URL by removing trailing slashes and /v1 or /api suffixes.""" |
|
api_base = api_base.rstrip('/') |
|
if api_base.endswith('/v1') or api_base.endswith('/api'): |
|
api_base = api_base[:-3] |
|
return api_base |
|
|
|
def is_ollama_api(base_url: str) -> bool: |
|
"""Check if the given base URL is for Ollama API.""" |
|
try: |
|
response = requests.get(f"{normalize_api_base(base_url)}/api/tags") |
|
return response.status_code == 200 |
|
except requests.RequestException: |
|
return False |
|
|
|
def get_ollama_models(base_url: str) -> List[str]: |
|
"""Fetch available models from Ollama API.""" |
|
try: |
|
response = requests.get(f"{normalize_api_base(base_url)}/api/tags") |
|
response.raise_for_status() |
|
models = response.json().get('models', []) |
|
return [model['name'] for model in models] |
|
except requests.RequestException as e: |
|
logger.error(f"Error fetching Ollama models: {str(e)}") |
|
return [] |
|
|
|
def get_openai_compatible_models(base_url: str) -> List[str]: |
|
"""Fetch available models from OpenAI-compatible API.""" |
|
try: |
|
response = requests.get(f"{normalize_api_base(base_url)}/v1/models") |
|
response.raise_for_status() |
|
models = response.json().get('data', []) |
|
return [model['id'] for model in models] |
|
except requests.RequestException as e: |
|
logger.error(f"Error fetching OpenAI-compatible models: {str(e)}") |
|
return [] |
|
|
|
def get_local_models(base_url: str) -> List[str]: |
|
"""Get available models based on the API type.""" |
|
if is_ollama_api(base_url): |
|
return get_ollama_models(base_url) |
|
else: |
|
return get_openai_compatible_models(base_url) |
|
|
|
def get_model_params(base_url: str, model_name: str) -> dict: |
|
"""Get model parameters for Ollama models.""" |
|
if is_ollama_api(base_url): |
|
try: |
|
response = requests.post(f"{normalize_api_base(base_url)}/api/show", json={"name": model_name}) |
|
response.raise_for_status() |
|
model_info = response.json() |
|
return model_info.get('parameters', {}) |
|
except requests.RequestException as e: |
|
logger.error(f"Error fetching Ollama model parameters: {str(e)}") |
|
return {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def start_indexing(request: IndexingRequest): |
|
url = f"{API_BASE_URL}/v1/index" |
|
|
|
try: |
|
response = requests.post(url, json=request.dict()) |
|
response.raise_for_status() |
|
result = response.json() |
|
return result['message'], gr.update(interactive=False), gr.update(interactive=True) |
|
except requests.RequestException as e: |
|
logger.error(f"Error starting indexing: {str(e)}") |
|
return f"Error: {str(e)}", gr.update(interactive=True), gr.update(interactive=False) |
|
|
|
def check_indexing_status(): |
|
url = f"{API_BASE_URL}/v1/index_status" |
|
try: |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
result = response.json() |
|
return result['status'], "\n".join(result['logs']) |
|
except requests.RequestException as e: |
|
logger.error(f"Error checking indexing status: {str(e)}") |
|
return "Error", f"Failed to check indexing status: {str(e)}" |
|
|
|
def start_prompt_tuning(request: PromptTuneRequest): |
|
url = f"{API_BASE_URL}/v1/prompt_tune" |
|
|
|
try: |
|
response = requests.post(url, json=request.dict()) |
|
response.raise_for_status() |
|
result = response.json() |
|
return result['message'], gr.update(interactive=False) |
|
except requests.RequestException as e: |
|
logger.error(f"Error starting prompt tuning: {str(e)}") |
|
return f"Error: {str(e)}", gr.update(interactive=True) |
|
|
|
def check_prompt_tuning_status(): |
|
url = f"{API_BASE_URL}/v1/prompt_tune_status" |
|
try: |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
result = response.json() |
|
return result['status'], "\n".join(result['logs']) |
|
except requests.RequestException as e: |
|
logger.error(f"Error checking prompt tuning status: {str(e)}") |
|
return "Error", f"Failed to check prompt tuning status: {str(e)}" |
|
|
|
def update_model_params(model_name): |
|
params = get_model_params(model_name) |
|
return gr.update(value=json.dumps(params, indent=2)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
css = """ |
|
html, body { |
|
margin: 0; |
|
padding: 0; |
|
height: 100vh; |
|
overflow: hidden; |
|
} |
|
|
|
.gradio-container { |
|
margin: 0 !important; |
|
padding: 0 !important; |
|
width: 100vw !important; |
|
max-width: 100vw !important; |
|
height: 100vh !important; |
|
max-height: 100vh !important; |
|
overflow: auto; |
|
display: flex; |
|
flex-direction: column; |
|
} |
|
|
|
#main-container { |
|
flex: 1; |
|
display: flex; |
|
overflow: hidden; |
|
} |
|
|
|
#left-column, #right-column { |
|
height: 100%; |
|
overflow-y: auto; |
|
padding: 10px; |
|
} |
|
|
|
#left-column { |
|
flex: 1; |
|
} |
|
|
|
#right-column { |
|
flex: 2; |
|
display: flex; |
|
flex-direction: column; |
|
} |
|
|
|
#chat-container { |
|
flex: 0 0 auto; /* Don't allow this to grow */ |
|
height: 100%; |
|
display: flex; |
|
flex-direction: column; |
|
overflow: hidden; |
|
border: 1px solid var(--color-accent); |
|
border-radius: 8px; |
|
padding: 10px; |
|
overflow-y: auto; |
|
} |
|
|
|
#chatbot { |
|
overflow-y: hidden; |
|
height: 100%; |
|
} |
|
|
|
#chat-input-row { |
|
margin-top: 10px; |
|
} |
|
|
|
#visualization-plot { |
|
width: 100%; |
|
aspect-ratio: 1 / 1; |
|
max-height: 600px; /* Adjust this value as needed */ |
|
} |
|
|
|
#vis-controls-row { |
|
display: flex; |
|
justify-content: space-between; |
|
align-items: center; |
|
margin-top: 10px; |
|
} |
|
|
|
#vis-controls-row > * { |
|
flex: 1; |
|
margin: 0 5px; |
|
} |
|
|
|
#vis-status { |
|
margin-top: 10px; |
|
} |
|
|
|
/* Chat input styling */ |
|
#chat-input-row { |
|
display: flex; |
|
flex-direction: column; |
|
} |
|
|
|
#chat-input-row > div { |
|
width: 100% !important; |
|
} |
|
|
|
#chat-input-row input[type="text"] { |
|
width: 100% !important; |
|
} |
|
|
|
/* Adjust padding for all containers */ |
|
.gr-box, .gr-form, .gr-panel { |
|
padding: 10px !important; |
|
} |
|
|
|
/* Ensure all textboxes and textareas have full height */ |
|
.gr-textbox, .gr-textarea { |
|
height: auto !important; |
|
min-height: 100px !important; |
|
} |
|
|
|
/* Ensure all dropdowns have full width */ |
|
.gr-dropdown { |
|
width: 100% !important; |
|
} |
|
|
|
:root { |
|
--color-background: #2C3639; |
|
--color-foreground: #3F4E4F; |
|
--color-accent: #A27B5C; |
|
--color-text: #DCD7C9; |
|
} |
|
|
|
body, .gradio-container { |
|
background-color: var(--color-background); |
|
color: var(--color-text); |
|
} |
|
|
|
.gr-button { |
|
background-color: var(--color-accent); |
|
color: var(--color-text); |
|
} |
|
|
|
.gr-input, .gr-textarea, .gr-dropdown { |
|
background-color: var(--color-foreground); |
|
color: var(--color-text); |
|
border: 1px solid var(--color-accent); |
|
} |
|
|
|
.gr-panel { |
|
background-color: var(--color-foreground); |
|
border: 1px solid var(--color-accent); |
|
} |
|
|
|
.gr-box { |
|
border-radius: 8px; |
|
margin-bottom: 10px; |
|
background-color: var(--color-foreground); |
|
} |
|
|
|
.gr-padded { |
|
padding: 10px; |
|
} |
|
|
|
.gr-form { |
|
background-color: var(--color-foreground); |
|
} |
|
|
|
.gr-input-label, .gr-radio-label { |
|
color: var(--color-text); |
|
} |
|
|
|
.gr-checkbox-label { |
|
color: var(--color-text); |
|
} |
|
|
|
.gr-markdown { |
|
color: var(--color-text); |
|
} |
|
|
|
.gr-accordion { |
|
background-color: var(--color-foreground); |
|
border: 1px solid var(--color-accent); |
|
} |
|
|
|
.gr-accordion-header { |
|
background-color: var(--color-accent); |
|
color: var(--color-text); |
|
} |
|
|
|
#visualization-container { |
|
display: flex; |
|
flex-direction: column; |
|
border: 2px solid var(--color-accent); |
|
border-radius: 8px; |
|
margin-top: 20px; |
|
padding: 10px; |
|
background-color: var(--color-foreground); |
|
height: calc(100vh - 300px); /* Adjust this value as needed */ |
|
} |
|
|
|
#visualization-plot { |
|
width: 100%; |
|
height: 100%; |
|
} |
|
|
|
#vis-controls-row { |
|
display: flex; |
|
justify-content: space-between; |
|
align-items: center; |
|
margin-top: 10px; |
|
} |
|
|
|
#vis-controls-row > * { |
|
flex: 1; |
|
margin: 0 5px; |
|
} |
|
|
|
#vis-status { |
|
margin-top: 10px; |
|
} |
|
|
|
#log-container { |
|
background-color: var(--color-foreground); |
|
border: 1px solid var(--color-accent); |
|
border-radius: 8px; |
|
padding: 10px; |
|
margin-top: 20px; |
|
max-height: auto; |
|
overflow-y: auto; |
|
} |
|
|
|
.setting-accordion .label-wrap { |
|
cursor: pointer; |
|
} |
|
|
|
.setting-accordion .icon { |
|
transition: transform 0.3s ease; |
|
} |
|
|
|
.setting-accordion[open] .icon { |
|
transform: rotate(90deg); |
|
} |
|
|
|
.gr-form.gr-box { |
|
border: none !important; |
|
background: none !important; |
|
} |
|
|
|
.model-params { |
|
border-top: 1px solid var(--color-accent); |
|
margin-top: 10px; |
|
padding-top: 10px; |
|
} |
|
""" |
|
|
|
|
|
def create_interface(): |
|
settings = load_settings() |
|
llm_api_base = normalize_api_base(settings['api_base']) |
|
embeddings_api_base = normalize_api_base(settings['embeddings_api_base']) |
|
|
|
with gr.Blocks(theme=gr.themes.Base(), css=css) as demo: |
|
gr.Markdown("# GraphRAG Indexer") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Indexing"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("## Indexing Configuration") |
|
|
|
with gr.Row(): |
|
llm_name = gr.Dropdown(label="LLM Model", choices=[], value=settings['llm_model'], allow_custom_value=True) |
|
refresh_llm_btn = gr.Button("🔄", size='sm', scale=0) |
|
|
|
with gr.Row(): |
|
embed_name = gr.Dropdown(label="Embedding Model", choices=[], value=settings['embedding_model'], allow_custom_value=True) |
|
refresh_embed_btn = gr.Button("🔄", size='sm', scale=0) |
|
|
|
save_config_button = gr.Button("Save Configuration", variant="primary") |
|
config_status = gr.Textbox(label="Configuration Status", lines=2) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
root_dir = gr.Textbox(label="Root Directory (Edit in .env file)", value=f"{ROOT_DIR}") |
|
with gr.Group(): |
|
verbose = gr.Checkbox(label="Verbose", interactive=True, value=True) |
|
nocache = gr.Checkbox(label="No Cache", interactive=True, value=True) |
|
|
|
with gr.Accordion("Advanced Options", open=True): |
|
resume = gr.Textbox(label="Resume Timestamp (optional)") |
|
reporter = gr.Dropdown( |
|
label="Reporter", |
|
choices=["rich", "print", "none"], |
|
value="rich", |
|
interactive=True |
|
) |
|
emit_formats = gr.CheckboxGroup( |
|
label="Emit Formats", |
|
choices=["json", "csv", "parquet"], |
|
value=["parquet"], |
|
interactive=True |
|
) |
|
custom_args = gr.Textbox(label="Custom CLI Arguments", placeholder="--arg1 value1 --arg2 value2") |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("## Indexing Output") |
|
index_output = gr.Textbox(label="Output", lines=10) |
|
index_status = gr.Textbox(label="Status", lines=2) |
|
|
|
run_index_button = gr.Button("Run Indexing", variant="primary") |
|
check_status_button = gr.Button("Check Indexing Status") |
|
|
|
|
|
with gr.TabItem("Prompt Tuning"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("## Prompt Tuning Configuration") |
|
|
|
pt_root = gr.Textbox(label="Root Directory", value=f"{ROOT_DIR}", interactive=True) |
|
pt_domain = gr.Textbox(label="Domain (optional)") |
|
pt_method = gr.Dropdown( |
|
label="Method", |
|
choices=["random", "top", "all"], |
|
value="random", |
|
interactive=True |
|
) |
|
pt_limit = gr.Number(label="Limit", value=15, precision=0, interactive=True) |
|
pt_language = gr.Textbox(label="Language (optional)") |
|
pt_max_tokens = gr.Number(label="Max Tokens", value=2000, precision=0, interactive=True) |
|
pt_chunk_size = gr.Number(label="Chunk Size", value=200, precision=0, interactive=True) |
|
pt_no_entity_types = gr.Checkbox(label="No Entity Types", value=False) |
|
pt_output_dir = gr.Textbox(label="Output Directory", value=f"{ROOT_DIR}/prompts", interactive=True) |
|
save_pt_config_button = gr.Button("Save Prompt Tuning Configuration", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("## Prompt Tuning Output") |
|
pt_output = gr.Textbox(label="Output", lines=10) |
|
pt_status = gr.Textbox(label="Status", lines=10) |
|
|
|
run_pt_button = gr.Button("Run Prompt Tuning", variant="primary") |
|
check_pt_status_button = gr.Button("Check Prompt Tuning Status") |
|
|
|
with gr.TabItem("Data Management"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
with gr.Accordion("File Upload", open=True): |
|
file_upload = gr.File(label="Upload File", file_types=[".txt", ".csv", ".parquet"]) |
|
upload_btn = gr.Button("Upload File", variant="primary") |
|
upload_output = gr.Textbox(label="Upload Status", visible=True) |
|
|
|
with gr.Accordion("File Management", open=True): |
|
file_list = gr.Dropdown(label="Select File", choices=[], interactive=True) |
|
refresh_btn = gr.Button("Refresh File List", variant="secondary") |
|
|
|
file_content = gr.TextArea(label="File Content", lines=10) |
|
|
|
with gr.Row(): |
|
delete_btn = gr.Button("Delete Selected File", variant="stop") |
|
save_btn = gr.Button("Save Changes", variant="primary") |
|
|
|
operation_status = gr.Textbox(label="Operation Status", visible=True) |
|
|
|
with gr.Column(scale=1): |
|
with gr.Accordion("Output Folders", open=True): |
|
output_folder_list = gr.Dropdown(label="Select Output Folder", choices=[], interactive=True) |
|
refresh_output_btn = gr.Button("Refresh Output Folders", variant="secondary") |
|
folder_content_list = gr.Dropdown(label="Folder Contents", choices=[], interactive=True, multiselect=False) |
|
|
|
file_info = gr.Textbox(label="File Info", lines=3) |
|
output_content = gr.TextArea(label="File Content", lines=10) |
|
|
|
|
|
|
|
|
|
def refresh_llm_models(): |
|
models = get_local_models(llm_api_base) |
|
return gr.update(choices=models) |
|
|
|
def refresh_embed_models(): |
|
models = get_local_models(embeddings_api_base) |
|
return gr.update(choices=models) |
|
|
|
refresh_llm_btn.click( |
|
refresh_llm_models, |
|
outputs=[llm_name] |
|
) |
|
|
|
refresh_embed_btn.click( |
|
refresh_embed_models, |
|
outputs=[embed_name] |
|
) |
|
|
|
|
|
demo.load(refresh_llm_models, outputs=[llm_name]) |
|
demo.load(refresh_embed_models, outputs=[embed_name]) |
|
|
|
def create_indexing_request(): |
|
return IndexingRequest( |
|
llm_model=llm_name.value, |
|
embed_model=embed_name.value, |
|
llm_api_base=llm_api_base, |
|
embed_api_base=embeddings_api_base, |
|
root=root_dir.value, |
|
verbose=verbose.value, |
|
nocache=nocache.value, |
|
resume=resume.value if resume.value else None, |
|
reporter=reporter.value, |
|
emit=[fmt for fmt in emit_formats.value], |
|
custom_args=custom_args.value if custom_args.value else None |
|
) |
|
|
|
run_index_button.click( |
|
lambda: start_indexing(create_indexing_request()), |
|
outputs=[index_output, run_index_button, check_status_button] |
|
) |
|
|
|
check_status_button.click( |
|
check_indexing_status, |
|
outputs=[index_status, index_output] |
|
) |
|
|
|
def create_prompt_tune_request(): |
|
return PromptTuneRequest( |
|
root=pt_root.value, |
|
domain=pt_domain.value if pt_domain.value else None, |
|
method=pt_method.value, |
|
limit=int(pt_limit.value), |
|
language=pt_language.value if pt_language.value else None, |
|
max_tokens=int(pt_max_tokens.value), |
|
chunk_size=int(pt_chunk_size.value), |
|
no_entity_types=pt_no_entity_types.value, |
|
output=pt_output_dir.value |
|
) |
|
|
|
def update_pt_output(request): |
|
result, button_update = start_prompt_tuning(request) |
|
return result, button_update, gr.update(value=f"Request: {request.dict()}") |
|
|
|
run_pt_button.click( |
|
lambda: update_pt_output(create_prompt_tune_request()), |
|
outputs=[pt_output, run_pt_button, pt_status] |
|
) |
|
|
|
check_pt_status_button.click( |
|
check_prompt_tuning_status, |
|
outputs=[pt_status, pt_output] |
|
) |
|
|
|
|
|
pt_root.change(lambda x: gr.update(value=f"Root Directory changed to: {x}"), inputs=[pt_root], outputs=[pt_status]) |
|
pt_limit.change(lambda x: gr.update(value=f"Limit changed to: {x}"), inputs=[pt_limit], outputs=[pt_status]) |
|
pt_max_tokens.change(lambda x: gr.update(value=f"Max Tokens changed to: {x}"), inputs=[pt_max_tokens], outputs=[pt_status]) |
|
pt_chunk_size.change(lambda x: gr.update(value=f"Chunk Size changed to: {x}"), inputs=[pt_chunk_size], outputs=[pt_status]) |
|
pt_output_dir.change(lambda x: gr.update(value=f"Output Directory changed to: {x}"), inputs=[pt_output_dir], outputs=[pt_status]) |
|
|
|
|
|
upload_btn.click( |
|
upload_file, |
|
inputs=[file_upload], |
|
outputs=[upload_output, file_list, operation_status] |
|
) |
|
|
|
refresh_btn.click( |
|
update_file_list, |
|
outputs=[file_list] |
|
) |
|
|
|
refresh_output_btn.click( |
|
update_output_folder_list, |
|
outputs=[output_folder_list] |
|
) |
|
|
|
file_list.change( |
|
update_file_content, |
|
inputs=[file_list], |
|
outputs=[file_content] |
|
) |
|
|
|
delete_btn.click( |
|
delete_file, |
|
inputs=[file_list], |
|
outputs=[operation_status, file_list, operation_status] |
|
) |
|
|
|
save_btn.click( |
|
save_file_content, |
|
inputs=[file_list, file_content], |
|
outputs=[operation_status, operation_status] |
|
) |
|
|
|
output_folder_list.change( |
|
update_folder_content_list, |
|
inputs=[output_folder_list], |
|
outputs=[folder_content_list] |
|
) |
|
|
|
folder_content_list.change( |
|
handle_content_selection, |
|
inputs=[output_folder_list, folder_content_list], |
|
outputs=[folder_content_list, file_info, output_content] |
|
) |
|
|
|
|
|
save_config_button.click( |
|
update_env_file, |
|
inputs=[llm_name, embed_name], |
|
outputs=[config_status] |
|
) |
|
|
|
|
|
save_pt_config_button.click( |
|
save_prompt_tuning_config, |
|
inputs=[pt_root, pt_domain, pt_method, pt_limit, pt_language, pt_max_tokens, pt_chunk_size, pt_no_entity_types, pt_output_dir], |
|
outputs=[pt_status] |
|
) |
|
|
|
|
|
demo.load(update_file_list, outputs=[file_list]) |
|
demo.load(update_output_folder_list, outputs=[output_folder_list]) |
|
|
|
return demo |
|
|
|
def update_env_file(llm_model, embed_model): |
|
env_path = os.path.join(ROOT_DIR, '.env') |
|
|
|
set_key(env_path, 'LLM_MODEL', llm_model) |
|
set_key(env_path, 'EMBEDDINGS_MODEL', embed_model) |
|
|
|
|
|
load_dotenv(env_path, override=True) |
|
|
|
return f"Environment updated: LLM_MODEL={llm_model}, EMBEDDINGS_MODEL={embed_model}" |
|
|
|
def save_prompt_tuning_config(root, domain, method, limit, language, max_tokens, chunk_size, no_entity_types, output_dir): |
|
config = { |
|
'prompt_tuning': { |
|
'root': root, |
|
'domain': domain, |
|
'method': method, |
|
'limit': limit, |
|
'language': language, |
|
'max_tokens': max_tokens, |
|
'chunk_size': chunk_size, |
|
'no_entity_types': no_entity_types, |
|
'output': output_dir |
|
} |
|
} |
|
|
|
config_path = os.path.join(ROOT_DIR, 'prompt_tuning_config.yaml') |
|
with open(config_path, 'w') as f: |
|
yaml.dump(config, f) |
|
|
|
return f"Prompt Tuning configuration saved to {config_path}" |
|
|
|
demo = create_interface() |
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_port=7860) |
|
|
|
|