Update README.md
Browse files
README.md
CHANGED
@@ -75,9 +75,11 @@ Kul att du tycker det!
|
|
75 |
...
|
76 |
```
|
77 |
|
78 |
-
The procedure to generate text
|
79 |
|
80 |
```python
|
|
|
|
|
81 |
prompt = """
|
82 |
<|endoftext|><s>
|
83 |
User:
|
@@ -86,17 +88,28 @@ Varför är träd fina?
|
|
86 |
Bot:
|
87 |
""".strip()
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)
|
90 |
|
91 |
generated_token_ids = model.generate(
|
92 |
inputs=input_ids,
|
93 |
-
max_new_tokens=
|
94 |
do_sample=True,
|
95 |
temperature=0.6,
|
96 |
top_p=1,
|
|
|
97 |
)[0]
|
98 |
|
99 |
-
generated_text = tokenizer.decode(generated_token_ids)
|
100 |
```
|
101 |
|
102 |
Generating text using the `generate` method is done as follows:
|
@@ -171,7 +184,7 @@ Following Mitchell et al. (2018), we provide a model card for GPT-SW3.
|
|
171 |
|
172 |
- Conversational
|
173 |
- Familjeliv (https://www.familjeliv.se/)
|
174 |
-
- Flashback (https://flashback.
|
175 |
- Datasets collected through Parlai (see Appendix in data paper for complete list) (https://github.com/facebookresearch/ParlAI)
|
176 |
- Pushshift.io Reddit dataset, developed in Baumgartner et al. (2020) and processed in Roller et al. (2021)
|
177 |
|
|
|
75 |
...
|
76 |
```
|
77 |
|
78 |
+
The procedure to generate text in chat format:
|
79 |
|
80 |
```python
|
81 |
+
from transformers import StoppingCriteriaList, StoppingCriteria
|
82 |
+
|
83 |
prompt = """
|
84 |
<|endoftext|><s>
|
85 |
User:
|
|
|
88 |
Bot:
|
89 |
""".strip()
|
90 |
|
91 |
+
# (Optional) - define a stopping criteria
|
92 |
+
# We ideally want the model to stop generate once the response from the Bot is generated
|
93 |
+
class StopOnTokenCriteria(StoppingCriteria):
|
94 |
+
def __init__(self, stop_token_id):
|
95 |
+
self.stop_token_id = stop_token_id
|
96 |
+
|
97 |
+
def __call__(self, input_ids, scores, **kwargs):
|
98 |
+
return input_ids[0, -1] == self.stop_token_id
|
99 |
+
|
100 |
+
stop_on_token_criteria = StopOnTokenCriteria(stop_token_id=tokenizer.bos_token_id)
|
101 |
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)
|
102 |
|
103 |
generated_token_ids = model.generate(
|
104 |
inputs=input_ids,
|
105 |
+
max_new_tokens=128,
|
106 |
do_sample=True,
|
107 |
temperature=0.6,
|
108 |
top_p=1,
|
109 |
+
stopping_criteria=StoppingCriteriaList([stop_on_token_criteria])
|
110 |
)[0]
|
111 |
|
112 |
+
generated_text = tokenizer.decode(generated_token_ids[len(input_ids[0]):-1])
|
113 |
```
|
114 |
|
115 |
Generating text using the `generate` method is done as follows:
|
|
|
184 |
|
185 |
- Conversational
|
186 |
- Familjeliv (https://www.familjeliv.se/)
|
187 |
+
- Flashback (https://flashback.org/)
|
188 |
- Datasets collected through Parlai (see Appendix in data paper for complete list) (https://github.com/facebookresearch/ParlAI)
|
189 |
- Pushshift.io Reddit dataset, developed in Baumgartner et al. (2020) and processed in Roller et al. (2021)
|
190 |
|