gemma-2-9b-it-1e-cot_lora / custom_functions.py
aolans's picture
Update custom_functions.py
4a49db9 verified
"""
このスクリプトは、自然言語処理モデルを使用してユーザーからの入力に対する応答を生成するためのユーティリティ関数を提供します。
特に、Chain of Thought (CoT) 技術を活用し、モデルが論理的かつ詳細な思考過程を踏まえた応答を生成できるよう設計されています。
作成者: aolans
日付: 2024/12/17
"""
def generate_cot_one( model, tokenizer, input: str ) -> str:
"""
inputに対して、CoTプロンプトを作成の後推論することで回答精度を上げている
引数:
model : モデル
tokenizer : トークナイザ
input (str): ユーザーからの入力文字列。
戻り値:
output_str (str): 応答内容(思考の過程を含む)
"""
# CoT用systemプロンプト
prompt = f"""
### 指示:
あなたは優秀で論理的なアシスタントです。
ユーザーから指定が無い限り、日本語で記載してください。
1. まずは<Thought></Thought>タグの中であなたの思考の過程を抜けがないように記載していきます。ユーザーからの指示に対して、ステップごとに詳細を詰めていきます。ここでは最終的にユーザーに提供するべき情報がすべて記載されるべきです。思考を進めていく中で、定期的に過去の思考過程を見直し、内容や方向性の修正を行います。
2. <Output></Output>タグの中に最終的にユーザーに提供する出力を記載します。ユーザーは<Output></Output>の中だけ見ることになるので、回答として必要となる情報はすべて記載されるべきです。 タグは必ず閉じなければなりません。
### 入力:
{input}
### 思考:
<Thought>"""
# *** 推論 ***
input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**input_ids, max_new_tokens=1024, do_sample=False, repetition_penalty=1.2,)
prediction = tokenizer.decode(outputs[0][input_ids.input_ids.size(1):], skip_special_tokens=True)
# *** 不要文字除去 ***
# <Thought></Thought>除去
output_str = remove_tags( prediction, "Thought" )
# <Output></Output>除去
output_str = remove_tags( output_str, "Output" )
# 「### 応答:」を変換
output_str = output_str.replace( "### 応答:", "【回答】\n" )
return output_str
def generate_cot_two( model, tokenizer, input: str ) -> str:
"""
inputに対するCoTプロンプトを適用し回答精度を上げる
「思考過程」と「回答」の2回に分けて推論する。
思考の過程と応用内容を別々に取得可能
引数:
model : モデル
tokenizer : トークナイザ
input (str): ユーザーからの入力文字列。
戻り値:
thought (str): 思考の過程
answer (str): 応答内容
"""
# ***************** 思考過程作成 **********************
prompt_1 = f"""
### 指示:
あなたは優秀で論理的なアシスタントです。
<Thought></Thought>タグの中であなたの思考の過程を抜けがないように記載していきます。
ユーザーからの指示に対して、ステップごとに詳細を詰めていきます。
ここでは最終的にユーザーに提供するべき情報がすべて記載されるべきです。
思考を進めていく中で、定期的に過去の思考過程を見直し、内容や方向性の修正を行います。
入力に対する直接の回答は避けて、思考の過程のみを記載してください。
日本語で記載しなさい。
### 入力:
{input}
### 思考:
<Thought>"""
# *** 推論(思考過程) ***
input_ids_1 = tokenizer(prompt_1, return_tensors="pt").to(model.device)
outputs_1 = model.generate(**input_ids_1, max_new_tokens=1024, do_sample=False, repetition_penalty=1.2,)
thought = tokenizer.decode(outputs_1[0][input_ids_1.input_ids.size(1):], skip_special_tokens=True)
# ***************** 応答作成 **********************
prompt_2 = f"""
### 指示:
あなたは優秀で論理的なアシスタントです。
<Thought></Thought>タグの内容を参考にして、<Output></Output>タグの中に最終的にユーザーに提供する出力を記載します。
ユーザーは<Output></Output>の中だけ見ることになるので、回答として必要となる情報はすべて記載されるべきです。
タグは必ず閉じなければなりません。
ユーザーから指定が無い限り、基本的に日本語で記載しなさい。
### 入力:
{input}
<out>
### 思考:
{thought}
### 応答:
<Output>"""
# *** 推論(応答) ***
input_ids_2 = tokenizer(prompt_2, return_tensors="pt").to(model.device)
outputs_2 = model.generate(**input_ids_2, max_new_tokens=1024, do_sample=False, repetition_penalty=1.2,)
answers = tokenizer.decode(outputs_2[0][input_ids_2.input_ids.size(1):], skip_special_tokens=True)
# *** 不要文字除去 ***
# <Thought></Thought>除去
thought = remove_tags( thought, "Thought" )
# <Output></Output>除去
answers = remove_tags( answers, "Output" )
output = tokenizer.decode(outputs_2[0], skip_special_tokens=True).split('<out>')[-1]
# <Thought></Thought>除去
output = remove_tags( output, "Thought" )
# <Output></Output>除去
output = remove_tags( output, "Output" )
# 「### 思考:」を変換
output = output.replace( "### 思考:", "【思考の過程】" )
# 「### 応答:」を変換
output = output.replace( "### 応答:", "【回答】\n" )
return output, answers, thought
def generate_simple( model, tokenizer, input: str ) -> str:
"""
単純なプロンプトで推論。 通常はこちらを選択
引数:
model : モデル
tokenizer : トークナイザ
input (str): ユーザーからの入力文字列。
戻り値:
output (str): 応答内容
"""
prompt = f"""### 指示
{input}
### 回答
"""
# *** 推論 ***
input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**input_ids, max_new_tokens=1024, do_sample=False, repetition_penalty=1.2,)
prediction = tokenizer.decode(outputs[0][input_ids.input_ids.size(1):], skip_special_tokens=True)
return prediction
def remove_tags( str_in_tag: str, tag_name: str ) -> str:
"""
出力文字列から<Thought>および<Output>タグを削除する関数です。
引数:
str_in_tag (str): タグを削除したい文字列。
tag_name (str): タグの名前(例: <Output><\Output> ⇒ Output )
戻り値:
str_removed (str): タグが削除された文字列。
"""
str_removed = str_in_tag.replace(f"\n<{tag_name}>\n", "\n").replace(f"\n<{tag_name}>", "\n").replace(f"<{tag_name}>\n", "\n").replace(f"<{tag_name}>", "")
str_removed = str_removed.replace(f"\n</{tag_name}>\n", "\n").replace(f"\n</{tag_name}>", "\n").replace(f"</{tag_name}>\n", "\n").replace(f"</{tag_name}>", "\n")
return str_removed