SantiagoTesla commited on
Commit
aa7bb26
·
1 Parent(s): 260e134

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -31
app.py CHANGED
@@ -19,37 +19,38 @@ def chatbot(input):
19
 
20
 
21
  tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
22
-
23
- # mtp-7b is trained to add "<|endoftext|>" at the end of generations
24
- stop_token_ids = tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
25
-
26
- # define custom stopping criteria object
27
- class StopOnTokens(StoppingCriteria):
28
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
29
- for stop_id in stop_token_ids:
30
- if input_ids[0][-1] == stop_id:
31
- return True
32
- return False
33
-
34
- stopping_criteria = StoppingCriteriaList([StopOnTokens()])
35
-
36
- generate_text = transformers.pipeline(
37
- model=model, tokenizer=tokenizer,
38
- return_full_text=True, # langchain expects the full text
39
- task='text-generation',
40
- device=device,
41
- # we pass model parameters here too
42
- stopping_criteria=stopping_criteria, # without this model will ramble
43
- temperature=0.1, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
44
- top_p=0.15, # select from top tokens whose probability add up to 15%
45
- top_k=0, # select from top 0 tokens (because zero, relies on top_p)
46
- max_new_tokens=64, # mex number of tokens to generate in the output
47
- repetition_penalty=1.1 # without this output begins repeating
48
- )
49
-
50
- res = generate_text(input)
51
- output = res[0]["generated_text"]
52
- return output
 
53
 
54
  inputs = gr.inputs.Textbox(lines=7, label="Chat with AI")
55
  outputs = gr.outputs.Textbox(label="Reply")
 
19
 
20
 
21
  tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
22
+
23
+ for i in range(50):
24
+ # mtp-7b is trained to add "<|endoftext|>" at the end of generations
25
+ stop_token_ids = tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
26
+
27
+ # define custom stopping criteria object
28
+ class StopOnTokens(StoppingCriteria):
29
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
30
+ for stop_id in stop_token_ids:
31
+ if input_ids[0][-1] == stop_id:
32
+ return True
33
+ return False
34
+
35
+ stopping_criteria = StoppingCriteriaList([StopOnTokens()])
36
+
37
+ generate_text = transformers.pipeline(
38
+ model=model, tokenizer=tokenizer,
39
+ return_full_text=True, # langchain expects the full text
40
+ task='text-generation',
41
+ device=device,
42
+ # we pass model parameters here too
43
+ stopping_criteria=stopping_criteria, # without this model will ramble
44
+ temperature=0.1, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
45
+ top_p=0.15, # select from top tokens whose probability add up to 15%
46
+ top_k=0, # select from top 0 tokens (because zero, relies on top_p)
47
+ max_new_tokens=64, # mex number of tokens to generate in the output
48
+ repetition_penalty=1.1 # without this output begins repeating
49
+ )
50
+
51
+ res = generate_text(input)
52
+ output = res[0]["generated_text"]
53
+ return output
54
 
55
  inputs = gr.inputs.Textbox(lines=7, label="Chat with AI")
56
  outputs = gr.outputs.Textbox(label="Reply")