|
|
|
import torch |
|
|
|
def generate_text(model, tokenizer, prompt, max_length=50, device='cuda'): |
|
model.eval() |
|
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) |
|
|
|
with torch.no_grad(): |
|
for _ in range(max_length): |
|
outputs = model(input_ids) |
|
next_token_logits = outputs[:, -1, :] |
|
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) |
|
input_ids = torch.cat([input_ids, next_token], dim=-1) |
|
|
|
if next_token.item() == tokenizer.eos_token_id: |
|
break |
|
|
|
return tokenizer.decode(input_ids[0], skip_special_tokens=True) |
|
|