import torch import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import TextIteratorStreamer import threading model = AutoModelForCausalLM.from_pretrained( "RWKV-Red-Team/ARWKV-7B-Preview-0.1", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained( "RWKV-Red-Team/ARWKV-7B-Preview-0.1" ) device = "cuda" def convert_history_to_messages(history): messages = [] for user_msg, bot_msg in history: messages.append({"role": "user", "content": user_msg}) if bot_msg is not None: messages.append({"role": "assistant", "content": bot_msg}) return messages def stream_chat(prompt, history): messages = convert_history_to_messages(history) messages.append({"role": "user", "content": prompt}) text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) model_inputs = tokenizer([text], return_tensors="pt").to(device) streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) generation_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=4096, do_sample=True, temperature=1.5, top_p=0.2, top_k=0, ) thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) thread.start() response = "" for new_text in streamer: response += new_text yield history + [(prompt, response)] with gr.Blocks() as demo: chatbot = gr.Chatbot(label="Chat with LLM", height=750) msg = gr.Textbox(label="Your Message") clear = gr.Button("Clear Chat") def user(user_message, history): return "", history + [[user_message, None]] def bot(history): prompt = history[-1][0] history[-1][1] = "" for updated_history in stream_chat(prompt, history[:-1]): yield updated_history msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, chatbot, chatbot ) clear.click(lambda: None, None, chatbot, queue=False) demo.queue().launch(server_name="0.0.0.0")