tgas-theme2-ph2-demo / src /research_pdf_scoring.py
yyuri's picture
Upload 4 files
4848895 verified
from logging import getLogger
import requests
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from openai import OpenAIError
import os
import re
import json
import pandas as pd
import chromadb
from rank_bm25 import BM25Okapi
from janome.tokenizer import Tokenizer
from collections import OrderedDict
from openai import OpenAI
import time
logger = getLogger(__name__)
# Tokenizerの初期化
t = Tokenizer()
# 文書用のTokenizerの定義
def tokenize(text):
return [token.surface for token in t.tokenize(text)]
# クエリ用のTokenizerの定義
def query_tokenize(text):
return [token.surface for token in t.tokenize(text) if token.part_of_speech.split(',')[0] in ["名詞", "動詞", "形容詞"]]
def normalize_text(s, sep_token = " \n "):
s = re.sub(r'\s+', ' ', s).strip()
s = re.sub(r". ,","",s)
s = s.replace("..",".")
s = s.replace(". .",".")
s = s.replace("\n", "")
s = s.strip()
return s
def generate_answer_(reference, system_prompt, json_schema, max_retries=100, delay=5):
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(
api_key=api_key,
)
retries = 0
while retries < max_retries:
try:
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": reference,
},
],
functions=[{"name": "generate_queries", "parameters": json_schema}],
function_call={"name": "generate_queries"},
temperature=0.0,
top_p=0.0,
)
print("completion end")
output = response.choices[0].message.function_call.arguments
time.sleep(1)
return output
except OpenAIError as e:
print(f"Error occurred: {e}. Retrying in {delay} seconds...")
retries += 1
time.sleep(delay)
except Exception as e:
print(f"Unexpected error: {e}. Retrying in {delay} seconds...")
retries += 1
time.sleep(delay)
raise RuntimeError("Maximum retries exceeded. Could not get a valid response.")
def find_context(pdf_url):
url = pdf_url
# リクエストをきちんと送るためのもの
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36"
}
result = requests.get(url, headers=headers)
if result.status_code == "403":
return "NO"
if result.status_code == "200":
result.encoding = result.apparent_encoding
if result.encoding == "":
result.encoding = "utf-8"
finename = (
# "out/" +
# url.replace("/", "_").replace(":", "+") + "IR.pdf"
"./out/tmp.pdf"
)
try:
print("Downloading...")
if os.path.exists(finename):
print("Already Downloaded")
else:
with open(finename, "wb") as file:
for chunk in result.iter_content(1000000000):
file.write(chunk)
print("Downloaded")
except Exception as e:
print(e)
loader = PyPDFLoader(finename)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=2000, chunk_overlap=50, separators=["\n", "\n\n"]
)
df = pd.DataFrame(columns=["source", "chunk_no", "content"])
# テキストを読み込んで分割し、DataFrameに格納
chunk = text_splitter.split_documents(documents)
filename = pdf_url
for i, c in enumerate(chunk):
# print(f"chunk {c.page_content}")
df = pd.concat([df, pd.DataFrame({"ID": str(i) ,"source": filename, "chunk_no": i, "content": c.page_content}, index=[0])])
# DataFrameを表示
df["content"] = df["content"].apply(lambda x: normalize_text(x))
df = df.reset_index(drop=True)
# contentとmetadataをリスト型で定義
content_list = df["content"].tolist()
metadata_list = df[["ID", "source"]].to_dict(orient="records")
# Embeddingの定義
print("Embedding")
embeddings = OpenAIEmbeddings(
api_key=os.getenv("OPENAI_API_KEY"),
)
print("Embedding done")
# VectorStoreの定義
print("VectorStore")
client = chromadb.PersistentClient()
db = Chroma(
collection_name="langchain_store",
embedding_function=embeddings,
client=client,
)
print("VectorStore done")
ids = df["ID"].tolist()
# Vectorstoreにデータを登録
print("add_texts")
db.add_texts(
texts=content_list,
metadatas=metadata_list,
embeddings=embeddings,
ids=ids,
)
print("add_texts done")
# クエリと関連する文を5件検索してRAGでansを生成
n = 10
# 文書を単語リストに分割
tokenized_documents = [tokenize(doc) for doc in content_list]
print("tokenized_documents done")
# BM25
bm25 = BM25Okapi(tokenized_documents)
print("BM25 done")
# まず初めに脱炭素に言及しているか調べる
query = '脱炭素に関連する情報をまとめてください。'
vector_top = db.similarity_search(query=query, k=n)
vector_rank_list = [{"content": doc.page_content, "vector_rank": i+1} for i, doc in enumerate(vector_top)]
df_vector = pd.DataFrame(vector_rank_list)
# キーワード検索
tokenized_query = query_tokenize(query)
keyword_top = bm25.get_top_n(tokenized_query, content_list, n=n)
keyword_rank_list = [{"content":doc, "keyword_rank":i+1} for i, doc in enumerate(keyword_top)]
df_keyword = pd.DataFrame(keyword_rank_list)
# 順位を結合して表示
df_rank = pd.merge(df_vector, df_keyword, on="content", how="left")
df_rank["hybrid_score"] = 1/(df_rank["vector_rank"]+60) + 1/(df_rank["keyword_rank"]+60)
df_rank = pd.merge(df_rank, df, on="content", how="left")
df_rank.sort_values(by="hybrid_score", ascending=False).head()
# 結果をCSVファイルとして保存
df_rank_sorted = df_rank.sort_values(by="hybrid_score", ascending=False)
# df_rank_sorted.to_csv("search_rank_results.csv", index=False)
search_results = df_rank_sorted["content"].tolist()
relevent_texts = ".\n".join([doc for doc in search_results])
context = relevent_texts[:10000]
return context, client
def research_pdf_hybrid(
pdf_url,
company_name,
question_bank
):
result_list = []
context, client = find_context(pdf_url)
for i in range(0, 5):
question = question_bank[i]
logger.info(f'Question: {question}')
q_prompt = f"""
次の資料に、{company_name}{question}という情報はありますか、具体的な引用を含めて答えてください。
また、明記されている場合のみ1と判断してください。
[出力フォーマット]
{{
judge: 0 or 1({question}であれば1,そうでなければ0)
reason: "どうしてそう判断したのか具体的に説明してください。"
}}
===example1===
[内容]
{company_name}{question}
[出力]
{{
judge: 1
reason: "{company_name}{question}と明記されているため"
}}
===example2===
[内容]
{company_name}は脱炭素・低炭素に向けたさまざまな取り組みを行っています。
[出力]
{{
judge: 0
reason: "特に{question}と明記されていないため"
}}
"""
json_schema = {
"type": "object",
"properties": {
"judge": {
"type": "integer",
"description": f"{question}であれば1,そうでなければ0"
},
"reason": {
"type": "string",
"description": "どうしてそう判断したのか"
}
},
"required": ["judge", "reason"]
}
if i == 4:
q_prompt = f"""
次の資料に、{company_name}が太陽光発電、風力発電、バイオマスエネルギー以外の再生可能エネルギーを使用している
という情報はありますか、具体的な引用を含めて答えてください。明記されている場合のみ1と判断してください。
[出力フォーマット]
{{
judge: 0 or 1(上記以外の再生可能エネルギーを使用していれば1,そうでなければ0)
reason: "記事に書かれている文章・具体的な取り組みを引用してください。"
}}"""
ret = generate_answer_(context, q_prompt, json_schema)
js = json.loads(ret)
logger.info(f"judge output:{js}")
result_list.append(js["judge"])
if i == 4:
result_list.append(js["reason"])
# その他の具体的な内容の部分を除くため
for i in range(6, 10):
question = question_bank[i]
logger.info(f'Question: {question}')
q_prompt = f"""
次の資料に、{company_name}{question}という情報はありますか、具体的な引用を含めて答えてください。
また、明記されている場合のみ1と判断してください。
[出力フォーマット]
{{
judge: 0 or 1({question}であれば1,そうでなければ0)
reason: "どうしてそう判断したのか具体的に説明してください。"
}}
===example1===
[内容]
{company_name}{question}
[出力]
{{
judge: 1
reason: "{company_name}{question}と明記されているため"
}}
===example2===
[内容]
{company_name}は脱炭素・低炭素に向けたさまざまな取り組みを行っています。
[出力]
{{
judge: 0
reason: "特に{question}と明記されていないため"
}}
"""
if i == 9:
q_prompt = f"""
次の資料に、{company_name}が再生可能エネルギー、電気自動車、暖房や給油などでヒートポンプを利用している以外の方法で
エネルギー使用量の削減に取り組んでいるという情報はありますか、具体的な引用を含めて答えてください。
また、明記されている場合のみ1と判断してください。
[出力フォーマット]
{{
judge: 0 or 1(上記以外の方法でエネルギー使用量の削減を行っていれば1,そうでなければ0)
reason: "記事に書かれている文章・具体的な取り組みを引用してください。"
}}
"""
json_schema = {
"type": "object",
"properties": {
"judge": {
"type": "integer",
"description": f"{question}であれば1,そうでなければ0"
},
"reason": {
"type": "string",
"description": "どうしてそう判断したのか"
}
},
"required": ["judge", "reason"]
}
ret = generate_answer_(context, q_prompt, json_schema)
js = json.loads(ret)
logger.info(f"judge output:{js}")
result_list.append(js["judge"])
if i == 9:
result_list.append(js["reason"])
try:
client.delete_collection("langchain_store")
print("Collection deleted successfully.")
except Exception as e:
print(f"An error occurred during collection deletion: {e}")
return result_list
def group1_pdf(pdf_url):
context, client = find_context(pdf_url)
q_prompt = """次の資料の脱炭素に関連する情報をまとめてください"""
json_schema = {
"type": "object",
"properties": {
"answer": {
"type": "string",
"description": "与えられた資料の脱炭素に関連する情報をまとめてください"
},
},
"required": ["answer"]
}
ret = generate_answer_(context, q_prompt, json_schema)
js = json.loads(ret)
logger.info(f"judge output:{js}")
try:
client.delete_collection("langchain_store")
except Exception as e:
logger.error(f"An error occurred during collection deletion: {e}")
return js["answer"]