Spaces:
Sleeping
Sleeping
import json | |
import os | |
import pandas as pd | |
import time | |
import gspread | |
from google.oauth2.service_account import Credentials | |
from src.research_html_scoring import research_html_hybrid, group1_html | |
from src.research_pdf_scoring import research_pdf_hybrid, group1_pdf | |
from utils_groupclassification.check_openai import co | |
from src.myLogger import set_logger | |
from openai import AzureOpenAI | |
from openai import OpenAIError | |
from dotenv import load_dotenv | |
logger = set_logger("my_app", level="INFO") | |
load_dotenv() | |
openai_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT") | |
openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY") | |
version = os.environ.get("AZURE_OPENAI_API_VERSION") | |
api_type = 'gpt-4o' | |
def _init_client(auth_path): | |
json_open = open(auth_path, "r") | |
service_account_key = json.load(json_open) | |
credentials = Credentials.from_service_account_info(service_account_key) | |
scoped_credentials = credentials.with_scopes( | |
[ | |
"https://spreadsheets.google.com/feeds", | |
"https://www.googleapis.com/auth/drive", | |
] | |
) | |
Client = gspread.authorize(scoped_credentials) | |
return Client | |
def result_change(answer, final_answer): | |
for i in range(0, 5): | |
# if final_answer = 0, answer = 1, then change final_answer = 1 | |
if final_answer[i] < answer[i]: | |
final_answer[i] = 1 | |
for i in range(6, 10): | |
# if final_answer = 0, answer = 1, then change final_answer = 1 | |
if final_answer[i] < answer[i]: | |
final_answer[i] = 1 | |
if answer[5] == 1: | |
final_answer[5].append(answer[5]) | |
if answer[10] == 1: | |
final_answer[5].append(answer[10]) | |
return final_answer | |
def summarize(sentence, user_prompt): | |
client = AzureOpenAI( | |
api_key=openai_api_key, | |
api_version=version, | |
azure_endpoint=openai_endpoint | |
) | |
messages = [ | |
{"role": "system", | |
"content": """You are a knowledgeable assistant with expertise in corporate information and financial markets. | |
When users inquire about the listing status of a company, provide a clear Yes or No answer."""}, | |
{"role": "user", | |
"content": user_prompt} | |
] | |
retries = 0 | |
max_retries = 100 | |
delay = 5 | |
while retries < max_retries: | |
try: | |
response = client.chat.completions.create( | |
messages=messages, | |
model=api_type, | |
temperature=0, | |
) | |
return response.choices[0].message.content | |
except OpenAIError as e: | |
print(f"Error occurred: {e}. Retrying in {delay} seconds...") | |
retries += 1 | |
time.sleep(delay) | |
except Exception as e: | |
print(f"Unexpected error: {e}. Retrying in {delay} seconds...") | |
retries += 1 | |
time.sleep(delay) | |
raise RuntimeError("Maximum retries exceeded. Could not get a valid response.") | |
def summarize5(sentence): | |
user_prompt = f"""{sentence}は再生可能エネルギーに関する取り組みが書かれています。 | |
太陽光発電、風力発電、バイオマスエネルギーの使用以外の再生可能エネルギーに関する具体的な取り組みについてまとめてください。 | |
重複を取り除き、意味が通るように企業が行っている取り組みを数センテンスでまとめてください。""" | |
return summarize(sentence, user_prompt) | |
def summarize10(sentence): | |
user_prompt = f"""{sentence}はエネルギー使用量の削減に関する取り組みが書かれています。 | |
再生可能エネルギー、電気自動車、暖房や給油などでヒートポンプの利用以外のエネルギー使用料削減に関する具体的な取り組みについてまとめてください。 | |
重複を取り除き、意味が通るように企業が行っている取り組みを数センテンスでまとめてください。""" | |
return summarize(sentence, user_prompt) | |
def judge(company_name, indices) -> list[list]: | |
base_path = "./gspread" | |
os.makedirs(base_path, exist_ok=True) | |
with open("./output.json") as f: | |
config = json.load(f) | |
input_urls = [] | |
related_urls = [] | |
for value in config: | |
input_urls.append(value) | |
logger.info(f"urls: {input_urls}") | |
final_answers = [0] * len(indices) | |
final_answers[5] = '' | |
final_answers[10] = '' | |
for url in input_urls: | |
logger.info(f"company_name: {company_name}, url: {url}") | |
try: | |
# urlの周りに余計な文字列がある場合に削除 | |
if url.endswith("'"): | |
url = url[:-1] | |
if url.endswith("']"): | |
url = url[:-2] | |
# urlの最後がpdfかhtmlかで処理を分ける | |
if url.endswith(".pdf"): | |
logger.info(f"pdf: {url}") | |
# co関数でurl先の情報が会社名と一致するか判定 | |
judge, reason = co(company_name, url) | |
logger.info(f"judge: {judge}, reason: {reason}") | |
# 一致する場合はresearch_pdf_hybrid関数を実行 | |
if judge == 1: | |
logger.info("research_pdf_hybrid") | |
answer = research_pdf_hybrid(url, company_name, indices) | |
final_answers = result_change(answer, final_answers) | |
related_urls.append(url) | |
# 一致しない場合はreasonを返す | |
elif judge == 0: | |
logger.info(f"reason: {reason}") | |
# answer = reason | |
# 取得できない場合はurl先の情報が取得できない旨を返す | |
elif judge == -1: | |
logger.debug("url先の情報が取得できません") | |
# answer = "url先の情報が取得できません" | |
else: | |
logger.info(f"html: {url}") | |
# co関数でurl先の情報が会社名と一致するか判定 | |
judge, reason = co(company_name, url) | |
logger.info(f"judge: {judge}, reason: {reason}") | |
# 一致する場合はresearch_html_hybrid関数を実行 | |
if judge == 1: | |
logger.info("research_html_hybrid") | |
answer = research_html_hybrid(url, company_name, indices) | |
final_answers = result_change(answer, final_answers) | |
related_urls.append(url) | |
# 一致しない場合はreasonを返す | |
elif judge == 0: | |
logger.info(f"reason: {reason}") | |
# 取得できない場合はurl先の情報が取得できない旨を返す | |
elif judge == -1: | |
logger.debug("url先の情報が取得できません") | |
except Exception as e: | |
logger.error(f"Error: {e}") | |
if final_answers[4] == 0: | |
final_answers[5] = ' ' | |
else: | |
final_answers[5] = summarize5(final_answers[5]) | |
if final_answers[9] == 0: | |
final_answers[10] = ' ' | |
else: | |
final_answers[10] = summarize10(final_answers[10]) | |
return final_answers, related_urls | |
def scoring_group1(company_name): | |
base_path = "./gspread" | |
os.makedirs(base_path, exist_ok=True) | |
with open("./output.json") as f: | |
config = json.load(f) | |
input_urls = [] | |
related_urls = [] | |
answers = '' | |
for value in config: | |
input_urls.append(value) | |
logger.info(f"urls: {input_urls}") | |
for url in input_urls: | |
logger.info(f"company_name: {company_name}, url: {url}") | |
try: | |
# urlの周りに余計な文字列がある場合に削除 | |
if url.endswith("'"): | |
url = url[:-1] | |
if url.endswith("']"): | |
url = url[:-2] | |
# urlの最後がpdfかhtmlかで処理を分ける | |
if url.endswith(".pdf"): | |
logger.info(f"pdf: {url}") | |
# co関数でurl先の情報が会社名と一致するか判定 | |
judge, reason = co(company_name, url) | |
logger.info(f"judge: {judge}, reason: {reason}") | |
# 一致する場合はresearch_pdf_hybrid関数を実行 | |
if judge == 1: | |
logger.info("research_pdf_hybrid") | |
answer = group1_pdf(url) | |
answers += answer | |
related_urls.append(url) | |
logger.info(f"anser: {answer}") | |
# 一致しない場合はreasonを返す | |
elif judge == 0: | |
logger.info(f"reason: {reason}") | |
# 取得できない場合はurl先の情報が取得できない旨を返す | |
elif judge == -1: | |
logger.debug("url先の情報が取得できません") | |
# answer = "url先の情報が取得できません" | |
else: | |
logger.info(f"html: {url}") | |
# co関数でurl先の情報が会社名と一致するか判定 | |
judge, reason = co(company_name, url) | |
logger.info(f"judge: {judge}, reason: {reason}") | |
# 一致する場合はresearch_html_hybrid関数を実行 | |
if judge == 1: | |
logger.info("research_html_hybrid") | |
answer = group1_html(url) | |
answers += answer | |
related_urls.append(url) | |
logger.info(f"anser: {answer}") | |
# 一致しない場合はreasonを返す | |
elif judge == 0: | |
logger.info(f"reason: {reason}") | |
# 取得できない場合はurl先の情報が取得できない旨を返す | |
elif judge == -1: | |
logger.debug("url先の情報が取得できません") | |
except Exception as e: | |
logger.error(f"Error: {e}") | |
if answers != '': | |
client = AzureOpenAI( | |
api_key=openai_api_key, | |
api_version=version, | |
azure_endpoint=openai_endpoint | |
) | |
messages = [ | |
{"role": "system", | |
"content": """You are a knowledgeable assistant with expertise in corporate information and financial markets. | |
When users inquire about the listing status of a company, provide a clear Yes or No answer."""}, | |
{"role": "user", | |
"content": f"""{answers}は脱炭素に関して企業が言及している内容が書かれています。 | |
重複を取り除き、意味が通るように企業が行っている取り組みを数センテンスでまとめてください。"""} | |
] | |
response = client.chat.completions.create( | |
messages=messages, | |
model=api_type, | |
temperature=0, | |
) | |
result = response.choices[0].message.content | |
else: | |
result = '' | |
return result, related_urls | |