Mya-Mya
Create app.py
8066662
raw
history blame
1.89 kB
from backend import Backend
import frontend
import numpy as np
from pandas import DataFrame
from transformers import BertJapaneseTokenizer, BertModel
import pickle
with open("./himitsudogu_db.pkl", "rb") as file:
himitsudogu_db: dict = pickle.load(file)
class HFBackend(Backend):
def __init__(self):
super().__init__()
self.feature_matrix = himitsudogu_db["feature_matrix_s"][
"sonoisa/sentence-bert-base-ja-mean-tokens-v2"
]
# モデルsonoisa/sentence-bert-base-ja-mean-tokens-v2を使用する
self.tokenizer = BertJapaneseTokenizer.from_pretrained(
"sonoisa/sentence-bert-base-ja-mean-tokens-v2"
)
self.model = BertModel.from_pretrained(
"sonoisa/sentence-bert-base-ja-mean-tokens-v2"
)
def on_submit_button_press(self, query: str) -> DataFrame:
# 文章を形態素解析し、形態素ID列へ変換
tokenized = self.tokenizer(query, return_tensors="pt")
# 言語モデルへ形態素ID列を代入
output = self.model(**tokenized)
# 文章の特徴ベクトルを取得
pooler_output = output["pooler_output"]
query_feature_vector = pooler_output[0].detach().numpy()
# 各ひみつ道具の説明文の特徴ベクトルとの内積を取る
cs_s = self.feature_matrix @ query_feature_vector
# 内積が大きかったもの順にひみつ道具を表示するようにする
ranked_index_s = np.argsort(cs_s)[::-1]
output = DataFrame(columns=["類似度", "名前", "説明"])
for rank, i in enumerate(ranked_index_s[:20], 1):
output.loc[rank] = [
cs_s[i],
himitsudogu_db["name_s"][i],
himitsudogu_db["description_s"][i],
]
return output
frontend.launch_frontend(backend=HFBackend())