Spaces:
Runtime error
Runtime error
Upload 5 files
Browse files- app.py +240 -0
- custom.html +12 -0
- requirements.txt +2 -0
- test_prompt.jinja2 +22 -0
- utils/dl_utils.py +19 -0
app.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from jinja2 import Template
|
| 3 |
+
from llama_cpp import Llama
|
| 4 |
+
import os
|
| 5 |
+
import configparser
|
| 6 |
+
from utils.dl_utils import dl_guff_model
|
| 7 |
+
|
| 8 |
+
# モデルディレクトリが存在しない場合は作成
|
| 9 |
+
if not os.path.exists("models"):
|
| 10 |
+
os.makedirs("models")
|
| 11 |
+
|
| 12 |
+
# 使用するモデルのファイル名を指定
|
| 13 |
+
model_filename = "Llama-3.1-70B-EZO-1.1-it-Q4_K_M.gguf"
|
| 14 |
+
model_path = os.path.join("models", model_filename)
|
| 15 |
+
|
| 16 |
+
# モデルファイルが存在しない場合はダウンロード
|
| 17 |
+
if not os.path.exists(model_path):
|
| 18 |
+
dl_guff_model("models", f"https://huggingface.co/mmnga/Llama-3.1-70B-EZO-1.1-it-gguf/resolve/main/{model_filename}")
|
| 19 |
+
|
| 20 |
+
# 設定をINIファイルに保存する関数
|
| 21 |
+
def save_settings_to_ini(settings, filename='character_settings.ini'):
|
| 22 |
+
config = configparser.ConfigParser()
|
| 23 |
+
config['Settings'] = {
|
| 24 |
+
'name': settings['name'],
|
| 25 |
+
'gender': settings['gender'],
|
| 26 |
+
'situation': '\n'.join(settings['situation']),
|
| 27 |
+
'orders': '\n'.join(settings['orders']),
|
| 28 |
+
'dirty_talk_list': '\n'.join(settings['dirty_talk_list']),
|
| 29 |
+
'example_quotes': '\n'.join(settings['example_quotes'])
|
| 30 |
+
}
|
| 31 |
+
with open(filename, 'w', encoding='utf-8') as configfile:
|
| 32 |
+
config.write(configfile)
|
| 33 |
+
|
| 34 |
+
# INIファイルから設定を読み込む関数
|
| 35 |
+
def load_settings_from_ini(filename='character_settings.ini'):
|
| 36 |
+
if not os.path.exists(filename):
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
config = configparser.ConfigParser()
|
| 40 |
+
config.read(filename, encoding='utf-8')
|
| 41 |
+
|
| 42 |
+
if 'Settings' not in config:
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
settings = {
|
| 47 |
+
'name': config['Settings']['name'],
|
| 48 |
+
'gender': config['Settings']['gender'],
|
| 49 |
+
'situation': config['Settings']['situation'].split('\n'),
|
| 50 |
+
'orders': config['Settings']['orders'].split('\n'),
|
| 51 |
+
'dirty_talk_list': config['Settings']['dirty_talk_list'].split('\n'),
|
| 52 |
+
'example_quotes': config['Settings']['example_quotes'].split('\n')
|
| 53 |
+
}
|
| 54 |
+
return settings
|
| 55 |
+
except KeyError:
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
# LlamaCppのラッパークラス
|
| 59 |
+
class LlamaCppAdapter:
|
| 60 |
+
def __init__(self, model_path, n_ctx=4096):
|
| 61 |
+
print(f"モデルの初期化: {model_path}")
|
| 62 |
+
self.llama = Llama(model_path=model_path, n_ctx=n_ctx, n_gpu_layers=-1)
|
| 63 |
+
|
| 64 |
+
def generate(self, prompt, max_new_tokens=4096, temperature=0.5, top_p=0.7, top_k=80, stop=["<END>"]):
|
| 65 |
+
return self._generate(prompt, max_new_tokens, temperature, top_p, top_k, stop)
|
| 66 |
+
|
| 67 |
+
def _generate(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int, stop: list):
|
| 68 |
+
return self.llama(
|
| 69 |
+
prompt,
|
| 70 |
+
temperature=temperature,
|
| 71 |
+
max_tokens=max_new_tokens,
|
| 72 |
+
top_p=top_p,
|
| 73 |
+
top_k=top_k,
|
| 74 |
+
stop=stop,
|
| 75 |
+
repeat_penalty=1.2,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# キャラクターメーカークラス
|
| 79 |
+
class CharacterMaker:
|
| 80 |
+
def __init__(self):
|
| 81 |
+
self.llama = LlamaCppAdapter(model_path)
|
| 82 |
+
self.history = []
|
| 83 |
+
self.settings = load_settings_from_ini()
|
| 84 |
+
if not self.settings:
|
| 85 |
+
self.settings = {
|
| 86 |
+
"name": "ナツ",
|
| 87 |
+
"gender": "女性",
|
| 88 |
+
"situation": [
|
| 89 |
+
"あなたは人工知能アシスタントです。",
|
| 90 |
+
"ユーザーの日常生活をサポートし、より良い生活を送るお手伝いをします。",
|
| 91 |
+
"AIアシスタント『ナツ』として、ユーザーの健康と幸福をケアし、様々な質問に答えたり課題解決を手伝ったりします。"
|
| 92 |
+
],
|
| 93 |
+
"orders": [
|
| 94 |
+
"丁寧な言葉遣いを心がけてください。",
|
| 95 |
+
"ユーザーとの対話を通じてサポートを提供します。",
|
| 96 |
+
"ユーザーのことは『ユーザー様』と呼んでください。"
|
| 97 |
+
],
|
| 98 |
+
"conversation_topics": [
|
| 99 |
+
"健康管理",
|
| 100 |
+
"目標設定",
|
| 101 |
+
"時間管理"
|
| 102 |
+
],
|
| 103 |
+
"example_quotes": [
|
| 104 |
+
"ユーザー様の健康と幸福が何より大切です。どのようなサポートが必要でしょうか?",
|
| 105 |
+
"私はユーザー様の生活をより良いものにするためのアシスタントです。お手伝いできることがありましたらお申し付けください。",
|
| 106 |
+
"目標達成に向けて一緒に頑張りましょう。具体的な計画を立てるお手伝いをさせていただきます。",
|
| 107 |
+
"効率的な時間管理のコツをお教えします。まずは1日のスケジュールを確認してみましょう。",
|
| 108 |
+
"ストレス解消法についてアドバイスいたします。リラックスするための簡単な呼吸法から始めてみませんか?"
|
| 109 |
+
]
|
| 110 |
+
}
|
| 111 |
+
save_settings_to_ini(self.settings)
|
| 112 |
+
|
| 113 |
+
def make(self, input_str: str):
|
| 114 |
+
prompt = self._generate_aki(input_str)
|
| 115 |
+
print(prompt)
|
| 116 |
+
print("-----------------")
|
| 117 |
+
res = self.llama.generate(prompt, max_new_tokens=1000, stop=["<END>", "\n"])
|
| 118 |
+
res_text = res["choices"][0]["text"]
|
| 119 |
+
self.history.append({"user": input_str, "assistant": res_text})
|
| 120 |
+
return res_text
|
| 121 |
+
|
| 122 |
+
def make_prompt(self, name: str, gender: str, situation: list, orders: list, dirty_talk_list: list, example_quotes: list, input_str: str):
|
| 123 |
+
with open('test_prompt.jinja2', 'r', encoding='utf-8') as f:
|
| 124 |
+
prompt = f.readlines()
|
| 125 |
+
fix_example_quotes = [quote+"<END>" for quote in example_quotes]
|
| 126 |
+
prompt = "".join(prompt)
|
| 127 |
+
prompt = Template(prompt).render(name=name, gender=gender, situation=situation, orders=orders, dirty_talk_list=dirty_talk_list, example_quotes=fix_example_quotes, histories=self.history, input_str=input_str)
|
| 128 |
+
return prompt
|
| 129 |
+
|
| 130 |
+
def _generate_aki(self, input_str: str):
|
| 131 |
+
prompt = self.make_prompt(
|
| 132 |
+
self.settings["name"],
|
| 133 |
+
self.settings["gender"],
|
| 134 |
+
self.settings["situation"],
|
| 135 |
+
self.settings["orders"],
|
| 136 |
+
self.settings["dirty_talk_list"],
|
| 137 |
+
self.settings["example_quotes"],
|
| 138 |
+
input_str
|
| 139 |
+
)
|
| 140 |
+
print(prompt)
|
| 141 |
+
return prompt
|
| 142 |
+
|
| 143 |
+
def update_settings(self, new_settings):
|
| 144 |
+
self.settings.update(new_settings)
|
| 145 |
+
save_settings_to_ini(self.settings)
|
| 146 |
+
|
| 147 |
+
def reset(self):
|
| 148 |
+
self.history = []
|
| 149 |
+
self.llama = LlamaCppAdapter(model_path)
|
| 150 |
+
|
| 151 |
+
character_maker = CharacterMaker()
|
| 152 |
+
|
| 153 |
+
# 設定を更新する関数
|
| 154 |
+
def update_settings(name, gender, situation, orders, dirty_talk_list, example_quotes):
|
| 155 |
+
new_settings = {
|
| 156 |
+
"name": name,
|
| 157 |
+
"gender": gender,
|
| 158 |
+
"situation": [s.strip() for s in situation.split('\n') if s.strip()],
|
| 159 |
+
"orders": [o.strip() for o in orders.split('\n') if o.strip()],
|
| 160 |
+
"dirty_talk_list": [d.strip() for d in dirty_talk_list.split('\n') if d.strip()],
|
| 161 |
+
"example_quotes": [e.strip() for e in example_quotes.split('\n') if e.strip()]
|
| 162 |
+
}
|
| 163 |
+
character_maker.update_settings(new_settings)
|
| 164 |
+
return "設定が更新されました。"
|
| 165 |
+
|
| 166 |
+
# チャット機能の関数
|
| 167 |
+
def chat_with_character(message, history):
|
| 168 |
+
character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history]
|
| 169 |
+
response = character_maker.make(message)
|
| 170 |
+
return response
|
| 171 |
+
|
| 172 |
+
# チャットをクリアする関数
|
| 173 |
+
def clear_chat():
|
| 174 |
+
character_maker.reset()
|
| 175 |
+
return []
|
| 176 |
+
|
| 177 |
+
# カスタムCSS
|
| 178 |
+
custom_css = """
|
| 179 |
+
#chatbot {
|
| 180 |
+
height: 60vh !important;
|
| 181 |
+
overflow-y: auto;
|
| 182 |
+
}
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
# カスタムJavaScript(HTML内に埋め込む)
|
| 186 |
+
custom_js = """
|
| 187 |
+
<script>
|
| 188 |
+
function adjustChatbotHeight() {
|
| 189 |
+
var chatbot = document.querySelector('#chatbot');
|
| 190 |
+
if (chatbot) {
|
| 191 |
+
chatbot.style.height = window.innerHeight * 0.6 + 'px';
|
| 192 |
+
}
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
// ページ読み込み時と画面サイズ変更時にチャットボットの高さを調整
|
| 196 |
+
window.addEventListener('load', adjustChatbotHeight);
|
| 197 |
+
window.addEventListener('resize', adjustChatbotHeight);
|
| 198 |
+
</script>
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
# Gradioインターフェースの設定
|
| 202 |
+
with gr.Blocks(css=custom_css) as iface:
|
| 203 |
+
chatbot = gr.Chatbot(elem_id="chatbot")
|
| 204 |
+
|
| 205 |
+
with gr.Tab("チャット"):
|
| 206 |
+
gr.ChatInterface(
|
| 207 |
+
chat_with_character,
|
| 208 |
+
chatbot=chatbot,
|
| 209 |
+
textbox=gr.Textbox(placeholder="メッセージを入力してください...", container=False, scale=7),
|
| 210 |
+
theme="soft",
|
| 211 |
+
retry_btn="もう一度生成",
|
| 212 |
+
undo_btn="前のメッセージを取り消す",
|
| 213 |
+
clear_btn="チャットをクリア",
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
with gr.Tab("設定"):
|
| 217 |
+
gr.Markdown("## キャラクター設定")
|
| 218 |
+
name_input = gr.Textbox(label="名前", value=character_maker.settings["name"])
|
| 219 |
+
gender_input = gr.Textbox(label="性別", value=character_maker.settings["gender"])
|
| 220 |
+
situation_input = gr.Textbox(label="状況設定", value="\n".join(character_maker.settings["situation"]), lines=5)
|
| 221 |
+
orders_input = gr.Textbox(label="指示", value="\n".join(character_maker.settings["orders"]), lines=5)
|
| 222 |
+
dirty_talk_input = gr.Textbox(label="淫語リスト", value="\n".join(character_maker.settings["dirty_talk_list"]), lines=5)
|
| 223 |
+
example_quotes_input = gr.Textbox(label="例文", value="\n".join(character_maker.settings["example_quotes"]), lines=5)
|
| 224 |
+
|
| 225 |
+
update_button = gr.Button("設定を更新")
|
| 226 |
+
update_output = gr.Textbox(label="更新状態")
|
| 227 |
+
|
| 228 |
+
update_button.click(
|
| 229 |
+
update_settings,
|
| 230 |
+
inputs=[name_input, gender_input, situation_input, orders_input, dirty_talk_input, example_quotes_input],
|
| 231 |
+
outputs=[update_output]
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Gradioアプリの起動
|
| 235 |
+
if __name__ == "__main__":
|
| 236 |
+
iface.launch(
|
| 237 |
+
share=True,
|
| 238 |
+
allowed_paths=["models"],
|
| 239 |
+
favicon_path="custom.html"
|
| 240 |
+
)
|
custom.html
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<script>
|
| 2 |
+
function adjustChatbotHeight() {
|
| 3 |
+
var chatbot = document.querySelector('#chatbot');
|
| 4 |
+
if (chatbot) {
|
| 5 |
+
chatbot.style.height = window.innerHeight * 0.6 + 'px';
|
| 6 |
+
}
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
// ページ読み込み時と画面サイズ変更時にチャットボットの高さを調整
|
| 10 |
+
window.addEventListener('load', adjustChatbotHeight);
|
| 11 |
+
window.addEventListener('resize', adjustChatbotHeight);
|
| 12 |
+
</script>
|
requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
https://github.com/abetlen/llama-cpp-python/releases/download/v0.2.81-cu124/llama_cpp_python-0.2.81-cp310-cp310-linux_x86_64.whl
|
test_prompt.jinja2
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
・キャラクター設定
|
| 2 |
+
名前:{{name}}
|
| 3 |
+
性別:{{gender}}
|
| 4 |
+
|
| 5 |
+
{%for situation in situation %}
|
| 6 |
+
{{situation}}{%endfor%}
|
| 7 |
+
|
| 8 |
+
・今回のユーザーのオーダー
|
| 9 |
+
{%for order in orders %}
|
| 10 |
+
{{order}}{%endfor%}
|
| 11 |
+
|
| 12 |
+
・使ってほしい淫語表現
|
| 13 |
+
{%for dirty_talk in dirty_talk_list %}
|
| 14 |
+
{{dirty_talk}}{%endfor%}
|
| 15 |
+
・キャラクターの発言例
|
| 16 |
+
{%for example_quote in example_quotes %}
|
| 17 |
+
{{example_quote}}{%endfor%}
|
| 18 |
+
|
| 19 |
+
{%for history in histories %}user: {{history.user}}
|
| 20 |
+
{{name}}: {{history.assistant}}{%endfor%}
|
| 21 |
+
user: {{input_str}}
|
| 22 |
+
{{name}}:
|
utils/dl_utils.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def dl_guff_model(model_dir, url):
|
| 7 |
+
file_name = url.split('/')[-1]
|
| 8 |
+
folder = model_dir
|
| 9 |
+
file_path = os.path.join(folder, file_name)
|
| 10 |
+
if not os.path.exists(file_path):
|
| 11 |
+
response = requests.get(url, allow_redirects=True)
|
| 12 |
+
if response.status_code == 200:
|
| 13 |
+
with open(file_path, 'wb') as f:
|
| 14 |
+
f.write(response.content)
|
| 15 |
+
print(f'Downloaded {file_name}')
|
| 16 |
+
else:
|
| 17 |
+
print(f'Failed to download {file_name}')
|
| 18 |
+
else:
|
| 19 |
+
print(f'{file_name} already exists.')
|