Spaces:
Runtime error
Runtime error
File size: 2,862 Bytes
c67e035 29567f1 b702fe6 29567f1 12b152d c67e035 12b152d 411e510 12b152d 411e510 12b152d 411e510 12b152d c67e035 29567f1 b702fe6 12b152d 411e510 12b152d b702fe6 411e510 29567f1 12b152d 411e510 12b152d 411e510 12b152d c51870d 411e510 29567f1 12b152d 411e510 12b152d e18c985 12b152d 411e510 0151088 29567f1 3e4a10f e84e0fa 3e4a10f 29567f1 12b152d e84e0fa c67e035 411e510 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import gradio as gr
import torch
import gc
import threading
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
try:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", torch_dtype=torch.float16, device_map="auto")
device = model.device #Get device automatically
print(f"Model loaded on {device}")
except Exception as e:
print(f"Error loading model: {e}")
exit(1)
def clean_memory():
while True:
gc.collect()
if device.type == 'cuda': #Check device type explicitly
torch.cuda.empty_cache()
time.sleep(1)
cleanup_thread = threading.Thread(target=clean_memory, daemon=True)
cleanup_thread.start()
def generate_response(message, history, max_tokens, temperature, top_p):
try:
system_message = "You are a helpful and friendly AI assistant."
prompt = system_message + "\n" + "".join([f"{speaker}: {text}\n" for speaker, text in history] + [f"User: {message}\n"])
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
generated_text = ""
with torch.no_grad():
for token_id in tqdm(model.generate(input_ids, max_length=min(input_ids.shape[-1] + max_tokens, 2048), temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id, stream=True)): # Added max length to prevent excessive generation
generated_text = tokenizer.decode(token_id, skip_special_tokens=True)
yield generated_text
except Exception as e:
yield f"Error generating response: {e}"
def update_chatbox(history, message, max_tokens, temperature, top_p):
history.append(("User", message))
for response_chunk in generate_response(message, history, max_tokens, temperature, top_p):
yield history, response_chunk
response = response_chunk.strip()
history.append(("AI", response))
yield history, ""
with gr.Blocks(css=".gradio-container {border: none;}") as demo:
chat_history = gr.State([])
max_tokens = gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max Tokens")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
chatbot = gr.Chatbot(label="Character-like AI Chat")
user_input = gr.Textbox(show_label=False, placeholder="Type your message here...")
send_button = gr.Button("Send")
send_button.click(
fn=update_chatbox,
inputs=[chat_history, user_input, max_tokens, temperature, top_p],
outputs=[chatbot, user_input],
queue=True,
)
if __name__ == "__main__":
demo.launch(share=False) |