Update README.md
Browse files
README.md
CHANGED
@@ -82,8 +82,12 @@ start = time.time()
|
|
82 |
for i in range(200):
|
83 |
next_token = model(input_ids).logits[:, -1].argmax(-1)
|
84 |
generated_token_ids.append(next_token.item())
|
|
|
|
|
85 |
input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=-1)
|
86 |
-
|
|
|
|
|
87 |
break
|
88 |
|
89 |
print(tokenizer.decode(generated_token_ids))
|
|
|
82 |
for i in range(200):
|
83 |
next_token = model(input_ids).logits[:, -1].argmax(-1)
|
84 |
generated_token_ids.append(next_token.item())
|
85 |
+
|
86 |
+
print(next_token.item())
|
87 |
input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=-1)
|
88 |
+
|
89 |
+
# 32041 is the token id of <nexa_end>
|
90 |
+
if next_token.item() == 32041:
|
91 |
break
|
92 |
|
93 |
print(tokenizer.decode(generated_token_ids))
|