Threatthriver commited on
Commit
12b152d
·
verified ·
1 Parent(s): 88c5a1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -42
app.py CHANGED
@@ -4,75 +4,80 @@ import gc
4
  import threading
5
  import time
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
7
 
8
  # Load the tokenizer and model (lightweight model as per your suggestion)
9
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
10
- model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
 
 
 
 
 
 
 
 
11
 
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
- model = model.to(device)
14
 
15
  # Function to clean up memory
16
  def clean_memory():
17
  while True:
18
- gc.collect() # Free up CPU memory
19
  if device == "cuda":
20
- torch.cuda.empty_cache() # Free up GPU memory
21
- time.sleep(1) # Clean every second
22
 
23
  # Start memory cleanup in a background thread
24
  cleanup_thread = threading.Thread(target=clean_memory, daemon=True)
25
  cleanup_thread.start()
26
 
27
  def generate_response(message, history, max_tokens, temperature, top_p):
28
- """
29
- Generates a response from the model.
30
- """
31
- # Prepare conversation history as input
32
- input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt").to(device)
33
-
34
- # Generate the output using the model with no gradient calculations
35
- with torch.no_grad():
36
- output = model.generate(
37
- input_ids,
38
- max_length=max_tokens,
39
- temperature=temperature,
40
- top_p=top_p,
41
- pad_token_id=tokenizer.eos_token_id,
42
- )
43
-
44
- response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
45
- history.append((message, response))
46
- return history, ""
47
 
48
  def update_chatbox(history, message, max_tokens, temperature, top_p):
49
- """
50
- Update the chat history and generate the next AI response.
51
- """
52
- history.append(("User", message)) # Add user message to history
53
- history, _ = generate_response(message, history, max_tokens, temperature, top_p)
54
- return history, "" # Return updated history and clear the user input
55
-
56
- # Define the Gradio interface with the Blocks context
 
 
57
  with gr.Blocks(css=".gradio-container {border: none;}") as demo:
58
- chat_history = gr.State([]) # Initialize an empty chat history state
59
- max_tokens = gr.Slider(minimum=1, maximum=1024, value=256, step=1, label="Max Tokens")
60
  temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
61
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
62
 
63
  chatbot = gr.Chatbot(label="Character-like AI Chat")
64
-
65
  user_input = gr.Textbox(show_label=False, placeholder="Type your message here...")
66
  send_button = gr.Button("Send")
67
 
68
- # When the send button is clicked, update chat history
69
  send_button.click(
70
  fn=update_chatbox,
71
  inputs=[chat_history, user_input, max_tokens, temperature, top_p],
72
- outputs=[chatbot, user_input], # Update chatbox and clear user input
73
- queue=True # Ensure responses are shown in order
 
74
  )
75
 
76
- # Launch the Gradio interface
77
  if __name__ == "__main__":
78
- demo.launch(share=True)
 
4
  import threading
5
  import time
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from tqdm import tqdm
8
 
9
  # Load the tokenizer and model (lightweight model as per your suggestion)
10
+ try:
11
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
12
+ model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", torch_dtype=torch.float16) # Use float16 for lower VRAM usage
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ model = model.to(device)
16
+ print(f"Model loaded on {device}")
17
+ except Exception as e:
18
+ print(f"Error loading model: {e}")
19
+ exit(1)
20
 
 
 
21
 
22
  # Function to clean up memory
23
  def clean_memory():
24
  while True:
25
+ gc.collect()
26
  if device == "cuda":
27
+ torch.cuda.empty_cache()
28
+ time.sleep(1)
29
 
30
  # Start memory cleanup in a background thread
31
  cleanup_thread = threading.Thread(target=clean_memory, daemon=True)
32
  cleanup_thread.start()
33
 
34
  def generate_response(message, history, max_tokens, temperature, top_p):
35
+ try:
36
+ # Add system message for better control
37
+ system_message = "You are a helpful and friendly AI assistant."
38
+ prompt = system_message + "\n" + "".join([f"{speaker}: {text}\n" for speaker, text in history] + [f"User: {message}\n"])
39
+
40
+
41
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
42
+
43
+ #Streaming response
44
+ generated_text = ""
45
+ with torch.no_grad():
46
+ for token_id in tqdm(model.generate(input_ids, max_length=input_ids.shape[-1] + max_tokens, temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id, stream=True)):
47
+ generated_text = tokenizer.decode(token_id, skip_special_tokens=True)
48
+ yield generated_text
49
+
50
+ except Exception as e:
51
+ yield f"Error generating response: {e}"
 
 
52
 
53
  def update_chatbox(history, message, max_tokens, temperature, top_p):
54
+ history.append(("User", message))
55
+ for response_chunk in generate_response(message, history, max_tokens, temperature, top_p):
56
+ yield history, response_chunk #yield allows streaming updates
57
+
58
+ #Append final response after generation complete
59
+ response = response_chunk.strip()
60
+ history.append(("AI", response))
61
+ yield history, ""
62
+
63
+
64
  with gr.Blocks(css=".gradio-container {border: none;}") as demo:
65
+ chat_history = gr.State([])
66
+ max_tokens = gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max Tokens") #Reduced max tokens
67
  temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
68
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
69
 
70
  chatbot = gr.Chatbot(label="Character-like AI Chat")
 
71
  user_input = gr.Textbox(show_label=False, placeholder="Type your message here...")
72
  send_button = gr.Button("Send")
73
 
 
74
  send_button.click(
75
  fn=update_chatbox,
76
  inputs=[chat_history, user_input, max_tokens, temperature, top_p],
77
+ outputs=[chatbot, user_input],
78
+ queue=True,
79
+ live=True #For streaming updates
80
  )
81
 
 
82
  if __name__ == "__main__":
83
+ demo.launch(share=False) #share=False, because share=True is not supported on Hugging Face spaces