import torch def generate_text(model, tokenizer, prompt, max_length=50, device='cuda'): model.eval() input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) with torch.no_grad(): for _ in range(max_length): outputs = model(input_ids) next_token_logits = outputs[:, -1, :] next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) input_ids = torch.cat([input_ids, next_token], dim=-1) if next_token.item() == tokenizer.eos_token_id: break return tokenizer.decode(input_ids[0], skip_special_tokens=True)