bobpopboom commited on
Commit
3c7c10f
·
verified ·
1 Parent(s): 2876cd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -35
app.py CHANGED
@@ -5,69 +5,55 @@ import torch
5
  model_id = "thrishala/mental_health_chatbot"
6
 
7
  try:
8
- # Load model with int8 quantization for CPU
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_id,
11
  device_map="cpu",
12
- torch_dtype=torch.float16, # Use float16 for reduced memory
13
- low_cpu_mem_usage=True, # Enable memory optimization
14
- max_memory={"cpu": "15GB"}, # Limit memory usage
15
- offload_folder="offload", # Enable disk offloading
16
  )
17
 
18
- # Load tokenizer
19
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
20
 
21
- # Create pipeline with optimizations
22
  pipe = pipeline(
23
  "text-generation",
24
  model=model,
25
  tokenizer=tokenizer,
26
  torch_dtype=torch.float16,
27
- num_return_sequences=1, # Only generate one response
28
- do_sample=True, # Enable sampling since we're using temperature and top_p
29
- truncation=True, # Explicitly enable truncation
30
- max_new_tokens=128 # Use only max_new_tokens
31
  )
32
 
33
  except Exception as e:
34
  print(f"Error loading model: {e}")
35
  exit()
36
 
37
- def respond(
38
- message,
39
- history: list[tuple[str, str]],
40
- system_message, # You can use this for initial instructions
41
- max_tokens,
42
- temperature,
43
- top_p,
44
- ):
45
- # 2. Construct the Prompt (Crucial!)
46
- prompt = f"{system_message}\n"
47
  for user_msg, bot_msg in history:
48
  prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
49
  prompt += f"User: {message}\nAssistant:"
50
-
51
- # 3. Generate with the Pipeline
52
  try:
53
  response = pipe(
54
  prompt,
55
  max_new_tokens=max_tokens,
56
  temperature=temperature,
57
  top_p=top_p,
 
 
58
  )[0]["generated_text"]
59
- prompt,
60
- do_sample=True,
61
- #Extract the bot's reply (adjust if your model format is different)
62
  bot_response = response.split("Assistant:")[-1].strip()
63
  yield bot_response
64
-
65
  except Exception as e:
66
  print(f"Error during generation: {e}")
67
- yield "An error occurred during generation." #Handle generation errors.
68
-
69
 
70
- # 4. Gradio Interface (No changes needed here)
71
  demo = gr.ChatInterface(
72
  respond,
73
  additional_inputs=[
@@ -78,13 +64,10 @@ demo = gr.ChatInterface(
78
  gr.Slider(minimum=1, maximum=128, value=128, step=1, label="Max new tokens"),
79
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
80
  gr.Slider(
81
- minimum=0.1,
82
- maximum=1.0,
83
- value=0.95,
84
- step=0.05,
85
- label="Top-p (nucleus sampling)",
86
  ),
87
  ],
 
88
  )
89
 
90
  if __name__ == "__main__":
 
5
  model_id = "thrishala/mental_health_chatbot"
6
 
7
  try:
 
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_id,
10
  device_map="cpu",
11
+ torch_dtype=torch.float16,
12
+ low_cpu_mem_usage=True,
13
+ max_memory={"cpu": "15GB"},
14
+ offload_folder="offload",
15
  )
16
 
 
17
  tokenizer = AutoTokenizer.from_pretrained(model_id)
18
+ tokenizer.model_max_length = 256 # Set maximum length
19
 
 
20
  pipe = pipeline(
21
  "text-generation",
22
  model=model,
23
  tokenizer=tokenizer,
24
  torch_dtype=torch.float16,
25
+ num_return_sequences=1,
26
+ do_sample=True,
27
+ truncation=True,
28
+ max_new_tokens=128
29
  )
30
 
31
  except Exception as e:
32
  print(f"Error loading model: {e}")
33
  exit()
34
 
35
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
36
+ prompt = f"{system_message}\n"
 
 
 
 
 
 
 
 
37
  for user_msg, bot_msg in history:
38
  prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
39
  prompt += f"User: {message}\nAssistant:"
40
+
 
41
  try:
42
  response = pipe(
43
  prompt,
44
  max_new_tokens=max_tokens,
45
  temperature=temperature,
46
  top_p=top_p,
47
+ do_sample=True,
48
+ pad_token_id=tokenizer.eos_token_id
49
  )[0]["generated_text"]
50
+
 
 
51
  bot_response = response.split("Assistant:")[-1].strip()
52
  yield bot_response
 
53
  except Exception as e:
54
  print(f"Error during generation: {e}")
55
+ yield "An error occurred during generation."
 
56
 
 
57
  demo = gr.ChatInterface(
58
  respond,
59
  additional_inputs=[
 
64
  gr.Slider(minimum=1, maximum=128, value=128, step=1, label="Max new tokens"),
65
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
66
  gr.Slider(
67
+ minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)",
 
 
 
 
68
  ),
69
  ],
70
+ chatbot=gr.Chatbot(type="messages"), # Updated to new format
71
  )
72
 
73
  if __name__ == "__main__":