bobpopboom commited on
Commit
b84cd4b
·
verified ·
1 Parent(s): fb27a1f

deep seek help plz xD

Browse files
Files changed (1) hide show
  1. app.py +32 -40
app.py CHANGED
@@ -5,62 +5,47 @@ import torch
5
  model_id = "thrishala/mental_health_chatbot"
6
 
7
  try:
 
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_id,
10
- device_map="cpu",
11
- torch_dtype=torch.float16,
12
- low_cpu_mem_usage=True,
13
- max_memory={"cpu": "15GB"},
14
- offload_folder="offload",
15
- )
16
-
17
- tokenizer = AutoTokenizer.from_pretrained(model_id)
18
- tokenizer.model_max_length = 256 # Set maximum length
19
-
20
- pipe = pipeline(
21
- "text-generation",
22
- model=model,
23
- tokenizer=tokenizer,
24
- torch_dtype=torch.float16,
25
- num_return_sequences=1,
26
- do_sample=False,
27
- truncation=True,
28
- max_new_tokens=128
29
  )
30
-
31
  except Exception as e:
32
  print(f"Error loading model: {e}")
33
  exit()
34
 
35
- def respond(message, history, system_message, max_tokens):
 
 
 
 
 
 
 
 
36
  prompt = f"{system_message}\n"
37
-
38
- # Yield the FULL history FIRST (important!)
39
- full_history = [] # Initialize an empty list for the full history
40
- for user_msg, bot_msg in reversed(history): # Reversed to append messages correctly
41
- full_history.append([user_msg, bot_msg]) # Append the user and bot message to the full history.
42
  prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
43
-
44
- yield full_history # Yield the full history first!
45
-
46
- # THEN yield the new message/response
47
  prompt += f"User: {message}\nAssistant:"
48
 
49
  try:
50
  response = pipe(
51
  prompt,
52
  max_new_tokens=max_tokens,
53
- do_sample=False,
54
- pad_token_id=tokenizer.eos_token_id
 
55
  )[0]["generated_text"]
56
-
57
- bot_response = response.split("Assistant:")[-1].strip()
58
 
59
- yield [message, bot_response] # Yield the new message/response
60
-
 
61
  except Exception as e:
62
  print(f"Error during generation: {e}")
63
- yield [message, "An error occurred during generation."]
64
 
65
  demo = gr.ChatInterface(
66
  respond,
@@ -69,10 +54,17 @@ demo = gr.ChatInterface(
69
  value="You are a friendly and helpful mental health chatbot.",
70
  label="System message",
71
  ),
72
- gr.Slider(minimum=1, maximum=128, value=128, step=1, label="Max new tokens"),
 
 
 
 
 
 
 
 
73
  ],
74
- chatbot=gr.Chatbot(type="messages"), # Updated to new format
75
  )
76
 
77
  if __name__ == "__main__":
78
- demo.launch()
 
5
  model_id = "thrishala/mental_health_chatbot"
6
 
7
  try:
8
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_id,
11
+ load_in_8bit=True,
12
+ device_map="auto",
13
+ torch_dtype=torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  )
15
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
16
  except Exception as e:
17
  print(f"Error loading model: {e}")
18
  exit()
19
 
20
+ def respond(
21
+ message,
22
+ history,
23
+ system_message,
24
+ max_tokens,
25
+ temperature,
26
+ top_p,
27
+ ):
28
+ # Construct the prompt with clear separation
29
  prompt = f"{system_message}\n"
30
+ for user_msg, bot_msg in history:
 
 
 
 
31
  prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
 
 
 
 
32
  prompt += f"User: {message}\nAssistant:"
33
 
34
  try:
35
  response = pipe(
36
  prompt,
37
  max_new_tokens=max_tokens,
38
+ temperature=temperature,
39
+ top_p=top_p,
40
+ eos_token_id=tokenizer.eos_token_id, # Use EOS token to stop generation
41
  )[0]["generated_text"]
 
 
42
 
43
+ # Extract only the new assistant response after the last Assistant: in the prompt
44
+ bot_response = response[len(prompt):].split("User:")[0].strip() # Take text after prompt and before next User
45
+ yield bot_response
46
  except Exception as e:
47
  print(f"Error during generation: {e}")
48
+ yield "An error occurred."
49
 
50
  demo = gr.ChatInterface(
51
  respond,
 
54
  value="You are a friendly and helpful mental health chatbot.",
55
  label="System message",
56
  ),
57
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
58
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
59
+ gr.Slider(
60
+ minimum=0.1,
61
+ maximum=1.0,
62
+ value=0.95,
63
+ step=0.05,
64
+ label="Top-p (nucleus sampling)",
65
+ ),
66
  ],
 
67
  )
68
 
69
  if __name__ == "__main__":
70
+ demo.launch()