bobpopboom commited on
Commit
803024c
·
verified ·
1 Parent(s): 81ab351

idk anymore is more a vibe

Browse files
Files changed (1) hide show
  1. app.py +18 -18
app.py CHANGED
@@ -12,14 +12,14 @@ model_id = "thrishala/mental_health_chatbot"
12
  try:
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_id,
15
- device_map=device, # Use the determined device
16
  torch_dtype=torch.float16,
17
  low_cpu_mem_usage=True,
18
- max_memory={device: "15GB"}, # Use device-specific memory management
19
  offload_folder="offload",
20
  )
21
  tokenizer = AutoTokenizer.from_pretrained(model_id)
22
- tokenizer.model_max_length = 512 # Set maximum length
23
 
24
  except Exception as e:
25
  print(f"Error loading model: {e}")
@@ -32,11 +32,12 @@ def generate_text(prompt, max_new_tokens=128):
32
  output = model.generate(
33
  input_ids=input_ids,
34
  max_new_tokens=max_new_tokens,
35
- do_sample=False, # Or True for sampling
36
  eos_token_id=tokenizer.eos_token_id,
 
37
  )
38
 
39
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
40
  return generated_text
41
 
42
  def generate_text_streaming(prompt, max_new_tokens=128):
@@ -46,20 +47,20 @@ def generate_text_streaming(prompt, max_new_tokens=128):
46
  for i in range(max_new_tokens):
47
  output = model.generate(
48
  input_ids=input_ids,
49
- max_new_tokens=1, # Generate only 1 new token at a time
50
- do_sample=False, # Or True for sampling
51
  eos_token_id=tokenizer.eos_token_id,
52
- return_dict=True, #Return a dictionary
53
- output_scores=True #Return the scores
54
  )
55
 
56
- generated_token = tokenizer.decode(output.logits[0][-1].argmax(), skip_special_tokens=True) #Decode the last token only
57
- yield generated_token #Yield the last token
58
 
59
- input_ids = torch.cat([input_ids, output.sequences[:, -1:]], dim=-1) #Append the new token to the input
60
 
61
- if output.sequences[0][-1] == tokenizer.eos_token_id: #Check if the end of sequence token was generated
62
- break #Break the loop
63
 
64
  def respond(message, history, system_message, max_tokens):
65
  prompt = f"{system_message}\n"
@@ -68,9 +69,8 @@ def respond(message, history, system_message, max_tokens):
68
  prompt += f"User: {message}\nAssistant:"
69
 
70
  try:
71
- for token in generate_text_streaming(prompt, max_tokens): #Iterate over the generator
72
- yield token #Yield each token individually
73
-
74
  except Exception as e:
75
  print(f"Error during generation: {e}")
76
  yield "An error occurred."
@@ -82,7 +82,7 @@ demo = gr.ChatInterface(
82
  value="You are a friendly and helpful mental health chatbot.",
83
  label="System message",
84
  ),
85
- gr.Slider(minimum=1, maximum=128, value=32, step=10, label="Max new tokens"),
86
  ],
87
  )
88
 
 
12
  try:
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_id,
15
+ device_map=device,
16
  torch_dtype=torch.float16,
17
  low_cpu_mem_usage=True,
18
+ max_memory={device: "15GB"},
19
  offload_folder="offload",
20
  )
21
  tokenizer = AutoTokenizer.from_pretrained(model_id)
22
+ tokenizer.model_max_length = 512
23
 
24
  except Exception as e:
25
  print(f"Error loading model: {e}")
 
32
  output = model.generate(
33
  input_ids=input_ids,
34
  max_new_tokens=max_new_tokens,
35
+ do_sample=False,
36
  eos_token_id=tokenizer.eos_token_id,
37
+ return_dict=True, # Explicitly set return_dict=True
38
  )
39
 
40
+ generated_text = tokenizer.decode(output.sequences[0], skip_special_tokens=True) # Decode from sequences
41
  return generated_text
42
 
43
  def generate_text_streaming(prompt, max_new_tokens=128):
 
47
  for i in range(max_new_tokens):
48
  output = model.generate(
49
  input_ids=input_ids,
50
+ max_new_tokens=1,
51
+ do_sample=False,
52
  eos_token_id=tokenizer.eos_token_id,
53
+ return_dict=True,
54
+ output_scores=True,
55
  )
56
 
57
+ generated_token = tokenizer.decode(output.logits[0][-1].argmax(), skip_special_tokens=True)
58
+ yield generated_token
59
 
60
+ input_ids = torch.cat([input_ids, output.sequences[:, -1:]], dim=-1)
61
 
62
+ if output.sequences[0][-1] == tokenizer.eos_token_id:
63
+ break
64
 
65
  def respond(message, history, system_message, max_tokens):
66
  prompt = f"{system_message}\n"
 
69
  prompt += f"User: {message}\nAssistant:"
70
 
71
  try:
72
+ for token in generate_text_streaming(prompt, max_tokens):
73
+ yield token # Yield each token individually
 
74
  except Exception as e:
75
  print(f"Error during generation: {e}")
76
  yield "An error occurred."
 
82
  value="You are a friendly and helpful mental health chatbot.",
83
  label="System message",
84
  ),
85
+ gr.Slider(minimum=1, maximum=128, value=32, step=1, label="Max new tokens"),
86
  ],
87
  )
88