Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn.functional as F | |
| def get_last_attn(attn_map): | |
| for i, layer in enumerate(attn_map): | |
| attn_map[i] = layer[:, :, -1, :].unsqueeze(2) | |
| return attn_map | |
| def sample_token(logits, top_k=None, top_p=None, temperature=1.0): | |
| # Optionally apply temperature | |
| logits = logits / temperature | |
| # Apply top-k sampling | |
| if top_k is not None: | |
| top_k = min(top_k, logits.size(-1)) # Ensure top_k <= vocab size | |
| values, indices = torch.topk(logits, top_k) | |
| probs = F.softmax(values, dim=-1) | |
| next_token_id = indices[torch.multinomial(probs, 1)] | |
| return next_token_id | |
| return logits.argmax(dim=-1).squeeze() | |