nikshep01 commited on
Commit
b4068de
·
verified ·
1 Parent(s): 910badc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -36
app.py CHANGED
@@ -2,71 +2,51 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
  from threading import Thread
5
- from queue import Empty
6
-
7
- # Load the tokenizer and model
8
- tokenizer = AutoTokenizer.from_pretrained("thrishala/mental_health_chatbot")
9
- model = AutoModelForCausalLM.from_pretrained("thrishala/mental_health_chatbot", torch_dtype=torch.float16)
10
 
 
 
11
  # Move model to GPU if available
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
  model = model.to(device)
14
 
15
  class StopOnTokens(StoppingCriteria):
16
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
17
- stop_ids = [29, 0] # Token IDs for stopping criteria
18
  for stop_id in stop_ids:
19
  if input_ids[0][-1] == stop_id:
20
  return True
21
  return False
22
 
23
  def predict(message, history):
24
- # Prepare the input history in the expected format for the model
25
  history_transformer_format = list(zip(history[:-1], history[1:])) + [[message, ""]]
26
  stop = StopOnTokens()
27
 
28
- # Concatenate conversation history
29
- messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) for item in history_transformer_format])
30
-
31
- # Tokenize and prepare model inputs
32
- model_inputs = tokenizer([messages], return_tensors="pt").to(device)
33
 
34
- # Create streamer with longer timeout
35
- streamer = TextIteratorStreamer(tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True)
36
-
37
- # Define generation parameters
38
  generate_kwargs = dict(
39
  model_inputs,
40
  streamer=streamer,
41
- max_new_tokens=512, # Reduced to avoid memory issues
42
  do_sample=True,
43
- top_p=0.85, # Adjusted for faster generation
44
- top_k=500, # Adjusted for faster generation
45
  temperature=1.0,
46
  num_beams=1,
47
  stopping_criteria=StoppingCriteriaList([stop])
48
- )
49
-
50
- # Run the generation in a separate thread
51
  t = Thread(target=model.generate, kwargs=generate_kwargs)
52
  t.start()
53
 
54
- # Yield generated tokens
55
  partial_message = ""
56
- try:
57
- for new_token in streamer:
58
- if new_token != '<':
59
- partial_message += new_token
60
- yield partial_message
61
- except Empty:
62
- yield "Error: No tokens generated or generation timeout."
63
-
64
- # Gradio interface to run the chatbot
65
- gr.ChatInterface(predict).launch()
66
-
67
-
68
-
69
 
 
70
 
71
 
72
 
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
  from threading import Thread
 
 
 
 
 
5
 
6
+ tokenizer = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
7
+ model = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16)
8
  # Move model to GPU if available
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
  model = model.to(device)
11
 
12
  class StopOnTokens(StoppingCriteria):
13
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
14
+ stop_ids = [29, 0]
15
  for stop_id in stop_ids:
16
  if input_ids[0][-1] == stop_id:
17
  return True
18
  return False
19
 
20
  def predict(message, history):
 
21
  history_transformer_format = list(zip(history[:-1], history[1:])) + [[message, ""]]
22
  stop = StopOnTokens()
23
 
24
+ messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]])
25
+ for item in history_transformer_format])
 
 
 
26
 
27
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
28
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
 
 
29
  generate_kwargs = dict(
30
  model_inputs,
31
  streamer=streamer,
32
+ max_new_tokens=1024,
33
  do_sample=True,
34
+ top_p=0.95,
35
+ top_k=1000,
36
  temperature=1.0,
37
  num_beams=1,
38
  stopping_criteria=StoppingCriteriaList([stop])
39
+ )
 
 
40
  t = Thread(target=model.generate, kwargs=generate_kwargs)
41
  t.start()
42
 
 
43
  partial_message = ""
44
+ for new_token in streamer:
45
+ if new_token != '<':
46
+ partial_message += new_token
47
+ yield partial_message
 
 
 
 
 
 
 
 
 
48
 
49
+ gr.ChatInterface(predict).launch()
50
 
51
 
52