tgas-theme2-ph2-demo / src /group4_scoring.py
manatoboys's picture
demo2
103de27
import json
import os
import sys
import csv
from dotenv import load_dotenv
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from utils_groupclassification.agent import submit_tool_outputs
from utils_groupclassification.assistants_utils import (
create_assistant,
create_message,
create_run,
create_thread,
wait_for_run_completion,
)
from src.scoring_utils import judge
from src.myLogger import set_logger
logger = set_logger("my_app", level="INFO")
load_dotenv()
model = 'gpt-4o'
scoreingList = [
'再生可能エネルギーを使用している',
'太陽光発電を行っている',
'風力発電を行っている',
'バイオマスエネルギーの使用をしている',
'その他の再生可能エネルギーを使用している',
'その他の再生可能エネルギーの例',
'エネルギー使用量の削減を行っている',
'電気自動車の利用',
'暖房や給油などでヒートポンプを利用している',
'その他の方法でエネルギー使用量の削減を行っている',
'その他のエネルギー使用量の削減の例'
]
def save(company_name, result, urls):
path = 'data/Group4_result.csv'
total = 0
for i in range(0, 5):
total += result[i]
for i in range(6, 10):
total += result[i]
with open(path, mode='a', newline='') as file:
row = [company_name, 'Group 4'] + result + [total, urls]
writer = csv.writer(file)
writer.writerow(row)
logger.info(f'Result: {company_name}, {urls} saved')
def scoring(company_name):
logger.info("create thread")
thread = create_thread()
logger.info("created thread")
with open("prompts/ph1.txt", "r") as fp: # search prompt
assistant_prompt = fp.read()
logger.info(f"assistant prompt: {assistant_prompt}")
with open("tools/ph1.json", "r") as f: # tool prompt
tools = json.load(f)
logger.info("create assistant")
assistant = create_assistant( # create assistant
assistant_prompt,
model=model,
tools=tools,
)
logger.info(f"User:{company_name}")
create_message(thread.id, "user", company_name)
run = create_run(thread.id, assistant.id)
run = wait_for_run_completion(thread.id, run.id)
logger.info(f"status: {run.status}")
if run.status == "failed":
logger.error(f"Run_failed: {run.last_error}, Run ID: {run.id}")
# result = [0] * len(scoreingList)
# return result, None
sys.exit(1)
elif run.status == "requires_action":
logger.info(
f"Run requires action: {run.required_action.submit_tool_outputs.tool_calls}"
)
run = submit_tool_outputs(
thread.id,
run.id,
run.required_action.submit_tool_outputs.tool_calls,
)
if run is None:
result = [0] * len(scoreingList)
return result, None
run = wait_for_run_completion(thread.id, run.id)
logger.info(f"Run status: {run.status}")
with open("prompts/ph2.txt", "r") as fp: # search prompt
assistant_prompt = fp.read()
logger.info(f"assistant prompt: {assistant_prompt}")
with open("tools/ph2.json", "r") as f: # tool prompt
tools = json.load(f)
assistant1 = create_assistant( # create assistant
assistant_prompt,
model="gpt-3.5-turbo",
tools=tools,
)
run = create_run(thread.id, assistant1.id)
run = wait_for_run_completion(thread.id, run.id)
logger.info(f"status: {run.status}")
if run.status == "failed":
logger.error(f"Run_failed: {run.error}, Run ID: {run.id}")
sys.exit(1)
elif run.status == "requires_action":
logger.info(
f"Run requires action: {run.required_action.submit_tool_outputs.tool_calls}"
)
run = submit_tool_outputs(
thread.id,
run.id,
run.required_action.submit_tool_outputs.tool_calls,
)
if run is None:
result = [0] * len(scoreingList)
return result, None
run = wait_for_run_completion(thread.id, run.id)
logger.info(f"Run status: {run.status}")
result, urls = judge(company_name, scoreingList)
url_list = ', '.join(urls)
save(company_name, result, url_list)
return result, urls
if __name__ == '__main__':
result, url_list = scoring('清水鋼鐵株式会社')
print(result)