chat-llm / app.py
Threatthriver's picture
Update app.py
411e510 verified
import gradio as gr
import torch
import gc
import threading
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
try:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", torch_dtype=torch.float16, device_map="auto")
device = model.device #Get device automatically
print(f"Model loaded on {device}")
except Exception as e:
print(f"Error loading model: {e}")
exit(1)
def clean_memory():
while True:
gc.collect()
if device.type == 'cuda': #Check device type explicitly
torch.cuda.empty_cache()
time.sleep(1)
cleanup_thread = threading.Thread(target=clean_memory, daemon=True)
cleanup_thread.start()
def generate_response(message, history, max_tokens, temperature, top_p):
try:
system_message = "You are a helpful and friendly AI assistant."
prompt = system_message + "\n" + "".join([f"{speaker}: {text}\n" for speaker, text in history] + [f"User: {message}\n"])
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
generated_text = ""
with torch.no_grad():
for token_id in tqdm(model.generate(input_ids, max_length=min(input_ids.shape[-1] + max_tokens, 2048), temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id, stream=True)): # Added max length to prevent excessive generation
generated_text = tokenizer.decode(token_id, skip_special_tokens=True)
yield generated_text
except Exception as e:
yield f"Error generating response: {e}"
def update_chatbox(history, message, max_tokens, temperature, top_p):
history.append(("User", message))
for response_chunk in generate_response(message, history, max_tokens, temperature, top_p):
yield history, response_chunk
response = response_chunk.strip()
history.append(("AI", response))
yield history, ""
with gr.Blocks(css=".gradio-container {border: none;}") as demo:
chat_history = gr.State([])
max_tokens = gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max Tokens")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
chatbot = gr.Chatbot(label="Character-like AI Chat")
user_input = gr.Textbox(show_label=False, placeholder="Type your message here...")
send_button = gr.Button("Send")
send_button.click(
fn=update_chatbox,
inputs=[chat_history, user_input, max_tokens, temperature, top_p],
outputs=[chatbot, user_input],
queue=True,
)
if __name__ == "__main__":
demo.launch(share=False)