bobpopboom commited on
Commit
0e0341f
·
verified ·
1 Parent(s): eec50a8

might be done now?

Browse files
Files changed (1) hide show
  1. app.py +42 -46
app.py CHANGED
@@ -8,79 +8,75 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
8
  model_id = "thrishala/mental_health_chatbot"
9
 
10
  try:
11
- # Load model with appropriate device_map and settings
12
  model = AutoModelForCausalLM.from_pretrained(
13
  model_id,
14
- device_map="auto", # Use "auto" for device_map instead of device name
15
  torch_dtype=torch.float16,
16
  low_cpu_mem_usage=True,
17
  max_memory={0: "15GiB"} if torch.cuda.is_available() else None,
18
  offload_folder="offload",
19
- ).eval() # Set model to evaluation mode
20
 
21
  tokenizer = AutoTokenizer.from_pretrained(model_id)
22
- tokenizer.pad_token = tokenizer.eos_token # Set padding token if missing
23
-
24
- # Perform a dummy generation to initialize model (if needed)
25
- dummy_input = tokenizer("This is a test.", return_tensors="pt").to(model.device)
26
- model.generate(
27
- input_ids=dummy_input.input_ids,
28
- max_new_tokens=1,
29
- return_dict_in_generate=True # Correct parameter name
30
- )
31
 
32
  except Exception as e:
33
  print(f"Error loading model: {e}")
34
  exit()
35
 
36
- def generate_text(prompt, max_new_tokens=128):
37
- input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
38
-
39
- with torch.no_grad():
40
- output = model.generate(
41
- input_ids=input_ids,
42
- max_new_tokens=max_new_tokens,
43
- do_sample=False,
44
- eos_token_id=tokenizer.eos_token_id,
45
- return_dict_in_generate=True # Correct parameter name
46
- )
47
-
48
- generated_text = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
49
- return generated_text
50
-
51
  def generate_text_streaming(prompt, max_new_tokens=128):
52
- input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
53
-
 
 
 
 
 
 
54
  with torch.no_grad():
55
  for _ in range(max_new_tokens):
56
- output = model.generate(
57
- input_ids=input_ids,
58
  max_new_tokens=1,
59
  do_sample=False,
60
  eos_token_id=tokenizer.eos_token_id,
61
- return_dict_in_generate=True # Correct parameter name
62
  )
63
-
64
- # Get the last generated token
65
- generated_token_id = output.sequences[0, -1]
66
- generated_token = tokenizer.decode([generated_token_id], skip_special_tokens=True)
67
- yield generated_token
68
-
69
- # Append new token to input_ids
70
- input_ids = torch.cat([input_ids, output.sequences[:, -1:]], dim=-1)
71
-
72
- if generated_token_id == tokenizer.eos_token_id:
 
 
 
 
 
73
  break
74
 
75
  def respond(message, history, system_message, max_tokens):
 
76
  prompt = f"{system_message}\n"
77
  for user_msg, bot_msg in history:
78
  prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
79
  prompt += f"User: {message}\nAssistant:"
80
-
 
 
 
81
  try:
82
- for token in generate_text_streaming(prompt, max_tokens):
83
- yield token
 
 
 
84
  except Exception as e:
85
  print(f"Error during generation: {e}")
86
  yield "An error occurred."
@@ -92,7 +88,7 @@ demo = gr.ChatInterface(
92
  value="You are a friendly and helpful mental health chatbot.",
93
  label="System message",
94
  ),
95
- gr.Slider(minimum=1, maximum=128, value=32, step=1, label="Max new tokens"),
96
  ],
97
  )
98
 
 
8
  model_id = "thrishala/mental_health_chatbot"
9
 
10
  try:
11
+ # Load model with appropriate settings
12
  model = AutoModelForCausalLM.from_pretrained(
13
  model_id,
14
+ device_map="auto",
15
  torch_dtype=torch.float16,
16
  low_cpu_mem_usage=True,
17
  max_memory={0: "15GiB"} if torch.cuda.is_available() else None,
18
  offload_folder="offload",
19
+ ).eval()
20
 
21
  tokenizer = AutoTokenizer.from_pretrained(model_id)
22
+ tokenizer.pad_token = tokenizer.eos_token
23
+ tokenizer.model_max_length = 4096 # Set to model's actual context length
 
 
 
 
 
 
 
24
 
25
  except Exception as e:
26
  print(f"Error loading model: {e}")
27
  exit()
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def generate_text_streaming(prompt, max_new_tokens=128):
30
+ inputs = tokenizer(
31
+ prompt,
32
+ return_tensors="pt",
33
+ truncation=True,
34
+ max_length=4096 # Match model's context length
35
+ ).to(model.device)
36
+
37
+ generated_tokens = []
38
  with torch.no_grad():
39
  for _ in range(max_new_tokens):
40
+ outputs = model.generate(
41
+ **inputs,
42
  max_new_tokens=1,
43
  do_sample=False,
44
  eos_token_id=tokenizer.eos_token_id,
45
+ return_dict_in_generate=True
46
  )
47
+
48
+ new_token = outputs.sequences[0, -1]
49
+ generated_tokens.append(new_token)
50
+
51
+ # Update inputs for next iteration
52
+ inputs = {
53
+ "input_ids": torch.cat([inputs["input_ids"], new_token.unsqueeze(0).unsqueeze(0)], dim=-1),
54
+ "attention_mask": torch.cat([inputs["attention_mask"], torch.ones(1, 1, device=model.device)], dim=-1)
55
+ }
56
+
57
+ # Decode the accumulated tokens
58
+ current_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
59
+ yield current_text # Yield the full text so far
60
+
61
+ if new_token == tokenizer.eos_token_id:
62
  break
63
 
64
  def respond(message, history, system_message, max_tokens):
65
+ # Build prompt with full history
66
  prompt = f"{system_message}\n"
67
  for user_msg, bot_msg in history:
68
  prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
69
  prompt += f"User: {message}\nAssistant:"
70
+
71
+ # Keep track of the full response
72
+ full_response = ""
73
+
74
  try:
75
+ for token_chunk in generate_text_streaming(prompt, max_tokens):
76
+ # Update the full response and yield incremental changes
77
+ full_response = token_chunk
78
+ yield full_response
79
+
80
  except Exception as e:
81
  print(f"Error during generation: {e}")
82
  yield "An error occurred."
 
88
  value="You are a friendly and helpful mental health chatbot.",
89
  label="System message",
90
  ),
91
+ gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max new tokens"),
92
  ],
93
  )
94