Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import gc | |
import threading | |
import time | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from tqdm import tqdm | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") | |
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", torch_dtype=torch.float16, device_map="auto") | |
device = model.device #Get device automatically | |
print(f"Model loaded on {device}") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
exit(1) | |
def clean_memory(): | |
while True: | |
gc.collect() | |
if device.type == 'cuda': #Check device type explicitly | |
torch.cuda.empty_cache() | |
time.sleep(1) | |
cleanup_thread = threading.Thread(target=clean_memory, daemon=True) | |
cleanup_thread.start() | |
def generate_response(message, history, max_tokens, temperature, top_p): | |
try: | |
system_message = "You are a helpful and friendly AI assistant." | |
prompt = system_message + "\n" + "".join([f"{speaker}: {text}\n" for speaker, text in history] + [f"User: {message}\n"]) | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
generated_text = "" | |
with torch.no_grad(): | |
for token_id in tqdm(model.generate(input_ids, max_length=min(input_ids.shape[-1] + max_tokens, 2048), temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id, stream=True)): # Added max length to prevent excessive generation | |
generated_text = tokenizer.decode(token_id, skip_special_tokens=True) | |
yield generated_text | |
except Exception as e: | |
yield f"Error generating response: {e}" | |
def update_chatbox(history, message, max_tokens, temperature, top_p): | |
history.append(("User", message)) | |
for response_chunk in generate_response(message, history, max_tokens, temperature, top_p): | |
yield history, response_chunk | |
response = response_chunk.strip() | |
history.append(("AI", response)) | |
yield history, "" | |
with gr.Blocks(css=".gradio-container {border: none;}") as demo: | |
chat_history = gr.State([]) | |
max_tokens = gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max Tokens") | |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature") | |
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)") | |
chatbot = gr.Chatbot(label="Character-like AI Chat") | |
user_input = gr.Textbox(show_label=False, placeholder="Type your message here...") | |
send_button = gr.Button("Send") | |
send_button.click( | |
fn=update_chatbox, | |
inputs=[chat_history, user_input, max_tokens, temperature, top_p], | |
outputs=[chatbot, user_input], | |
queue=True, | |
) | |
if __name__ == "__main__": | |
demo.launch(share=False) |