Text Generation
Transformers
PyTorch
Safetensors
gpt2
conversational
text-generation-inference
Inference Endpoints
timpal0l commited on
Commit
630d104
·
1 Parent(s): 40e3e92

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +17 -4
README.md CHANGED
@@ -75,9 +75,11 @@ Kul att du tycker det!
75
  ...
76
  ```
77
 
78
- The procedure to generate text is the same as before:
79
 
80
  ```python
 
 
81
  prompt = """
82
  <|endoftext|><s>
83
  User:
@@ -86,17 +88,28 @@ Varför är träd fina?
86
  Bot:
87
  """.strip()
88
 
 
 
 
 
 
 
 
 
 
 
89
  input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)
90
 
91
  generated_token_ids = model.generate(
92
  inputs=input_ids,
93
- max_new_tokens=100,
94
  do_sample=True,
95
  temperature=0.6,
96
  top_p=1,
 
97
  )[0]
98
 
99
- generated_text = tokenizer.decode(generated_token_ids)
100
  ```
101
 
102
  Generating text using the `generate` method is done as follows:
@@ -171,7 +184,7 @@ Following Mitchell et al. (2018), we provide a model card for GPT-SW3.
171
 
172
  - Conversational
173
  - Familjeliv (https://www.familjeliv.se/)
174
- - Flashback (https://flashback.se/)
175
  - Datasets collected through Parlai (see Appendix in data paper for complete list) (https://github.com/facebookresearch/ParlAI)
176
  - Pushshift.io Reddit dataset, developed in Baumgartner et al. (2020) and processed in Roller et al. (2021)
177
 
 
75
  ...
76
  ```
77
 
78
+ The procedure to generate text in chat format:
79
 
80
  ```python
81
+ from transformers import StoppingCriteriaList, StoppingCriteria
82
+
83
  prompt = """
84
  <|endoftext|><s>
85
  User:
 
88
  Bot:
89
  """.strip()
90
 
91
+ # (Optional) - define a stopping criteria
92
+ # We ideally want the model to stop generate once the response from the Bot is generated
93
+ class StopOnTokenCriteria(StoppingCriteria):
94
+ def __init__(self, stop_token_id):
95
+ self.stop_token_id = stop_token_id
96
+
97
+ def __call__(self, input_ids, scores, **kwargs):
98
+ return input_ids[0, -1] == self.stop_token_id
99
+
100
+ stop_on_token_criteria = StopOnTokenCriteria(stop_token_id=tokenizer.bos_token_id)
101
  input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)
102
 
103
  generated_token_ids = model.generate(
104
  inputs=input_ids,
105
+ max_new_tokens=128,
106
  do_sample=True,
107
  temperature=0.6,
108
  top_p=1,
109
+ stopping_criteria=StoppingCriteriaList([stop_on_token_criteria])
110
  )[0]
111
 
112
+ generated_text = tokenizer.decode(generated_token_ids[len(input_ids[0]):-1])
113
  ```
114
 
115
  Generating text using the `generate` method is done as follows:
 
184
 
185
  - Conversational
186
  - Familjeliv (https://www.familjeliv.se/)
187
+ - Flashback (https://flashback.org/)
188
  - Datasets collected through Parlai (see Appendix in data paper for complete list) (https://github.com/facebookresearch/ParlAI)
189
  - Pushshift.io Reddit dataset, developed in Baumgartner et al. (2020) and processed in Roller et al. (2021)
190