| import torch | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from transformers import TextIteratorStreamer | |
| import threading | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "RWKV-Red-Team/ARWKV-7B-Preview-0.1", | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "RWKV-Red-Team/ARWKV-7B-Preview-0.1" | |
| ) | |
| device = "cuda" | |
| def convert_history_to_messages(history): | |
| messages = [] | |
| for user_msg, bot_msg in history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if bot_msg is not None: | |
| messages.append({"role": "assistant", "content": bot_msg}) | |
| return messages | |
| def stream_chat(prompt, history): | |
| messages = convert_history_to_messages(history) | |
| messages.append({"role": "user", "content": prompt}) | |
| text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(device) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| generation_kwargs = dict( | |
| model_inputs, | |
| streamer=streamer, | |
| max_new_tokens=4096, | |
| do_sample=True, | |
| temperature=1.5, | |
| top_p=0.2, | |
| top_k=0, | |
| ) | |
| thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| response = "" | |
| for new_text in streamer: | |
| response += new_text | |
| yield history + [(prompt, response)] | |
| with gr.Blocks() as demo: | |
| chatbot = gr.Chatbot(label="Chat with LLM", height=750) | |
| msg = gr.Textbox(label="Your Message") | |
| clear = gr.Button("Clear Chat") | |
| def user(user_message, history): | |
| return "", history + [[user_message, None]] | |
| def bot(history): | |
| prompt = history[-1][0] | |
| history[-1][1] = "" | |
| for updated_history in stream_chat(prompt, history[:-1]): | |
| yield updated_history | |
| msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
| bot, chatbot, chatbot | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| demo.queue().launch(server_name="0.0.0.0") | |