Threatthriver commited on
Commit
411e510
·
verified ·
1 Parent(s): 12b152d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -17
app.py CHANGED
@@ -6,56 +6,51 @@ import time
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from tqdm import tqdm
8
 
9
- # Load the tokenizer and model (lightweight model as per your suggestion)
10
  try:
11
  tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
12
- model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", torch_dtype=torch.float16) # Use float16 for lower VRAM usage
13
 
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
- model = model.to(device)
16
  print(f"Model loaded on {device}")
 
17
  except Exception as e:
18
  print(f"Error loading model: {e}")
19
  exit(1)
20
 
21
 
22
- # Function to clean up memory
23
  def clean_memory():
24
  while True:
25
  gc.collect()
26
- if device == "cuda":
27
  torch.cuda.empty_cache()
28
  time.sleep(1)
29
 
30
- # Start memory cleanup in a background thread
31
  cleanup_thread = threading.Thread(target=clean_memory, daemon=True)
32
  cleanup_thread.start()
33
 
 
34
  def generate_response(message, history, max_tokens, temperature, top_p):
35
  try:
36
- # Add system message for better control
37
  system_message = "You are a helpful and friendly AI assistant."
38
  prompt = system_message + "\n" + "".join([f"{speaker}: {text}\n" for speaker, text in history] + [f"User: {message}\n"])
39
 
40
-
41
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
42
-
43
- #Streaming response
44
  generated_text = ""
45
  with torch.no_grad():
46
- for token_id in tqdm(model.generate(input_ids, max_length=input_ids.shape[-1] + max_tokens, temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id, stream=True)):
47
  generated_text = tokenizer.decode(token_id, skip_special_tokens=True)
48
  yield generated_text
49
 
50
  except Exception as e:
51
  yield f"Error generating response: {e}"
52
 
 
53
  def update_chatbox(history, message, max_tokens, temperature, top_p):
54
  history.append(("User", message))
55
  for response_chunk in generate_response(message, history, max_tokens, temperature, top_p):
56
- yield history, response_chunk #yield allows streaming updates
57
 
58
- #Append final response after generation complete
59
  response = response_chunk.strip()
60
  history.append(("AI", response))
61
  yield history, ""
@@ -63,7 +58,7 @@ def update_chatbox(history, message, max_tokens, temperature, top_p):
63
 
64
  with gr.Blocks(css=".gradio-container {border: none;}") as demo:
65
  chat_history = gr.State([])
66
- max_tokens = gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max Tokens") #Reduced max tokens
67
  temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
68
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
69
 
@@ -76,8 +71,7 @@ with gr.Blocks(css=".gradio-container {border: none;}") as demo:
76
  inputs=[chat_history, user_input, max_tokens, temperature, top_p],
77
  outputs=[chatbot, user_input],
78
  queue=True,
79
- live=True #For streaming updates
80
  )
81
 
82
  if __name__ == "__main__":
83
- demo.launch(share=False) #share=False, because share=True is not supported on Hugging Face spaces
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from tqdm import tqdm
8
 
 
9
  try:
10
  tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
11
+ model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", torch_dtype=torch.float16, device_map="auto")
12
 
13
+ device = model.device #Get device automatically
 
14
  print(f"Model loaded on {device}")
15
+
16
  except Exception as e:
17
  print(f"Error loading model: {e}")
18
  exit(1)
19
 
20
 
 
21
  def clean_memory():
22
  while True:
23
  gc.collect()
24
+ if device.type == 'cuda': #Check device type explicitly
25
  torch.cuda.empty_cache()
26
  time.sleep(1)
27
 
 
28
  cleanup_thread = threading.Thread(target=clean_memory, daemon=True)
29
  cleanup_thread.start()
30
 
31
+
32
  def generate_response(message, history, max_tokens, temperature, top_p):
33
  try:
 
34
  system_message = "You are a helpful and friendly AI assistant."
35
  prompt = system_message + "\n" + "".join([f"{speaker}: {text}\n" for speaker, text in history] + [f"User: {message}\n"])
36
 
 
37
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
38
+
 
39
  generated_text = ""
40
  with torch.no_grad():
41
+ for token_id in tqdm(model.generate(input_ids, max_length=min(input_ids.shape[-1] + max_tokens, 2048), temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id, stream=True)): # Added max length to prevent excessive generation
42
  generated_text = tokenizer.decode(token_id, skip_special_tokens=True)
43
  yield generated_text
44
 
45
  except Exception as e:
46
  yield f"Error generating response: {e}"
47
 
48
+
49
  def update_chatbox(history, message, max_tokens, temperature, top_p):
50
  history.append(("User", message))
51
  for response_chunk in generate_response(message, history, max_tokens, temperature, top_p):
52
+ yield history, response_chunk
53
 
 
54
  response = response_chunk.strip()
55
  history.append(("AI", response))
56
  yield history, ""
 
58
 
59
  with gr.Blocks(css=".gradio-container {border: none;}") as demo:
60
  chat_history = gr.State([])
61
+ max_tokens = gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max Tokens")
62
  temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
63
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
64
 
 
71
  inputs=[chat_history, user_input, max_tokens, temperature, top_p],
72
  outputs=[chatbot, user_input],
73
  queue=True,
 
74
  )
75
 
76
  if __name__ == "__main__":
77
+ demo.launch(share=False)