Spaces:
Running
on
L40S
Running
on
L40S
Update tools/llama/generate.py
Browse files- tools/llama/generate.py +1 -15
tools/llama/generate.py
CHANGED
|
@@ -154,16 +154,11 @@ def decode_one_token_ar_agent(
|
|
| 154 |
logits = x.logits # [:, -1:]
|
| 155 |
hidden_states = x.hidden_states # [:, -1:]
|
| 156 |
|
| 157 |
-
sampling_kwargs_main = sampling_kwargs.copy()
|
| 158 |
-
sampling_kwargs_main["temperature"] = 0.1
|
| 159 |
-
sampling_kwargs_main["top_p"] = 0.1
|
| 160 |
-
sampling_kwargs_main["repetition_penalty"] = 1.0
|
| 161 |
-
|
| 162 |
codebooks = [
|
| 163 |
sample_agent(
|
| 164 |
logits,
|
| 165 |
previous_tokens=None, # Disable repetition penalty for the token codebook
|
| 166 |
-
**
|
| 167 |
)[0]
|
| 168 |
]
|
| 169 |
|
|
@@ -194,15 +189,6 @@ def decode_one_token_ar_agent(
|
|
| 194 |
codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
|
| 195 |
)
|
| 196 |
|
| 197 |
-
# for i in range(codebooks.size(1) - 1):
|
| 198 |
-
# codebooks[:, i + 1, :] = torch.masked_fill(
|
| 199 |
-
# codebooks[:, i + 1, :],
|
| 200 |
-
# codebooks[:, :1, :] != semantic_id,
|
| 201 |
-
# CODEBOOK_PAD_TOKEN_ID + i * 1024,
|
| 202 |
-
# )
|
| 203 |
-
|
| 204 |
-
# print(codebooks)
|
| 205 |
-
|
| 206 |
return codebooks
|
| 207 |
|
| 208 |
|
|
|
|
| 154 |
logits = x.logits # [:, -1:]
|
| 155 |
hidden_states = x.hidden_states # [:, -1:]
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
codebooks = [
|
| 158 |
sample_agent(
|
| 159 |
logits,
|
| 160 |
previous_tokens=None, # Disable repetition penalty for the token codebook
|
| 161 |
+
**sampling_kwargs,
|
| 162 |
)[0]
|
| 163 |
]
|
| 164 |
|
|
|
|
| 189 |
codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
|
| 190 |
)
|
| 191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
return codebooks
|
| 193 |
|
| 194 |
|