tgas-theme2-ph2-demo / src /scoring_utils.py
yyuri's picture
Upload 4 files
4848895 verified
import json
import os
import pandas as pd
import time
import gspread
from google.oauth2.service_account import Credentials
from src.research_html_scoring import research_html_hybrid, group1_html
from src.research_pdf_scoring import research_pdf_hybrid, group1_pdf
from utils_groupclassification.check_openai import co
from src.myLogger import set_logger
from openai import AzureOpenAI
from openai import OpenAIError
from dotenv import load_dotenv
logger = set_logger("my_app", level="INFO")
load_dotenv()
openai_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
version = os.environ.get("AZURE_OPENAI_API_VERSION")
api_type = 'gpt-4o'
def _init_client(auth_path):
json_open = open(auth_path, "r")
service_account_key = json.load(json_open)
credentials = Credentials.from_service_account_info(service_account_key)
scoped_credentials = credentials.with_scopes(
[
"https://spreadsheets.google.com/feeds",
"https://www.googleapis.com/auth/drive",
]
)
Client = gspread.authorize(scoped_credentials)
return Client
def result_change(answer, final_answer):
for i in range(0, 5):
# if final_answer = 0, answer = 1, then change final_answer = 1
if final_answer[i] < answer[i]:
final_answer[i] = 1
for i in range(6, 10):
# if final_answer = 0, answer = 1, then change final_answer = 1
if final_answer[i] < answer[i]:
final_answer[i] = 1
if answer[5] == 1:
final_answer[5].append(answer[5])
if answer[10] == 1:
final_answer[5].append(answer[10])
return final_answer
def summarize(sentence, user_prompt):
client = AzureOpenAI(
api_key=openai_api_key,
api_version=version,
azure_endpoint=openai_endpoint
)
messages = [
{"role": "system",
"content": """You are a knowledgeable assistant with expertise in corporate information and financial markets.
When users inquire about the listing status of a company, provide a clear Yes or No answer."""},
{"role": "user",
"content": user_prompt}
]
retries = 0
max_retries = 100
delay = 5
while retries < max_retries:
try:
response = client.chat.completions.create(
messages=messages,
model=api_type,
temperature=0,
)
return response.choices[0].message.content
except OpenAIError as e:
print(f"Error occurred: {e}. Retrying in {delay} seconds...")
retries += 1
time.sleep(delay)
except Exception as e:
print(f"Unexpected error: {e}. Retrying in {delay} seconds...")
retries += 1
time.sleep(delay)
raise RuntimeError("Maximum retries exceeded. Could not get a valid response.")
def summarize5(sentence):
user_prompt = f"""{sentence}は再生可能エネルギーに関する取り組みが書かれています。
太陽光発電、風力発電、バイオマスエネルギーの使用以外の再生可能エネルギーに関する具体的な取り組みについてまとめてください。
重複を取り除き、意味が通るように企業が行っている取り組みを数センテンスでまとめてください。"""
return summarize(sentence, user_prompt)
def summarize10(sentence):
user_prompt = f"""{sentence}はエネルギー使用量の削減に関する取り組みが書かれています。
再生可能エネルギー、電気自動車、暖房や給油などでヒートポンプの利用以外のエネルギー使用料削減に関する具体的な取り組みについてまとめてください。
重複を取り除き、意味が通るように企業が行っている取り組みを数センテンスでまとめてください。"""
return summarize(sentence, user_prompt)
def judge(company_name, indices) -> list[list]:
base_path = "./gspread"
os.makedirs(base_path, exist_ok=True)
with open("./output.json") as f:
config = json.load(f)
input_urls = []
related_urls = []
for value in config:
input_urls.append(value)
logger.info(f"urls: {input_urls}")
final_answers = [0] * len(indices)
final_answers[5] = ''
final_answers[10] = ''
for url in input_urls:
logger.info(f"company_name: {company_name}, url: {url}")
try:
# urlの周りに余計な文字列がある場合に削除
if url.endswith("'"):
url = url[:-1]
if url.endswith("']"):
url = url[:-2]
# urlの最後がpdfかhtmlかで処理を分ける
if url.endswith(".pdf"):
logger.info(f"pdf: {url}")
# co関数でurl先の情報が会社名と一致するか判定
judge, reason = co(company_name, url)
logger.info(f"judge: {judge}, reason: {reason}")
# 一致する場合はresearch_pdf_hybrid関数を実行
if judge == 1:
logger.info("research_pdf_hybrid")
answer = research_pdf_hybrid(url, company_name, indices)
final_answers = result_change(answer, final_answers)
related_urls.append(url)
# 一致しない場合はreasonを返す
elif judge == 0:
logger.info(f"reason: {reason}")
# answer = reason
# 取得できない場合はurl先の情報が取得できない旨を返す
elif judge == -1:
logger.debug("url先の情報が取得できません")
# answer = "url先の情報が取得できません"
else:
logger.info(f"html: {url}")
# co関数でurl先の情報が会社名と一致するか判定
judge, reason = co(company_name, url)
logger.info(f"judge: {judge}, reason: {reason}")
# 一致する場合はresearch_html_hybrid関数を実行
if judge == 1:
logger.info("research_html_hybrid")
answer = research_html_hybrid(url, company_name, indices)
final_answers = result_change(answer, final_answers)
related_urls.append(url)
# 一致しない場合はreasonを返す
elif judge == 0:
logger.info(f"reason: {reason}")
# 取得できない場合はurl先の情報が取得できない旨を返す
elif judge == -1:
logger.debug("url先の情報が取得できません")
except Exception as e:
logger.error(f"Error: {e}")
if final_answers[4] == 0:
final_answers[5] = ' '
else:
final_answers[5] = summarize5(final_answers[5])
if final_answers[9] == 0:
final_answers[10] = ' '
else:
final_answers[10] = summarize10(final_answers[10])
return final_answers, related_urls
def scoring_group1(company_name):
base_path = "./gspread"
os.makedirs(base_path, exist_ok=True)
with open("./output.json") as f:
config = json.load(f)
input_urls = []
related_urls = []
answers = ''
for value in config:
input_urls.append(value)
logger.info(f"urls: {input_urls}")
for url in input_urls:
logger.info(f"company_name: {company_name}, url: {url}")
try:
# urlの周りに余計な文字列がある場合に削除
if url.endswith("'"):
url = url[:-1]
if url.endswith("']"):
url = url[:-2]
# urlの最後がpdfかhtmlかで処理を分ける
if url.endswith(".pdf"):
logger.info(f"pdf: {url}")
# co関数でurl先の情報が会社名と一致するか判定
judge, reason = co(company_name, url)
logger.info(f"judge: {judge}, reason: {reason}")
# 一致する場合はresearch_pdf_hybrid関数を実行
if judge == 1:
logger.info("research_pdf_hybrid")
answer = group1_pdf(url)
answers += answer
related_urls.append(url)
logger.info(f"anser: {answer}")
# 一致しない場合はreasonを返す
elif judge == 0:
logger.info(f"reason: {reason}")
# 取得できない場合はurl先の情報が取得できない旨を返す
elif judge == -1:
logger.debug("url先の情報が取得できません")
# answer = "url先の情報が取得できません"
else:
logger.info(f"html: {url}")
# co関数でurl先の情報が会社名と一致するか判定
judge, reason = co(company_name, url)
logger.info(f"judge: {judge}, reason: {reason}")
# 一致する場合はresearch_html_hybrid関数を実行
if judge == 1:
logger.info("research_html_hybrid")
answer = group1_html(url)
answers += answer
related_urls.append(url)
logger.info(f"anser: {answer}")
# 一致しない場合はreasonを返す
elif judge == 0:
logger.info(f"reason: {reason}")
# 取得できない場合はurl先の情報が取得できない旨を返す
elif judge == -1:
logger.debug("url先の情報が取得できません")
except Exception as e:
logger.error(f"Error: {e}")
if answers != '':
client = AzureOpenAI(
api_key=openai_api_key,
api_version=version,
azure_endpoint=openai_endpoint
)
messages = [
{"role": "system",
"content": """You are a knowledgeable assistant with expertise in corporate information and financial markets.
When users inquire about the listing status of a company, provide a clear Yes or No answer."""},
{"role": "user",
"content": f"""{answers}は脱炭素に関して企業が言及している内容が書かれています。
重複を取り除き、意味が通るように企業が行っている取り組みを数センテンスでまとめてください。"""}
]
response = client.chat.completions.create(
messages=messages,
model=api_type,
temperature=0,
)
result = response.choices[0].message.content
else:
result = ''
return result, related_urls