|
import datetime |
|
|
|
import gspread |
|
import random |
|
import time |
|
import functools |
|
from gspread.exceptions import SpreadsheetNotFound, APIError |
|
from oauth2client.service_account import ServiceAccountCredentials |
|
import pandas as pd |
|
import json |
|
import gradio as gr |
|
import os |
|
|
|
GSERVICE_ACCOUNT_INFO = { |
|
"type": "service_account", |
|
"project_id": "txagent", |
|
"private_key_id": "cc1a12e427917244a93faf6f19e72b589a685e65", |
|
"private_key": None, |
|
"client_email": "[email protected]", |
|
"client_id": "108950722202634464257", |
|
"auth_uri": "https://accounts.google.com/o/oauth2/auth", |
|
"token_uri": "https://oauth2.googleapis.com/token", |
|
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", |
|
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/shanghua%40txagent.iam.gserviceaccount.com", |
|
"universe_domain": "googleapis.com" |
|
} |
|
GSHEET_NAME = "TxAgent_data_collection" |
|
|
|
GSheet_API_KEY = os.environ.get("GSheets_Shanghua_PrivateKey") |
|
if GSheet_API_KEY is None: |
|
print("GSheet_API_KEY not found in environment variables. Please set it.") |
|
else: |
|
GSheet_API_KEY = GSheet_API_KEY.replace("\\n", "\n") |
|
GSERVICE_ACCOUNT_INFO["private_key"] = GSheet_API_KEY |
|
|
|
|
|
def exponential_backoff_gspread(max_retries=30, max_backoff_sec=64, base_delay_sec=1, target_exception=APIError): |
|
""" |
|
Decorator to implement exponential backoff for gspread API calls. |
|
|
|
Retries a function call if it raises a specific exception (defaults to APIError) |
|
that matches the Google Sheets API rate limit error (HTTP 429). |
|
|
|
Args: |
|
max_retries (int): Maximum number of retry attempts. |
|
max_backoff_sec (int): Maximum delay between retries in seconds. |
|
base_delay_sec (int): Initial delay in seconds for the first retry. |
|
target_exception (Exception): The specific exception type to catch. |
|
""" |
|
def decorator(func): |
|
@functools.wraps(func) |
|
def wrapper(*args, **kwargs): |
|
retries = 0 |
|
while True: |
|
try: |
|
|
|
return func(*args, **kwargs) |
|
except target_exception as e: |
|
|
|
|
|
error_message = str(e) |
|
is_rate_limit_error = "[429]" in error_message and ( |
|
"Quota exceeded" in error_message or "Too Many Requests" in error_message |
|
) |
|
|
|
if is_rate_limit_error: |
|
retries += 1 |
|
if retries > max_retries: |
|
print(f"Max retries ({max_retries}) exceeded for {func.__name__}. Last error: {e}") |
|
raise e |
|
|
|
|
|
backoff_delay = min(max_backoff_sec, base_delay_sec * (2 ** (retries - 1)) + random.uniform(0, 1)) |
|
|
|
print( |
|
f"Rate limit hit for {func.__name__} (Attempt {retries}/{max_retries}). " |
|
f"Retrying in {backoff_delay:.2f} seconds. Error: {e}" |
|
) |
|
time.sleep(backoff_delay) |
|
else: |
|
|
|
print(f"Non-rate-limit APIError encountered in {func.__name__}: {e}") |
|
raise e |
|
except Exception as e: |
|
|
|
print(f"An unexpected error occurred in {func.__name__}: {e}") |
|
raise e |
|
return wrapper |
|
return decorator |
|
|
|
|
|
|
|
scope = [ |
|
"https://spreadsheets.google.com/feeds", |
|
"https://www.googleapis.com/auth/drive", |
|
] |
|
|
|
|
|
creds = ServiceAccountCredentials.from_json_keyfile_dict(GSERVICE_ACCOUNT_INFO, scope) |
|
client = gspread.authorize(creds) |
|
|
|
@exponential_backoff_gspread(max_retries=30, max_backoff_sec=64) |
|
def read_sheet_to_df(custom_sheet_name=None, sheet_index=0): |
|
""" |
|
Read all data from a Google Sheet into a pandas DataFrame. |
|
|
|
Parameters: |
|
custom_sheet_name (str): The name of the Google Sheet to open. If None, uses GSHEET_NAME. |
|
sheet_index (int): Index of the worksheet within the spreadsheet (default is 0, the first sheet). |
|
|
|
Returns: |
|
pandas.DataFrame: DataFrame containing the sheet data, with the first row used as headers. |
|
""" |
|
|
|
|
|
if custom_sheet_name is None: |
|
custom_sheet_name = GSHEET_NAME |
|
|
|
|
|
try: |
|
spreadsheet = client.open(custom_sheet_name) |
|
except gspread.SpreadsheetNotFound: |
|
return None |
|
|
|
|
|
try: |
|
worksheet = spreadsheet.get_worksheet(sheet_index) |
|
except IndexError: |
|
return None |
|
|
|
|
|
data = worksheet.get_all_records() |
|
|
|
|
|
df = pd.DataFrame(data) |
|
|
|
return df |
|
|
|
@exponential_backoff_gspread(max_retries=30, max_backoff_sec=64) |
|
def append_to_sheet(user_data=None, custom_row_dict=None, custom_sheet_name=None, add_header_when_create_sheet=False): |
|
""" |
|
Append a new row to a Google Sheet. If 'custom_row' is provided, append that row. |
|
Otherwise, append a default row constructed from the provided user_data. |
|
Ensures that each value is aligned with the correct column header. |
|
""" |
|
if custom_sheet_name is None: |
|
custom_sheet_name = GSHEET_NAME |
|
|
|
try: |
|
|
|
spreadsheet = client.open(custom_sheet_name) |
|
is_new = False |
|
except SpreadsheetNotFound: |
|
|
|
spreadsheet = client.create(custom_sheet_name) |
|
|
|
spreadsheet.share('[email protected]', perm_type='user', role='writer') |
|
spreadsheet.share('[email protected]', perm_type='user', role='writer') |
|
is_new = True |
|
|
|
print("Spreadsheet ID:", spreadsheet.id) |
|
|
|
sheet = spreadsheet.sheet1 |
|
|
|
|
|
existing_values = sheet.get_all_values() |
|
is_empty = (existing_values == [[]]) |
|
|
|
|
|
if (is_new or is_empty) and add_header_when_create_sheet: |
|
|
|
if custom_row_dict is not None: |
|
headers = list(custom_row_dict.keys()) |
|
else: |
|
headers = list(user_data.keys()) |
|
sheet.append_row(headers) |
|
else: |
|
|
|
headers = sheet.row_values(1) if sheet.row_count > 0 else [] |
|
|
|
|
|
if custom_row_dict is not None: |
|
|
|
custom_row = [custom_row_dict.get(header, "") for header in headers] |
|
else: |
|
|
|
custom_row = [str(datetime.datetime.now()), user_data["question"], user_data["final_answer"], user_data["trace"]] |
|
|
|
|
|
sheet.append_row(custom_row) |
|
|
|
def format_chat(response, tool_database_labels): |
|
chat_history = [] |
|
|
|
last_tool_calls = [] |
|
|
|
for msg in response: |
|
if msg["role"] == "assistant": |
|
content = msg.get("content", "") |
|
|
|
last_tool_calls = json.loads(msg.get("tool_calls", "[]")) |
|
|
|
chat_history.append( |
|
gr.ChatMessage(role="assistant", content=content) |
|
) |
|
|
|
elif msg["role"] == "tool": |
|
|
|
for i, tool_call in enumerate(last_tool_calls): |
|
name = tool_call.get("name", "") |
|
args = tool_call.get("arguments", {}) |
|
|
|
|
|
database_label = "" |
|
if name == "Tool_RAG": |
|
title = "🧰 Tool RAG" |
|
else: |
|
title = f"🛠️ {name}" |
|
for db_label, tool_list in tool_database_labels.items(): |
|
if name in tool_list: |
|
title = f"🛠️ {name}\n(**Info** {db_label} [Click to view])" |
|
database_label = " (" + db_label + ")" |
|
break |
|
|
|
|
|
raw = msg.get("content", "") |
|
try: |
|
parsed = json.loads(raw) |
|
pretty = json.dumps(parsed) |
|
except json.JSONDecodeError: |
|
pretty = raw |
|
|
|
|
|
|
|
|
|
chat_history.append( |
|
gr.ChatMessage( |
|
role="assistant", |
|
content=f"Tool Response{database_label}:\n{pretty}", |
|
metadata={ |
|
"title": title, |
|
"log": json.dumps(args), |
|
"status": 'done' |
|
} |
|
) |
|
) |
|
|
|
|
|
last_tool_calls = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if chat_history: |
|
last_msg = chat_history[-1] |
|
if isinstance(last_msg.content, str) and "[FinalAnswer]" in last_msg.content: |
|
last_msg.content = last_msg.content.replace("[FinalAnswer]", "\n**Answer:**\n") |
|
|
|
|
|
final_answer_messages = [gr.ChatMessage(role="assistant", content=chat_history[-1].content.split("\n**Answer:**\n")[-1].strip())] |
|
assistant_count = sum(1 for msg in chat_history if msg.role == "assistant") |
|
if assistant_count == 1: |
|
|
|
reasoning_messages = [gr.ChatMessage(role="assistant", content="No reasoning was conducted.")] |
|
else: |
|
|
|
reasoning_messages = chat_history.copy() |
|
|
|
|
|
return final_answer_messages, reasoning_messages, chat_history |