File size: 2,862 Bytes
c67e035
29567f1
b702fe6
 
 
29567f1
12b152d
c67e035
12b152d
 
411e510
12b152d
411e510
12b152d
411e510
12b152d
 
 
c67e035
29567f1
b702fe6
 
12b152d
411e510
12b152d
 
b702fe6
 
 
 
411e510
29567f1
12b152d
 
 
 
 
411e510
12b152d
 
411e510
12b152d
 
 
 
 
c51870d
411e510
29567f1
12b152d
 
411e510
12b152d
 
 
 
 
 
e18c985
12b152d
411e510
0151088
29567f1
 
3e4a10f
 
 
e84e0fa
3e4a10f
 
29567f1
12b152d
 
e84e0fa
 
c67e035
411e510
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
import torch
import gc
import threading
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm

try:
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
    model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", torch_dtype=torch.float16, device_map="auto")

    device = model.device  #Get device automatically
    print(f"Model loaded on {device}")

except Exception as e:
    print(f"Error loading model: {e}")
    exit(1)


def clean_memory():
    while True:
        gc.collect()
        if device.type == 'cuda':  #Check device type explicitly
            torch.cuda.empty_cache()
        time.sleep(1)

cleanup_thread = threading.Thread(target=clean_memory, daemon=True)
cleanup_thread.start()


def generate_response(message, history, max_tokens, temperature, top_p):
    try:
        system_message = "You are a helpful and friendly AI assistant."
        prompt = system_message + "\n" + "".join([f"{speaker}: {text}\n" for speaker, text in history] + [f"User: {message}\n"])

        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

        generated_text = ""
        with torch.no_grad():
            for token_id in tqdm(model.generate(input_ids, max_length=min(input_ids.shape[-1] + max_tokens, 2048), temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id, stream=True)): # Added max length to prevent excessive generation
                generated_text = tokenizer.decode(token_id, skip_special_tokens=True)
                yield generated_text

    except Exception as e:
        yield f"Error generating response: {e}"


def update_chatbox(history, message, max_tokens, temperature, top_p):
    history.append(("User", message))
    for response_chunk in generate_response(message, history, max_tokens, temperature, top_p):
        yield history, response_chunk

    response = response_chunk.strip()
    history.append(("AI", response))
    yield history, ""


with gr.Blocks(css=".gradio-container {border: none;}") as demo:
    chat_history = gr.State([])
    max_tokens = gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max Tokens")
    temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
    top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")

    chatbot = gr.Chatbot(label="Character-like AI Chat")
    user_input = gr.Textbox(show_label=False, placeholder="Type your message here...")
    send_button = gr.Button("Send")

    send_button.click(
        fn=update_chatbox,
        inputs=[chat_history, user_input, max_tokens, temperature, top_p],
        outputs=[chatbot, user_input],
        queue=True,
    )

if __name__ == "__main__":
    demo.launch(share=False)