import gradio as gr import hf_transfer from transformers import AutoModelForCausalLM, AutoTokenizer,StoppingCriteriaList,TextIteratorStreamer from threading import Thread import os HFTOKEN=os.getenv("hftoken") model = AutoModelForCausalLM.from_pretrained( "kubernetes-bad/chargen-v2", token = HFTOKEN ) tknz=AutoTokenizer.from_pretrained("kubernetes-bad/chargen-v2",token=HFTOKEN) """ For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference """ def respond( message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, ): messages = [{"role": "system", "content": system_message}] for val in history: if val[0]: messages.append({"role": "user", "content": val[0]}) if val[1]: messages.append({"role": "assistant", "content": val[1]}) messages.append({"role": "user", "content": message}) response = "" model_inputs = tokenizer.build_chat_input(history=messages, role='user').input_ids.to( next(model.parameters()).device) streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True) eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), tokenizer.get_command("<|observation|>")] generate_kwargs = { "input_ids": model_inputs, "streamer": streamer, "max_new_tokens": max_tokens, "do_sample": True, "top_p": top_p, "temperature": temperature, "stopping_criteria": StoppingCriteriaList([stop]), "repetition_penalty": 1, "eos_token_id": eos_token_id, } t = Thread(target=model.generate, kwargs=generate_kwargs) for new_token in streamer: if new_token and '<|user|>' in new_token: new_token = new_token.split('<|user|>')[0] if new_token: history[-1][1] += new_token yield history """ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface """ js_func = """ function refresh() { const url = new URL(window.location); if (url.searchParams.get('__theme') !== 'dark') { url.searchParams.set('__theme', 'dark'); window.location.href = url.href; } } """ app = gr.ChatInterface( respond, js=js_func, additional_inputs=[ gr.Textbox(value="You are a bot who generates perfect roleplaying charecters.", label="System message"), gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", ), ], ) if __name__ == "__main__": app.launch()