Update README.md
Browse files
README.md
CHANGED
@@ -18,23 +18,31 @@ Oneirogen can be used to generate novel dream narratives. It can also be used fo
|
|
18 |
## Code
|
19 |
|
20 |
```py
|
21 |
-
from
|
22 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
23 |
-
import torch
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
bnb_4bit_compute_dtype=torch.bfloat16,
|
30 |
-
)
|
31 |
|
32 |
-
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
|
|
|
35 |
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
```
|
39 |
|
40 |
## Inspiration
|
@@ -49,6 +57,4 @@ Mail: [email protected]
|
|
49 |
|
50 |
X: [@gustavecortal](https://x.com/gustavecortal)
|
51 |
|
52 |
-
Website: [gustavecortal.com](gustavecortal.com)
|
53 |
-
|
54 |
-
|
|
|
18 |
## Code
|
19 |
|
20 |
```py
|
21 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
|
|
|
|
|
22 |
|
23 |
+
class CustomStoppingCriteria(StoppingCriteria):
|
24 |
+
def __init__(self, stop_token, tokenizer):
|
25 |
+
self.stop_token = stop_token
|
26 |
+
self.tokenizer = tokenizer
|
|
|
|
|
27 |
|
28 |
+
def __call__(self, input_ids, scores, **kwargs):
|
29 |
+
decoded_output = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
30 |
+
if self.stop_token in decoded_output:
|
31 |
+
return True
|
32 |
+
return False
|
33 |
|
34 |
+
stop_token = "END." # The model was trained with this special end of text token.
|
35 |
+
stopping_criteria = StoppingCriteriaList([CustomStoppingCriteria(stop_token, tokenizer)])
|
36 |
|
37 |
+
tokenizer = AutoTokenizer.from_pretrained("gustavecortal/oneirogen-0.5B")
|
38 |
+
model = AutoModelForCausalLM.from_pretrained("gustavecortal/oneirogen-0.5B", torch_dtype=torch.float16)
|
39 |
+
model.to("cuda")
|
40 |
+
|
41 |
+
text = "Dream:" # The model was trained with this prefix
|
42 |
+
|
43 |
+
inputs = tokenizer(text, return_tensors="pt").to("cuda")
|
44 |
+
outputs = model.generate(inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=256, top_k = 50, top_p = 0.95, do_sample = True, temperature=0.9, num_beams = 1, repetition_penalty= 1.11, stopping_criteria=stopping_criteria)
|
45 |
+
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=False)[0])
|
46 |
```
|
47 |
|
48 |
## Inspiration
|
|
|
57 |
|
58 |
X: [@gustavecortal](https://x.com/gustavecortal)
|
59 |
|
60 |
+
Website: [gustavecortal.com](gustavecortal.com)
|
|
|
|