bobpopboom commited on
Commit
eec50a8
·
verified ·
1 Parent(s): c2d3107

r1 i want it better

Browse files
Files changed (1) hide show
  1. app.py +25 -18
app.py CHANGED
@@ -2,27 +2,32 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- if torch.cuda.is_available():
6
- device = "cuda"
7
- else:
8
- device = "cpu"
9
 
10
  model_id = "thrishala/mental_health_chatbot"
11
 
12
  try:
 
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_id,
15
- device_map=device,
16
  torch_dtype=torch.float16,
17
  low_cpu_mem_usage=True,
18
- max_memory={device: "15GB"},
19
  offload_folder="offload",
20
- )
 
21
  tokenizer = AutoTokenizer.from_pretrained(model_id)
22
- tokenizer.model_max_length = 512
23
 
 
24
  dummy_input = tokenizer("This is a test.", return_tensors="pt").to(model.device)
25
- model.generate(input_ids=dummy_input.input_ids, return_dict=True) # Dummy call
 
 
 
 
26
 
27
  except Exception as e:
28
  print(f"Error loading model: {e}")
@@ -37,32 +42,34 @@ def generate_text(prompt, max_new_tokens=128):
37
  max_new_tokens=max_new_tokens,
38
  do_sample=False,
39
  eos_token_id=tokenizer.eos_token_id,
40
- return_dict=True, # Explicitly set return_dict=True
41
  )
42
 
43
- generated_text = tokenizer.decode(output.sequences[0], skip_special_tokens=True) # Decode from sequences
44
  return generated_text
45
 
46
  def generate_text_streaming(prompt, max_new_tokens=128):
47
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
48
 
49
  with torch.no_grad():
50
- for i in range(max_new_tokens):
51
  output = model.generate(
52
  input_ids=input_ids,
53
  max_new_tokens=1,
54
  do_sample=False,
55
  eos_token_id=tokenizer.eos_token_id,
56
- return_dict=True,
57
- output_scores=True,
58
  )
59
 
60
- generated_token = tokenizer.decode(output.logits[0][-1].argmax(), skip_special_tokens=True)
 
 
61
  yield generated_token
62
 
 
63
  input_ids = torch.cat([input_ids, output.sequences[:, -1:]], dim=-1)
64
 
65
- if output.sequences[0][-1] == tokenizer.eos_token_id:
66
  break
67
 
68
  def respond(message, history, system_message, max_tokens):
@@ -73,7 +80,7 @@ def respond(message, history, system_message, max_tokens):
73
 
74
  try:
75
  for token in generate_text_streaming(prompt, max_tokens):
76
- yield token # Yield each token individually
77
  except Exception as e:
78
  print(f"Error during generation: {e}")
79
  yield "An error occurred."
@@ -90,4 +97,4 @@ demo = gr.ChatInterface(
90
  )
91
 
92
  if __name__ == "__main__":
93
- demo.launch()
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ # Determine device
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
7
 
8
  model_id = "thrishala/mental_health_chatbot"
9
 
10
  try:
11
+ # Load model with appropriate device_map and settings
12
  model = AutoModelForCausalLM.from_pretrained(
13
  model_id,
14
+ device_map="auto", # Use "auto" for device_map instead of device name
15
  torch_dtype=torch.float16,
16
  low_cpu_mem_usage=True,
17
+ max_memory={0: "15GiB"} if torch.cuda.is_available() else None,
18
  offload_folder="offload",
19
+ ).eval() # Set model to evaluation mode
20
+
21
  tokenizer = AutoTokenizer.from_pretrained(model_id)
22
+ tokenizer.pad_token = tokenizer.eos_token # Set padding token if missing
23
 
24
+ # Perform a dummy generation to initialize model (if needed)
25
  dummy_input = tokenizer("This is a test.", return_tensors="pt").to(model.device)
26
+ model.generate(
27
+ input_ids=dummy_input.input_ids,
28
+ max_new_tokens=1,
29
+ return_dict_in_generate=True # Correct parameter name
30
+ )
31
 
32
  except Exception as e:
33
  print(f"Error loading model: {e}")
 
42
  max_new_tokens=max_new_tokens,
43
  do_sample=False,
44
  eos_token_id=tokenizer.eos_token_id,
45
+ return_dict_in_generate=True # Correct parameter name
46
  )
47
 
48
+ generated_text = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
49
  return generated_text
50
 
51
  def generate_text_streaming(prompt, max_new_tokens=128):
52
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
53
 
54
  with torch.no_grad():
55
+ for _ in range(max_new_tokens):
56
  output = model.generate(
57
  input_ids=input_ids,
58
  max_new_tokens=1,
59
  do_sample=False,
60
  eos_token_id=tokenizer.eos_token_id,
61
+ return_dict_in_generate=True # Correct parameter name
 
62
  )
63
 
64
+ # Get the last generated token
65
+ generated_token_id = output.sequences[0, -1]
66
+ generated_token = tokenizer.decode([generated_token_id], skip_special_tokens=True)
67
  yield generated_token
68
 
69
+ # Append new token to input_ids
70
  input_ids = torch.cat([input_ids, output.sequences[:, -1:]], dim=-1)
71
 
72
+ if generated_token_id == tokenizer.eos_token_id:
73
  break
74
 
75
  def respond(message, history, system_message, max_tokens):
 
80
 
81
  try:
82
  for token in generate_text_streaming(prompt, max_tokens):
83
+ yield token
84
  except Exception as e:
85
  print(f"Error during generation: {e}")
86
  yield "An error occurred."
 
97
  )
98
 
99
  if __name__ == "__main__":
100
+ demo.launch()