Text Generation
English
instruction-following
reasoning
gem-1o / utils /text_generation.py
comethrusws's picture
Commit #1: GEM_1o_Aug trained
d18eb09 verified
import torch
def generate_text(model, tokenizer, prompt, max_length=50, device='cuda'):
model.eval()
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
with torch.no_grad():
for _ in range(max_length):
outputs = model(input_ids)
next_token_logits = outputs[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
input_ids = torch.cat([input_ids, next_token], dim=-1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0], skip_special_tokens=True)