Spaces:
Sleeping
Sleeping
import json | |
import os | |
import sys | |
import csv | |
from dotenv import load_dotenv | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
from utils_groupclassification.agent import submit_tool_outputs | |
from utils_groupclassification.assistants_utils import ( | |
create_assistant, | |
create_message, | |
create_run, | |
create_thread, | |
wait_for_run_completion, | |
) | |
from src.scoring_utils import judge | |
from src.myLogger import set_logger | |
logger = set_logger("my_app", level="INFO") | |
load_dotenv() | |
model = 'gpt-4o' | |
scoreingList = [ | |
'再生可能エネルギーを使用している', | |
'太陽光発電を行っている', | |
'風力発電を行っている', | |
'バイオマスエネルギーの使用をしている', | |
'その他の再生可能エネルギーを使用している', | |
'その他の再生可能エネルギーの例', | |
'エネルギー使用量の削減を行っている', | |
'電気自動車の利用', | |
'暖房や給油などでヒートポンプを利用している', | |
'その他の方法でエネルギー使用量の削減を行っている', | |
'その他のエネルギー使用量の削減の例' | |
] | |
def save(company_name, result, urls): | |
path = 'data/Group4_result.csv' | |
total = 0 | |
for i in range(0, 5): | |
total += result[i] | |
for i in range(6, 10): | |
total += result[i] | |
with open(path, mode='a', newline='') as file: | |
row = [company_name, 'Group 4'] + result + [total, urls] | |
writer = csv.writer(file) | |
writer.writerow(row) | |
logger.info(f'Result: {company_name}, {urls} saved') | |
def scoring(company_name): | |
logger.info("create thread") | |
thread = create_thread() | |
logger.info("created thread") | |
with open("prompts/ph1.txt", "r") as fp: # search prompt | |
assistant_prompt = fp.read() | |
logger.info(f"assistant prompt: {assistant_prompt}") | |
with open("tools/ph1.json", "r") as f: # tool prompt | |
tools = json.load(f) | |
logger.info("create assistant") | |
assistant = create_assistant( # create assistant | |
assistant_prompt, | |
model=model, | |
tools=tools, | |
) | |
logger.info(f"User:{company_name}") | |
create_message(thread.id, "user", company_name) | |
run = create_run(thread.id, assistant.id) | |
run = wait_for_run_completion(thread.id, run.id) | |
logger.info(f"status: {run.status}") | |
if run.status == "failed": | |
logger.error(f"Run_failed: {run.last_error}, Run ID: {run.id}") | |
# result = [0] * len(scoreingList) | |
# return result, None | |
sys.exit(1) | |
elif run.status == "requires_action": | |
logger.info( | |
f"Run requires action: {run.required_action.submit_tool_outputs.tool_calls}" | |
) | |
run = submit_tool_outputs( | |
thread.id, | |
run.id, | |
run.required_action.submit_tool_outputs.tool_calls, | |
) | |
if run is None: | |
result = [0] * len(scoreingList) | |
return result, None | |
run = wait_for_run_completion(thread.id, run.id) | |
logger.info(f"Run status: {run.status}") | |
with open("prompts/ph2.txt", "r") as fp: # search prompt | |
assistant_prompt = fp.read() | |
logger.info(f"assistant prompt: {assistant_prompt}") | |
with open("tools/ph2.json", "r") as f: # tool prompt | |
tools = json.load(f) | |
assistant1 = create_assistant( # create assistant | |
assistant_prompt, | |
model="gpt-3.5-turbo", | |
tools=tools, | |
) | |
run = create_run(thread.id, assistant1.id) | |
run = wait_for_run_completion(thread.id, run.id) | |
logger.info(f"status: {run.status}") | |
if run.status == "failed": | |
logger.error(f"Run_failed: {run.error}, Run ID: {run.id}") | |
sys.exit(1) | |
elif run.status == "requires_action": | |
logger.info( | |
f"Run requires action: {run.required_action.submit_tool_outputs.tool_calls}" | |
) | |
run = submit_tool_outputs( | |
thread.id, | |
run.id, | |
run.required_action.submit_tool_outputs.tool_calls, | |
) | |
if run is None: | |
result = [0] * len(scoreingList) | |
return result, None | |
run = wait_for_run_completion(thread.id, run.id) | |
logger.info(f"Run status: {run.status}") | |
result, urls = judge(company_name, scoreingList) | |
url_list = ', '.join(urls) | |
save(company_name, result, url_list) | |
return result, urls | |
if __name__ == '__main__': | |
result, url_list = scoring('清水鋼鐵株式会社') | |
print(result) | |