tgas-theme2-ph2-demo / src /research_html_scoring.py
yyuri's picture
Upload 4 files
4848895 verified
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
import json
import requests
from langchain.text_splitter import RecursiveCharacterTextSplitter
from bs4 import BeautifulSoup
from pydantic import BaseModel
import os
import re
import pandas as pd
import chromadb
from rank_bm25 import BM25Okapi
from janome.tokenizer import Tokenizer
from openai import OpenAI
from openai import OpenAIError
from src.myLogger import set_logger
import time
logger = set_logger("my_app", level="INFO")
load_dotenv(".env.dev")
class Document(BaseModel):
page_content: str
metadata: dict = {}
# Tokenizerの初期化
t = Tokenizer()
# 文書用のTokenizerの定義
def tokenize(text):
return [token.surface for token in t.tokenize(text)]
# クエリ用のTokenizerの定義
def query_tokenize(text):
return [token.surface for token in t.tokenize(text) if token.part_of_speech.split(',')[0] in ["名詞", "動詞", "形容詞"]]
def normalize_text(s, sep_token = " \n "):
s = re.sub(r'\s+', ' ', s).strip()
s = re.sub(r". ,","",s)
s = s.replace("..",".")
s = s.replace(". .",".")
s = s.replace("\n", "")
s = s.strip()
return s
def generate_answer_(reference, system_prompt, json_schema, max_retries=100, delay=5):
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(
api_key=api_key,
)
retries = 0
while retries < max_retries:
try:
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": reference,
},
],
functions=[{"name": "generate_queries", "parameters": json_schema}],
function_call={"name": "generate_queries"},
temperature=0.0,
top_p=0.0,
)
output = response.choices[0].message.function_call.arguments
return output # Return successfully if no exception occurs
except OpenAIError as e:
print(f"Error occurred: {e}. Retrying in {delay} seconds...")
retries += 1
time.sleep(delay)
except Exception as e:
print(f"Unexpected error: {e}. Retrying in {delay} seconds...")
retries += 1
time.sleep(delay)
raise RuntimeError("Maximum retries exceeded. Could not get a valid response.")
def find_context(pdf_url):
url = pdf_url
# リクエストをきちんと送るためのもの
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36"
}
print(f"requesting {url}")
result = requests.get(url, headers=headers, timeout=5)
print(result.status_code)
result.encoding = result.apparent_encoding
soup = BeautifulSoup(result.text, "html.parser")
paragraphs = soup.find_all("p")
full_text = ' '.join([paragraph.get_text() for paragraph in paragraphs])
# 連結したテキストを1つのDocumentオブジェクトとして扱う
documents = [Document(page_content=full_text, metadata={})]
# documents = [Document(page_content=paragraph.get_text(), metadata={}) for paragraph in paragraphs]
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=2000, chunk_overlap=50
)
df = pd.DataFrame(columns=["source", "chunk_no", "content"])
# テキストを読み込んで分割し、DataFrameに格納
chunk = text_splitter.split_documents(documents)
filename = pdf_url
for i, c in enumerate(chunk):
# print(f"chunk {c.page_content}")
df = pd.concat([df, pd.DataFrame({"ID": str(i),"source": filename, "chunk_no": i, "content": c.page_content}, index=[0])])
# DataFrameを表示
df["content"] = df["content"].apply(lambda x: normalize_text(x))
df = df.reset_index(drop=True)
# contentとmetadataをリスト型で定義
content_list = df["content"].tolist()
metadata_list = df[["ID", "source"]].to_dict(orient="records")
# Embeddingの定義
embeddings = OpenAIEmbeddings(
openai_api_key=os.getenv("OPENAI_API_KEY"),
)
# VectorStoreの定義
client = chromadb.PersistentClient()
db = Chroma(
collection_name="langchain_store",
embedding_function=embeddings,
client=client,
)
ids = df["ID"].tolist()
# Vectorstoreにデータを登録
db.add_texts(
texts=content_list,
metadatas=metadata_list,
embeddings=embeddings,
ids=ids,
)
# クエリと関連する文を5件検索してRAGでansを生成
n = 10
# 文書を単語リストに分割
tokenized_documents = [tokenize(doc) for doc in content_list]
# BM25
bm25 = BM25Okapi(tokenized_documents)
# まず初めに脱炭素に言及しているか調べる
query = '脱炭素・省エネ・カーボンニュートラルに関連する情報'
vector_top = db.similarity_search(query=query, k=n)
vector_rank_list = [{"content": doc.page_content, "vector_rank": i+1} for i, doc in enumerate(vector_top)]
df_vector = pd.DataFrame(vector_rank_list)
# キーワード検索
tokenized_query = query_tokenize(query)
keyword_top = bm25.get_top_n(tokenized_query, content_list, n=n)
keyword_rank_list = [{"content":doc, "keyword_rank":i+1} for i, doc in enumerate(keyword_top)]
df_keyword = pd.DataFrame(keyword_rank_list)
# 順位を結合して表示
df_rank = pd.merge(df_vector, df_keyword, on="content", how="left")
df_rank["hybrid_score"] = 1/(df_rank["vector_rank"]+60) + 1/(df_rank["keyword_rank"]+60)
df_rank = pd.merge(df_rank, df, on="content", how="left")
df_rank.sort_values(by="hybrid_score", ascending=False).head()
# 結果をCSVファイルとして保存
df_rank_sorted = df_rank.sort_values(by="hybrid_score", ascending=False)
# df_rank_sorted.to_csv("search_rank_results.csv", index=False)
search_results = df_rank_sorted["content"].tolist()
relevent_texts = ".\n".join([doc for doc in search_results])
context = relevent_texts[:10000]
return context, client
def research_html_hybrid(
pdf_url,
company_name,
question_bank,
):
result_list = []
context, client = find_context(pdf_url)
for i in range(0, 5):
question = question_bank[i]
logger.info(f'Question: {question}')
q_prompt = f"""
次の資料に、{company_name}{question}という情報はありますか、具体的な引用を含めて答えてください。
また、明記されている場合のみ1と判断してください。
[出力フォーマット]
{{
judge: 0 or 1({question}であれば1,そうでなければ0)
reason: "どうしてそう判断したのか具体的に説明してください。"
}}
===example1===
[内容]
{company_name}{question}
[出力]
{{
judge: 1
reason: "{company_name}{question}と明記されているため"
}}
===example2===
[内容]
{company_name}は脱炭素・低炭素に向けたさまざまな取り組みを行っています。
[出力]
{{
judge: 0
reason: "特に{question}と明記されていないため"
}}
"""
json_schema = {
"type": "object",
"properties": {
"judge": {
"type": "integer",
"description": f"{question}であれば1,そうでなければ0"
},
"reason": {
"type": "string",
"description": "どうしてそう判断したのか"
}
},
"required": ["judge", "reason"]
}
if i == 4:
q_prompt = f"""
次の資料に、{company_name}が太陽光発電、風力発電、バイオマスエネルギー以外の再生可能エネルギーを使用している
という情報はありますか、具体的な引用を含めて答えてください。明記されている場合のみ1と判断してください。
[出力フォーマット]
{{
judge: 0 or 1(上記以外の再生可能エネルギーを使用していれば1,そうでなければ0)
reason: "記事に書かれている文章・具体的な取り組みを引用してください。"
}}"""
ret = generate_answer_(context, q_prompt, json_schema)
js = json.loads(ret)
logger.info(f"judge output:{js}")
result_list.append(js["judge"])
if i == 4:
result_list.append(js["reason"])
for i in range(6, 10):
question = question_bank[i]
logger.info(f'Question: {question}')
q_prompt = f"""
次の資料に、{company_name}{question}という情報はありますか、具体的な引用を含めて答えてください。
また、明記されている場合のみ1と判断してください。
[出力フォーマット]
{{
judge: 0 or 1({question}であれば1,そうでなければ0)
reason: "どうしてそう判断したのか具体的に説明してください。"
}}
===example1===
[内容]
{company_name}{question}
[出力]
{{
judge: 1
reason: "{company_name}{question}と明記されているため"
}}
===example2===
[内容]
{company_name}は脱炭素・低炭素に向けたさまざまな取り組みを行っています。
[出力]
{{
judge: 0
reason: "特に{question}と明記されていないため"
}}
"""
json_schema = {
"type": "object",
"properties": {
"judge": {
"type": "integer",
"description": f"{question}であれば1,そうでなければ0"
},
"reason": {
"type": "string",
"description": "どうしてそう判断したのか"
}
},
"required": ["judge", "reason"]
}
if i == 9:
q_prompt = f"""
次の資料に、{company_name}が再生可能エネルギー、電気自動車、暖房や給油などでヒートポンプを利用している以外の方法で
エネルギー使用量の削減に取り組んでいるという情報はありますか、具体的な引用を含めて答えてください。
また、明記されている場合のみ1と判断してください。
[出力フォーマット]
{{
judge: 0 or 1(上記以外の方法でエネルギー使用量の削減を行っていれば1,そうでなければ0)
reason: "記事に書かれている文章・具体的な取り組みを引用してください。"
}}
"""
ret = generate_answer_(context, q_prompt, json_schema)
js = json.loads(ret)
logger.info(f"judge output:{js}")
result_list.append(js["judge"])
if i == 9:
result_list.append(js["reason"])
try:
client.delete_collection("langchain_store")
except Exception as e:
logger.error(f"An error occurred during collection deletion: {e}")
return result_list
def group1_html(pdf_url):
context, client = find_context(pdf_url)
q_prompt = """次の資料の脱炭素に関連する情報をまとめてください"""
json_schema = {
"type": "object",
"properties": {
"answer": {
"type": "string",
"description": "与えられた資料の脱炭素に関連する情報をまとめてください"
},
},
"required": ["answer"]
}
ret = generate_answer_(context, q_prompt, json_schema)
js = json.loads(ret)
logger.info(f"judge output:{js}")
try:
client.delete_collection("langchain_store")
except Exception as e:
logger.error(f"An error occurred during collection deletion: {e}")
return js["answer"]