token_log_probs: simplify
Browse files- completions.py +2 -2
completions.py
CHANGED
|
@@ -66,9 +66,9 @@ def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, in
|
|
| 66 |
# B x T x V
|
| 67 |
log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
|
| 68 |
# T - 1
|
| 69 |
-
token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), input_ids[0][1:]]
|
| 70 |
-
# T - 1
|
| 71 |
tokens: torch.Tensor = input_ids[0][1:]
|
|
|
|
|
|
|
| 72 |
return list(zip(tokens.tolist(), token_log_probs.tolist()))
|
| 73 |
|
| 74 |
#%%
|
|
|
|
| 66 |
# B x T x V
|
| 67 |
log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
|
| 68 |
# T - 1
|
|
|
|
|
|
|
| 69 |
tokens: torch.Tensor = input_ids[0][1:]
|
| 70 |
+
# T - 1
|
| 71 |
+
token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), tokens]
|
| 72 |
return list(zip(tokens.tolist(), token_log_probs.tolist()))
|
| 73 |
|
| 74 |
#%%
|