Spaces:
Sleeping
Sleeping
from logging import getLogger | |
import requests | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import Chroma | |
from langchain_openai import OpenAIEmbeddings | |
from openai import OpenAIError | |
import os | |
import re | |
import json | |
import pandas as pd | |
import chromadb | |
from rank_bm25 import BM25Okapi | |
from janome.tokenizer import Tokenizer | |
from collections import OrderedDict | |
from openai import OpenAI | |
import time | |
logger = getLogger(__name__) | |
# Tokenizerの初期化 | |
t = Tokenizer() | |
# 文書用のTokenizerの定義 | |
def tokenize(text): | |
return [token.surface for token in t.tokenize(text)] | |
# クエリ用のTokenizerの定義 | |
def query_tokenize(text): | |
return [token.surface for token in t.tokenize(text) if token.part_of_speech.split(',')[0] in ["名詞", "動詞", "形容詞"]] | |
def normalize_text(s, sep_token = " \n "): | |
s = re.sub(r'\s+', ' ', s).strip() | |
s = re.sub(r". ,","",s) | |
s = s.replace("..",".") | |
s = s.replace(". .",".") | |
s = s.replace("\n", "") | |
s = s.strip() | |
return s | |
def generate_answer_(reference, system_prompt, json_schema, max_retries=100, delay=5): | |
api_key = os.getenv("OPENAI_API_KEY") | |
client = OpenAI( | |
api_key=api_key, | |
) | |
retries = 0 | |
while retries < max_retries: | |
try: | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{ | |
"role": "system", | |
"content": system_prompt, | |
}, | |
{ | |
"role": "user", | |
"content": reference, | |
}, | |
], | |
functions=[{"name": "generate_queries", "parameters": json_schema}], | |
function_call={"name": "generate_queries"}, | |
temperature=0.0, | |
top_p=0.0, | |
) | |
print("completion end") | |
output = response.choices[0].message.function_call.arguments | |
time.sleep(1) | |
return output | |
except OpenAIError as e: | |
print(f"Error occurred: {e}. Retrying in {delay} seconds...") | |
retries += 1 | |
time.sleep(delay) | |
except Exception as e: | |
print(f"Unexpected error: {e}. Retrying in {delay} seconds...") | |
retries += 1 | |
time.sleep(delay) | |
raise RuntimeError("Maximum retries exceeded. Could not get a valid response.") | |
def find_context(pdf_url): | |
url = pdf_url | |
# リクエストをきちんと送るためのもの | |
headers = { | |
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36" | |
} | |
result = requests.get(url, headers=headers) | |
if result.status_code == "403": | |
return "NO" | |
if result.status_code == "200": | |
result.encoding = result.apparent_encoding | |
if result.encoding == "": | |
result.encoding = "utf-8" | |
finename = ( | |
# "out/" + | |
# url.replace("/", "_").replace(":", "+") + "IR.pdf" | |
"./out/tmp.pdf" | |
) | |
try: | |
print("Downloading...") | |
if os.path.exists(finename): | |
print("Already Downloaded") | |
else: | |
with open(finename, "wb") as file: | |
for chunk in result.iter_content(1000000000): | |
file.write(chunk) | |
print("Downloaded") | |
except Exception as e: | |
print(e) | |
loader = PyPDFLoader(finename) | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=2000, chunk_overlap=50, separators=["\n", "\n\n"] | |
) | |
df = pd.DataFrame(columns=["source", "chunk_no", "content"]) | |
# テキストを読み込んで分割し、DataFrameに格納 | |
chunk = text_splitter.split_documents(documents) | |
filename = pdf_url | |
for i, c in enumerate(chunk): | |
# print(f"chunk {c.page_content}") | |
df = pd.concat([df, pd.DataFrame({"ID": str(i) ,"source": filename, "chunk_no": i, "content": c.page_content}, index=[0])]) | |
# DataFrameを表示 | |
df["content"] = df["content"].apply(lambda x: normalize_text(x)) | |
df = df.reset_index(drop=True) | |
# contentとmetadataをリスト型で定義 | |
content_list = df["content"].tolist() | |
metadata_list = df[["ID", "source"]].to_dict(orient="records") | |
# Embeddingの定義 | |
print("Embedding") | |
embeddings = OpenAIEmbeddings( | |
api_key=os.getenv("OPENAI_API_KEY"), | |
) | |
print("Embedding done") | |
# VectorStoreの定義 | |
print("VectorStore") | |
client = chromadb.PersistentClient() | |
db = Chroma( | |
collection_name="langchain_store", | |
embedding_function=embeddings, | |
client=client, | |
) | |
print("VectorStore done") | |
ids = df["ID"].tolist() | |
# Vectorstoreにデータを登録 | |
print("add_texts") | |
db.add_texts( | |
texts=content_list, | |
metadatas=metadata_list, | |
embeddings=embeddings, | |
ids=ids, | |
) | |
print("add_texts done") | |
# クエリと関連する文を5件検索してRAGでansを生成 | |
n = 10 | |
# 文書を単語リストに分割 | |
tokenized_documents = [tokenize(doc) for doc in content_list] | |
print("tokenized_documents done") | |
# BM25 | |
bm25 = BM25Okapi(tokenized_documents) | |
print("BM25 done") | |
# まず初めに脱炭素に言及しているか調べる | |
query = '脱炭素に関連する情報をまとめてください。' | |
vector_top = db.similarity_search(query=query, k=n) | |
vector_rank_list = [{"content": doc.page_content, "vector_rank": i+1} for i, doc in enumerate(vector_top)] | |
df_vector = pd.DataFrame(vector_rank_list) | |
# キーワード検索 | |
tokenized_query = query_tokenize(query) | |
keyword_top = bm25.get_top_n(tokenized_query, content_list, n=n) | |
keyword_rank_list = [{"content":doc, "keyword_rank":i+1} for i, doc in enumerate(keyword_top)] | |
df_keyword = pd.DataFrame(keyword_rank_list) | |
# 順位を結合して表示 | |
df_rank = pd.merge(df_vector, df_keyword, on="content", how="left") | |
df_rank["hybrid_score"] = 1/(df_rank["vector_rank"]+60) + 1/(df_rank["keyword_rank"]+60) | |
df_rank = pd.merge(df_rank, df, on="content", how="left") | |
df_rank.sort_values(by="hybrid_score", ascending=False).head() | |
# 結果をCSVファイルとして保存 | |
df_rank_sorted = df_rank.sort_values(by="hybrid_score", ascending=False) | |
# df_rank_sorted.to_csv("search_rank_results.csv", index=False) | |
search_results = df_rank_sorted["content"].tolist() | |
relevent_texts = ".\n".join([doc for doc in search_results]) | |
context = relevent_texts[:10000] | |
return context, client | |
def research_pdf_hybrid( | |
pdf_url, | |
company_name, | |
question_bank | |
): | |
result_list = [] | |
context, client = find_context(pdf_url) | |
for i in range(0, 5): | |
question = question_bank[i] | |
logger.info(f'Question: {question}') | |
q_prompt = f""" | |
次の資料に、{company_name}が{question}という情報はありますか、具体的な引用を含めて答えてください。 | |
また、明記されている場合のみ1と判断してください。 | |
[出力フォーマット] | |
{{ | |
judge: 0 or 1({question}であれば1,そうでなければ0) | |
reason: "どうしてそう判断したのか具体的に説明してください。" | |
}} | |
===example1=== | |
[内容] | |
{company_name}は{question} | |
[出力] | |
{{ | |
judge: 1 | |
reason: "{company_name}は{question}と明記されているため" | |
}} | |
===example2=== | |
[内容] | |
{company_name}は脱炭素・低炭素に向けたさまざまな取り組みを行っています。 | |
[出力] | |
{{ | |
judge: 0 | |
reason: "特に{question}と明記されていないため" | |
}} | |
""" | |
json_schema = { | |
"type": "object", | |
"properties": { | |
"judge": { | |
"type": "integer", | |
"description": f"{question}であれば1,そうでなければ0" | |
}, | |
"reason": { | |
"type": "string", | |
"description": "どうしてそう判断したのか" | |
} | |
}, | |
"required": ["judge", "reason"] | |
} | |
if i == 4: | |
q_prompt = f""" | |
次の資料に、{company_name}が太陽光発電、風力発電、バイオマスエネルギー以外の再生可能エネルギーを使用している | |
という情報はありますか、具体的な引用を含めて答えてください。明記されている場合のみ1と判断してください。 | |
[出力フォーマット] | |
{{ | |
judge: 0 or 1(上記以外の再生可能エネルギーを使用していれば1,そうでなければ0) | |
reason: "記事に書かれている文章・具体的な取り組みを引用してください。" | |
}}""" | |
ret = generate_answer_(context, q_prompt, json_schema) | |
js = json.loads(ret) | |
logger.info(f"judge output:{js}") | |
result_list.append(js["judge"]) | |
if i == 4: | |
result_list.append(js["reason"]) | |
# その他の具体的な内容の部分を除くため | |
for i in range(6, 10): | |
question = question_bank[i] | |
logger.info(f'Question: {question}') | |
q_prompt = f""" | |
次の資料に、{company_name}が{question}という情報はありますか、具体的な引用を含めて答えてください。 | |
また、明記されている場合のみ1と判断してください。 | |
[出力フォーマット] | |
{{ | |
judge: 0 or 1({question}であれば1,そうでなければ0) | |
reason: "どうしてそう判断したのか具体的に説明してください。" | |
}} | |
===example1=== | |
[内容] | |
{company_name}は{question} | |
[出力] | |
{{ | |
judge: 1 | |
reason: "{company_name}は{question}と明記されているため" | |
}} | |
===example2=== | |
[内容] | |
{company_name}は脱炭素・低炭素に向けたさまざまな取り組みを行っています。 | |
[出力] | |
{{ | |
judge: 0 | |
reason: "特に{question}と明記されていないため" | |
}} | |
""" | |
if i == 9: | |
q_prompt = f""" | |
次の資料に、{company_name}が再生可能エネルギー、電気自動車、暖房や給油などでヒートポンプを利用している以外の方法で | |
エネルギー使用量の削減に取り組んでいるという情報はありますか、具体的な引用を含めて答えてください。 | |
また、明記されている場合のみ1と判断してください。 | |
[出力フォーマット] | |
{{ | |
judge: 0 or 1(上記以外の方法でエネルギー使用量の削減を行っていれば1,そうでなければ0) | |
reason: "記事に書かれている文章・具体的な取り組みを引用してください。" | |
}} | |
""" | |
json_schema = { | |
"type": "object", | |
"properties": { | |
"judge": { | |
"type": "integer", | |
"description": f"{question}であれば1,そうでなければ0" | |
}, | |
"reason": { | |
"type": "string", | |
"description": "どうしてそう判断したのか" | |
} | |
}, | |
"required": ["judge", "reason"] | |
} | |
ret = generate_answer_(context, q_prompt, json_schema) | |
js = json.loads(ret) | |
logger.info(f"judge output:{js}") | |
result_list.append(js["judge"]) | |
if i == 9: | |
result_list.append(js["reason"]) | |
try: | |
client.delete_collection("langchain_store") | |
print("Collection deleted successfully.") | |
except Exception as e: | |
print(f"An error occurred during collection deletion: {e}") | |
return result_list | |
def group1_pdf(pdf_url): | |
context, client = find_context(pdf_url) | |
q_prompt = """次の資料の脱炭素に関連する情報をまとめてください""" | |
json_schema = { | |
"type": "object", | |
"properties": { | |
"answer": { | |
"type": "string", | |
"description": "与えられた資料の脱炭素に関連する情報をまとめてください" | |
}, | |
}, | |
"required": ["answer"] | |
} | |
ret = generate_answer_(context, q_prompt, json_schema) | |
js = json.loads(ret) | |
logger.info(f"judge output:{js}") | |
try: | |
client.delete_collection("langchain_store") | |
except Exception as e: | |
logger.error(f"An error occurred during collection deletion: {e}") | |
return js["answer"] | |