nikshep01 commited on
Commit
571a4b7
·
verified ·
1 Parent(s): b4068de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -7
app.py CHANGED
@@ -3,15 +3,17 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
  from threading import Thread
5
 
 
6
  tokenizer = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
7
  model = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16)
8
- # Move model to GPU if available
 
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
  model = model.to(device)
11
 
12
  class StopOnTokens(StoppingCriteria):
13
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
14
- stop_ids = [29, 0]
15
  for stop_id in stop_ids:
16
  if input_ids[0][-1] == stop_id:
17
  return True
@@ -21,11 +23,16 @@ def predict(message, history):
21
  history_transformer_format = list(zip(history[:-1], history[1:])) + [[message, ""]]
22
  stop = StopOnTokens()
23
 
24
- messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]])
25
- for item in history_transformer_format])
26
 
27
- model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
 
 
 
28
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
 
 
29
  generate_kwargs = dict(
30
  model_inputs,
31
  streamer=streamer,
@@ -36,20 +43,25 @@ def predict(message, history):
36
  temperature=1.0,
37
  num_beams=1,
38
  stopping_criteria=StoppingCriteriaList([stop])
39
- )
 
 
40
  t = Thread(target=model.generate, kwargs=generate_kwargs)
41
  t.start()
42
 
 
43
  partial_message = ""
44
  for new_token in streamer:
45
- if new_token != '<':
46
  partial_message += new_token
47
  yield partial_message
48
 
 
49
  gr.ChatInterface(predict).launch()
50
 
51
 
52
 
 
53
  # import gradio as gr
54
  # from huggingface_hub import InferenceClient
55
 
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
  from threading import Thread
5
 
6
+ # Load the tokenizer and model
7
  tokenizer = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
8
  model = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16)
9
+
10
+ # Move model to GPU if available, otherwise use CPU
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
  model = model.to(device)
13
 
14
  class StopOnTokens(StoppingCriteria):
15
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
16
+ stop_ids = [29, 0] # Define stop token IDs
17
  for stop_id in stop_ids:
18
  if input_ids[0][-1] == stop_id:
19
  return True
 
23
  history_transformer_format = list(zip(history[:-1], history[1:])) + [[message, ""]]
24
  stop = StopOnTokens()
25
 
26
+ # Format the messages for the model
27
+ messages = "".join([f"\n<human>:{item[0]}\n<bot>:{item[1]}" for item in history_transformer_format])
28
 
29
+ # Tokenize the input and move it to the correct device (GPU/CPU)
30
+ model_inputs = tokenizer([messages], return_tensors="pt").to(device)
31
+
32
+ # Create a streamer for output token generation
33
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
34
+
35
+ # Define generation parameters
36
  generate_kwargs = dict(
37
  model_inputs,
38
  streamer=streamer,
 
43
  temperature=1.0,
44
  num_beams=1,
45
  stopping_criteria=StoppingCriteriaList([stop])
46
+ )
47
+
48
+ # Run the generation in a separate thread
49
  t = Thread(target=model.generate, kwargs=generate_kwargs)
50
  t.start()
51
 
52
+ # Yield generated tokens as they are produced
53
  partial_message = ""
54
  for new_token in streamer:
55
+ if new_token != '<': # Ignore special tokens
56
  partial_message += new_token
57
  yield partial_message
58
 
59
+ # Gradio interface to interact with the model
60
  gr.ChatInterface(predict).launch()
61
 
62
 
63
 
64
+
65
  # import gradio as gr
66
  # from huggingface_hub import InferenceClient
67