Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| from typing import Dict | |
| from typing import List | |
| from typing import Tuple | |
| from typing import Union | |
| from pathlib import Path | |
| from src.logger import LoggerFactory | |
| from src.prompt_concat import GetManualTestSamples, CreateTestDataset | |
| from src.utils import decode_csv_to_json, load_json, save_to_json | |
| from threading import Thread | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| GenerationConfig, | |
| TextIteratorStreamer, | |
| ) | |
| from typing import List | |
| import gradio as gr | |
| import logging | |
| import os | |
| import shutil | |
| import torch | |
| import warnings | |
| import random | |
| import spaces | |
| logger = LoggerFactory.create_logger(name="test", level=logging.INFO) | |
| warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') | |
| MODEL_PATH = os.environ.get('MODEL_PATH', 'IndexTeam/Index-1.9B-Character') | |
| TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH) | |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16, device_map="auto", | |
| trust_remote_code=True) | |
| character_path = "./character" | |
| def _resolve_path(path: Union[str, Path]) -> Path: | |
| return Path(path).expanduser().resolve() | |
| # logger = LoggerFactory.create_logger(name="test", level=logging.INFO) | |
| # warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') | |
| # config_data = load_json("config/config.json") | |
| # model_path = config_data["huggingface_local_path"] | |
| # character_path = "./character" | |
| # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
| # model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto", | |
| # trust_remote_code=True) | |
| def generate_with_question(question, role_name, role_file_path): | |
| question_in = "\n".join(["\n".join(pair) for pair in question]) | |
| g = GetManualTestSamples( | |
| role_name=role_name, | |
| role_data_path=f"./character/{role_file_path}.json", | |
| save_samples_dir="./character", | |
| save_samples_path= role_file_path + "_rag.json", | |
| prompt_path="./prompt/dataset_character.txt", | |
| max_seq_len=4000 | |
| ) | |
| g.get_qa_samples_by_query( | |
| questions_query=question_in, | |
| keep_retrieve_results_flag=True | |
| ) | |
| def create_datasets(role_name, role_file_path): | |
| testset = [] | |
| role_samples_path = os.path.join("./character", role_file_path + "_rag.json") | |
| c = CreateTestDataset(role_name=role_name, | |
| role_samples_path=role_samples_path, | |
| role_data_path=role_samples_path, | |
| prompt_path="./prompt/dataset_character.txt" | |
| ) | |
| res = c.load_samples() | |
| testset.extend(res) | |
| save_to_json(testset, f"./character/{role_file_path}_测试问题.json") | |
| def hf_gen(dialog: List, role_name, role_file_path, top_k, top_p, temperature, repetition_penalty, max_dec_len): | |
| generate_with_question(dialog, role_name,role_file_path) | |
| create_datasets(role_name,role_file_path) | |
| json_data = load_json(f"{character_path}/{role_file_path}_测试问题.json")[0] | |
| text = json_data["input_text"] | |
| inputs = tokenizer(text, return_tensors="pt") | |
| if torch.cuda.is_available(): | |
| model.to("cuda") | |
| inputs.to("cuda") | |
| streamer = TextIteratorStreamer(tokenizer, **tokenizer.init_kwargs) | |
| generation_kwargs = dict( | |
| inputs, | |
| do_sample=True, | |
| top_k=int(top_k), | |
| top_p=float(top_p), | |
| temperature=float(temperature), | |
| repetition_penalty=float(repetition_penalty), | |
| max_new_tokens=int(max_dec_len), | |
| pad_token_id=tokenizer.eos_token_id, | |
| streamer=streamer, | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| answer = "" | |
| for new_text in streamer: | |
| answer += new_text | |
| yield answer[len(text):] | |
| def generate(chat_history: List, query, role_name, role_desc, role_file_path, top_k, top_p, temperature, repetition_penalty, max_dec_len): | |
| """generate after hitting "submit" button | |
| Args: | |
| chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records | |
| query (str): query of current round | |
| top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. | |
| temperature (float): strictly positive float value used to modulate the logits distribution. | |
| max_dec_len (int): The maximum numbers of tokens to generate. | |
| Yields: | |
| List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round. | |
| """ | |
| assert query != "", "Input must not be empty!!!" | |
| # apply chat template | |
| chat_history.append([f"user:{query}", ""]) | |
| if role_name == "三三": | |
| role_file_path = "三三" | |
| for answer in hf_gen(chat_history, role_name,role_file_path, top_k, top_p, temperature, repetition_penalty, max_dec_len): | |
| chat_history[-1][1] = role_name + ":" + answer | |
| yield gr.update(value=""), chat_history | |
| def regenerate(chat_history: List,role_name, role_description, role_file_path, top_k, top_p, temperature, repetition_penalty, max_dec_len): | |
| """re-generate the answer of last round's query | |
| Args: | |
| chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records | |
| top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. | |
| temperature (float): strictly positive float value used to modulate the logits distribution. | |
| max_dec_len (int): The maximum numbers of tokens to generate. | |
| Yields: | |
| List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history | |
| """ | |
| assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!" | |
| if len(chat_history[-1]) > 1: | |
| chat_history[-1][1] = "" | |
| # apply chat template | |
| if role_name == "三三": | |
| role_file_path = "三三" | |
| for answer in hf_gen(chat_history, role_name,role_file_path, top_k, top_p, temperature, repetition_penalty, max_dec_len): | |
| chat_history[-1][1] = role_name + ":" + answer | |
| yield gr.update(value=""), chat_history | |
| def clear_history(): | |
| """clear all chat history | |
| Returns: | |
| List: empty chat history | |
| """ | |
| torch.cuda.empty_cache() | |
| return [] | |
| def delete_current_user(user_role_path): | |
| try: | |
| role_upload_path = os.path.join(character_path, user_role_path + ".csv") | |
| role_path = os.path.join(character_path, user_role_path + ".json") | |
| rag_path = os.path.join(character_path, user_role_path + "_rag.json") | |
| question_path = os.path.join(character_path, user_role_path + "_测试问题.json") | |
| files_to_delete = [role_upload_path, role_path, rag_path, question_path] | |
| for file_path in files_to_delete: | |
| os.remove(file_path) | |
| except Exception as e: | |
| print(e) | |
| # launch gradio demo | |
| with gr.Blocks(theme="soft") as demo: | |
| gr.Markdown("""# Index-1.9B RolePlay Gradio Demo""") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| top_k = gr.Slider(0, 10, value=5, step=1, label="top_k") | |
| top_p = gr.Slider(0, 1, value=0.8, step=0.8, label="top_p") | |
| temperature = gr.Slider(0.1, 2.0, value=0.85, step=0.1, label="temperature") | |
| repetition_penalty = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="repetition_penalty") | |
| max_dec_len = gr.Slider(1, 4096, value=512, step=1, label="max_dec_len") | |
| file_input = gr.File(label="上传角色对话语料(.csv)") | |
| role_description = gr.Textbox(label="您创建的角色描述", placeholder="输入角色描述", lines=2) | |
| upload_button = gr.Button("生成角色!") | |
| new_path = gr.State() | |
| def generate_file(file_obj, role_info): | |
| random.seed() | |
| alphabet = 'abcdefghijklmnopqrstuvwxyz!@#$%^&*()' | |
| random_char = "".join(random.choice(alphabet) for _ in range(10)) | |
| role_name = os.path.basename(file_obj).split(".")[0] | |
| new_path = role_name + random_char | |
| new_save_path = os.path.join(character_path, new_path+".csv") | |
| shutil.copy(file_obj, new_save_path) | |
| new_file_path = os.path.join(character_path, new_path) | |
| decode_csv_to_json(os.path.join(character_path, new_path + ".csv"), role_name, role_info, | |
| new_file_path + ".json" ) | |
| gr.Info(f"{role_name}生成成功") | |
| return new_path | |
| upload_button.click(generate_file, inputs=[file_input, role_description],outputs=new_path) | |
| with gr.Column(scale=10): | |
| chatbot = gr.Chatbot(bubble_full_width=False, height=400, label='Index-1.9B') | |
| with gr.Row(): | |
| role_name = gr.Textbox(label="对话的角色名字", value="三三", placeholder="如果您没有创建角色,可以直接输入三三。如果已经创建好了对应的角色,请在这里输入角色的名称!", lines=2) | |
| user_input = gr.Textbox(label="用户问题", placeholder="输入你的问题!", lines=2) | |
| with gr.Row(): | |
| submit = gr.Button("🚀 Submit") | |
| clear = gr.Button("🧹 Clear") | |
| regen = gr.Button("🔄 Regenerate") | |
| submit.click(generate, inputs=[chatbot, user_input, role_name, role_description, new_path, top_k, top_p, temperature, | |
| repetition_penalty, max_dec_len], | |
| outputs=[user_input, chatbot]) | |
| regen.click(regenerate, | |
| inputs=[chatbot, role_name, role_description, new_path, top_k, top_p, temperature, repetition_penalty, | |
| max_dec_len], | |
| outputs=[user_input, chatbot]) | |
| clear.click(clear_history, inputs=[], outputs=[chatbot]) | |
| demo.queue().launch() |