Spaces:
Runtime error
Runtime error
Hugo Flores Garcia
commited on
Commit
·
9496f0e
1
Parent(s):
308d855
sampling tricks!
Browse files- app.py +49 -11
- vampnet/modules/transformer.py +109 -37
app.py
CHANGED
|
@@ -97,28 +97,35 @@ def _vamp(data, return_mask=False):
|
|
| 97 |
mask = pmask.codebook_unmask(mask, ncc)
|
| 98 |
|
| 99 |
|
| 100 |
-
print(
|
|
|
|
| 101 |
# save the mask as a txt file
|
| 102 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
| 103 |
|
|
|
|
| 104 |
zv, mask_z = interface.coarse_vamp(
|
| 105 |
z,
|
| 106 |
mask=mask,
|
| 107 |
sampling_steps=data[num_steps],
|
| 108 |
-
|
|
|
|
| 109 |
return_mask=True,
|
| 110 |
typical_filtering=data[typical_filtering],
|
| 111 |
typical_mass=data[typical_mass],
|
| 112 |
typical_min_tokens=data[typical_min_tokens],
|
|
|
|
| 113 |
gen_fn=interface.coarse.generate,
|
|
|
|
| 114 |
)
|
| 115 |
|
| 116 |
if use_coarse2fine:
|
| 117 |
zv = interface.coarse_to_fine(
|
| 118 |
zv,
|
| 119 |
-
|
|
|
|
| 120 |
mask=mask,
|
| 121 |
-
sampling_steps=data[num_steps]
|
|
|
|
| 122 |
)
|
| 123 |
|
| 124 |
sig = interface.to_signal(zv).cpu()
|
|
@@ -152,7 +159,9 @@ def save_vamp(data):
|
|
| 152 |
sig_out.write(out_dir / "output.wav")
|
| 153 |
|
| 154 |
_data = {
|
| 155 |
-
"
|
|
|
|
|
|
|
| 156 |
"prefix_s": data[prefix_s],
|
| 157 |
"suffix_s": data[suffix_s],
|
| 158 |
"rand_mask_intensity": data[rand_mask_intensity],
|
|
@@ -163,6 +172,7 @@ def save_vamp(data):
|
|
| 163 |
"n_conditioning_codebooks": data[n_conditioning_codebooks],
|
| 164 |
"use_coarse2fine": data[use_coarse2fine],
|
| 165 |
"stretch_factor": data[stretch_factor],
|
|
|
|
| 166 |
}
|
| 167 |
|
| 168 |
# save with yaml
|
|
@@ -385,16 +395,28 @@ with gr.Blocks() as demo:
|
|
| 385 |
value=0.0
|
| 386 |
)
|
| 387 |
|
| 388 |
-
|
| 389 |
-
label="temperature",
|
| 390 |
minimum=0.0,
|
| 391 |
maximum=10.0,
|
| 392 |
-
value=
|
| 393 |
)
|
| 394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
|
| 396 |
|
| 397 |
with gr.Accordion("sampling settings", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
typical_filtering = gr.Checkbox(
|
| 399 |
label="typical filtering ",
|
| 400 |
value=False
|
|
@@ -435,6 +457,18 @@ with gr.Blocks() as demo:
|
|
| 435 |
value=0.0
|
| 436 |
)
|
| 437 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
# mask settings
|
| 440 |
with gr.Column():
|
|
@@ -463,7 +497,9 @@ with gr.Blocks() as demo:
|
|
| 463 |
_inputs = {
|
| 464 |
input_audio,
|
| 465 |
num_steps,
|
| 466 |
-
|
|
|
|
|
|
|
| 467 |
prefix_s, suffix_s,
|
| 468 |
rand_mask_intensity,
|
| 469 |
periodic_p, periodic_w,
|
|
@@ -476,7 +512,9 @@ with gr.Blocks() as demo:
|
|
| 476 |
typical_mass,
|
| 477 |
typical_min_tokens,
|
| 478 |
beat_mask_width,
|
| 479 |
-
beat_mask_downbeats
|
|
|
|
|
|
|
| 480 |
}
|
| 481 |
|
| 482 |
# connect widgets
|
|
|
|
| 97 |
mask = pmask.codebook_unmask(mask, ncc)
|
| 98 |
|
| 99 |
|
| 100 |
+
print(data)
|
| 101 |
+
_top_p = data[top_p] if data[top_p] > 0 else None
|
| 102 |
# save the mask as a txt file
|
| 103 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
| 104 |
|
| 105 |
+
_seed = data[seed] if data[seed] > 0 else None
|
| 106 |
zv, mask_z = interface.coarse_vamp(
|
| 107 |
z,
|
| 108 |
mask=mask,
|
| 109 |
sampling_steps=data[num_steps],
|
| 110 |
+
mask_temperature=data[masktemp]*10,
|
| 111 |
+
sampling_temperature=data[sampletemp],
|
| 112 |
return_mask=True,
|
| 113 |
typical_filtering=data[typical_filtering],
|
| 114 |
typical_mass=data[typical_mass],
|
| 115 |
typical_min_tokens=data[typical_min_tokens],
|
| 116 |
+
top_p=_top_p,
|
| 117 |
gen_fn=interface.coarse.generate,
|
| 118 |
+
seed=_seed,
|
| 119 |
)
|
| 120 |
|
| 121 |
if use_coarse2fine:
|
| 122 |
zv = interface.coarse_to_fine(
|
| 123 |
zv,
|
| 124 |
+
mask_temperature=data[masktemp]*10,
|
| 125 |
+
sampling_temperature=data[sampletemp],
|
| 126 |
mask=mask,
|
| 127 |
+
sampling_steps=data[num_steps],
|
| 128 |
+
seed=_seed,
|
| 129 |
)
|
| 130 |
|
| 131 |
sig = interface.to_signal(zv).cpu()
|
|
|
|
| 159 |
sig_out.write(out_dir / "output.wav")
|
| 160 |
|
| 161 |
_data = {
|
| 162 |
+
"masktemp": data[masktemp],
|
| 163 |
+
"sampletemp": data[sampletemp],
|
| 164 |
+
"top_p": data[top_p],
|
| 165 |
"prefix_s": data[prefix_s],
|
| 166 |
"suffix_s": data[suffix_s],
|
| 167 |
"rand_mask_intensity": data[rand_mask_intensity],
|
|
|
|
| 172 |
"n_conditioning_codebooks": data[n_conditioning_codebooks],
|
| 173 |
"use_coarse2fine": data[use_coarse2fine],
|
| 174 |
"stretch_factor": data[stretch_factor],
|
| 175 |
+
"seed": data[seed],
|
| 176 |
}
|
| 177 |
|
| 178 |
# save with yaml
|
|
|
|
| 395 |
value=0.0
|
| 396 |
)
|
| 397 |
|
| 398 |
+
masktemp = gr.Slider(
|
| 399 |
+
label="mask temperature",
|
| 400 |
minimum=0.0,
|
| 401 |
maximum=10.0,
|
| 402 |
+
value=1.5
|
| 403 |
)
|
| 404 |
+
sampletemp = gr.Slider(
|
| 405 |
+
label="sample temperature",
|
| 406 |
+
minimum=0.1,
|
| 407 |
+
maximum=2.0,
|
| 408 |
+
value=1.0
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
|
| 412 |
|
| 413 |
with gr.Accordion("sampling settings", open=False):
|
| 414 |
+
top_p = gr.Slider(
|
| 415 |
+
label="top p (0.0 = off)",
|
| 416 |
+
minimum=0.0,
|
| 417 |
+
maximum=1.0,
|
| 418 |
+
value=0.0
|
| 419 |
+
)
|
| 420 |
typical_filtering = gr.Checkbox(
|
| 421 |
label="typical filtering ",
|
| 422 |
value=False
|
|
|
|
| 457 |
value=0.0
|
| 458 |
)
|
| 459 |
|
| 460 |
+
use_new_trick = gr.Checkbox(
|
| 461 |
+
label="new trick",
|
| 462 |
+
value=False
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
seed = gr.Number(
|
| 466 |
+
label="seed (0 for random)",
|
| 467 |
+
value=0,
|
| 468 |
+
precision=0,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
|
| 473 |
# mask settings
|
| 474 |
with gr.Column():
|
|
|
|
| 497 |
_inputs = {
|
| 498 |
input_audio,
|
| 499 |
num_steps,
|
| 500 |
+
masktemp,
|
| 501 |
+
sampletemp,
|
| 502 |
+
top_p,
|
| 503 |
prefix_s, suffix_s,
|
| 504 |
rand_mask_intensity,
|
| 505 |
periodic_p, periodic_w,
|
|
|
|
| 512 |
typical_mass,
|
| 513 |
typical_min_tokens,
|
| 514 |
beat_mask_width,
|
| 515 |
+
beat_mask_downbeats,
|
| 516 |
+
seed,
|
| 517 |
+
seed
|
| 518 |
}
|
| 519 |
|
| 520 |
# connect widgets
|
vampnet/modules/transformer.py
CHANGED
|
@@ -367,6 +367,15 @@ class TransformerLayer(nn.Module):
|
|
| 367 |
|
| 368 |
return x, position_bias, encoder_decoder_position_bias
|
| 369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
class TransformerStack(nn.Module):
|
| 372 |
def __init__(
|
|
@@ -580,20 +589,20 @@ class VampNet(at.ml.BaseModel):
|
|
| 580 |
time_steps: int = 300,
|
| 581 |
sampling_steps: int = 24,
|
| 582 |
start_tokens: Optional[torch.Tensor] = None,
|
|
|
|
| 583 |
mask: Optional[torch.Tensor] = None,
|
| 584 |
-
|
| 585 |
typical_filtering=False,
|
| 586 |
typical_mass=0.2,
|
| 587 |
typical_min_tokens=1,
|
|
|
|
| 588 |
return_signal=True,
|
|
|
|
| 589 |
):
|
|
|
|
|
|
|
| 590 |
logging.debug(f"beginning generation with {sampling_steps} steps")
|
| 591 |
|
| 592 |
-
#####################
|
| 593 |
-
# resolve temperature #
|
| 594 |
-
#####################
|
| 595 |
-
|
| 596 |
-
logging.debug(f"temperature: {temperature}")
|
| 597 |
|
| 598 |
|
| 599 |
#####################
|
|
@@ -641,13 +650,11 @@ class VampNet(at.ml.BaseModel):
|
|
| 641 |
#################
|
| 642 |
# begin sampling #
|
| 643 |
#################
|
|
|
|
| 644 |
|
| 645 |
for i in range(sampling_steps):
|
| 646 |
logging.debug(f"step {i} of {sampling_steps}")
|
| 647 |
|
| 648 |
-
# our current temperature
|
| 649 |
-
logging.debug(f"temperature: {temperature}")
|
| 650 |
-
|
| 651 |
# our current schedule step
|
| 652 |
r = scalar_to_batch_tensor(
|
| 653 |
(i + 1) / sampling_steps,
|
|
@@ -664,39 +671,19 @@ class VampNet(at.ml.BaseModel):
|
|
| 664 |
# NOTE: this collapses the codebook dimension into the sequence dimension
|
| 665 |
logits = self.forward(latents, r) # b, prob, seq
|
| 666 |
logits = logits.permute(0, 2, 1) # b, seq, prob
|
| 667 |
-
|
| 668 |
-
typical_filter(logits,
|
| 669 |
-
typical_mass=typical_mass,
|
| 670 |
-
typical_min_tokens=typical_min_tokens
|
| 671 |
-
)
|
| 672 |
-
|
| 673 |
|
| 674 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
| 675 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
|
| 677 |
-
# logits2probs
|
| 678 |
-
probs = torch.softmax(logits, dim=-1)
|
| 679 |
-
logging.debug(f"computed probs with shape: {probs.shape}")
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
# sample from logits with multinomial sampling
|
| 683 |
-
b = probs.shape[0]
|
| 684 |
-
probs = rearrange(probs, "b seq prob -> (b seq) prob")
|
| 685 |
-
|
| 686 |
-
sampled_z = torch.multinomial(probs, 1).squeeze(-1)
|
| 687 |
-
|
| 688 |
-
sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b)
|
| 689 |
-
probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
|
| 690 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
| 691 |
|
| 692 |
-
# get the confidences: which tokens did we sample?
|
| 693 |
-
selected_probs = (
|
| 694 |
-
torch.take_along_dim(
|
| 695 |
-
probs, sampled_z.long().unsqueeze(-1),
|
| 696 |
-
dim=-1
|
| 697 |
-
).squeeze(-1)
|
| 698 |
-
)
|
| 699 |
-
|
| 700 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
| 701 |
# we'll unflatten them at the end of the loop for the next forward pass
|
| 702 |
# remove conditioning codebooks, we'll add them back at the end
|
|
@@ -733,7 +720,7 @@ class VampNet(at.ml.BaseModel):
|
|
| 733 |
|
| 734 |
# get our new mask
|
| 735 |
mask = mask_by_random_topk(
|
| 736 |
-
num_to_mask, selected_probs,
|
| 737 |
)
|
| 738 |
|
| 739 |
# update the mask
|
|
@@ -766,6 +753,91 @@ class VampNet(at.ml.BaseModel):
|
|
| 766 |
else:
|
| 767 |
return sampled_z
|
| 768 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 769 |
|
| 770 |
def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
|
| 771 |
"""
|
|
|
|
| 367 |
|
| 368 |
return x, position_bias, encoder_decoder_position_bias
|
| 369 |
|
| 370 |
+
def t_schedule(n_steps, max_temp=1.0, min_temp=0.0, k=1.0):
|
| 371 |
+
x = np.linspace(0, 1, n_steps)
|
| 372 |
+
a = (0.5 - min_temp) / (max_temp - min_temp)
|
| 373 |
+
|
| 374 |
+
x = (x * 12) - 6
|
| 375 |
+
x0 = np.log((1 / a - 1) + 1e-5) / k
|
| 376 |
+
y = (1 / (1 + np.exp(- k *(x-x0))))[::-1]
|
| 377 |
+
|
| 378 |
+
return y
|
| 379 |
|
| 380 |
class TransformerStack(nn.Module):
|
| 381 |
def __init__(
|
|
|
|
| 589 |
time_steps: int = 300,
|
| 590 |
sampling_steps: int = 24,
|
| 591 |
start_tokens: Optional[torch.Tensor] = None,
|
| 592 |
+
sampling_temperature: float = 1.0,
|
| 593 |
mask: Optional[torch.Tensor] = None,
|
| 594 |
+
mask_temperature: float = 20.5,
|
| 595 |
typical_filtering=False,
|
| 596 |
typical_mass=0.2,
|
| 597 |
typical_min_tokens=1,
|
| 598 |
+
top_p=None,
|
| 599 |
return_signal=True,
|
| 600 |
+
seed: int = None
|
| 601 |
):
|
| 602 |
+
if seed is not None:
|
| 603 |
+
at.util.seed(seed)
|
| 604 |
logging.debug(f"beginning generation with {sampling_steps} steps")
|
| 605 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
|
| 607 |
|
| 608 |
#####################
|
|
|
|
| 650 |
#################
|
| 651 |
# begin sampling #
|
| 652 |
#################
|
| 653 |
+
t_sched = t_schedule(sampling_steps, max_temp=sampling_temperature)
|
| 654 |
|
| 655 |
for i in range(sampling_steps):
|
| 656 |
logging.debug(f"step {i} of {sampling_steps}")
|
| 657 |
|
|
|
|
|
|
|
|
|
|
| 658 |
# our current schedule step
|
| 659 |
r = scalar_to_batch_tensor(
|
| 660 |
(i + 1) / sampling_steps,
|
|
|
|
| 671 |
# NOTE: this collapses the codebook dimension into the sequence dimension
|
| 672 |
logits = self.forward(latents, r) # b, prob, seq
|
| 673 |
logits = logits.permute(0, 2, 1) # b, seq, prob
|
| 674 |
+
b = logits.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
|
| 676 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
| 677 |
|
| 678 |
+
sampled_z, selected_probs = sample_from_logits(
|
| 679 |
+
logits, sample=True, temperature=t_sched[i],
|
| 680 |
+
typical_filtering=typical_filtering, typical_mass=typical_mass,
|
| 681 |
+
typical_min_tokens=typical_min_tokens,
|
| 682 |
+
top_k=None, top_p=top_p, return_probs=True
|
| 683 |
+
)
|
| 684 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 685 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
| 686 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 687 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
| 688 |
# we'll unflatten them at the end of the loop for the next forward pass
|
| 689 |
# remove conditioning codebooks, we'll add them back at the end
|
|
|
|
| 720 |
|
| 721 |
# get our new mask
|
| 722 |
mask = mask_by_random_topk(
|
| 723 |
+
num_to_mask, selected_probs, mask_temperature * (1-r)
|
| 724 |
)
|
| 725 |
|
| 726 |
# update the mask
|
|
|
|
| 753 |
else:
|
| 754 |
return sampled_z
|
| 755 |
|
| 756 |
+
def sample_from_logits(
|
| 757 |
+
logits,
|
| 758 |
+
sample: bool = True,
|
| 759 |
+
temperature: float = 1.0,
|
| 760 |
+
top_k: int = None,
|
| 761 |
+
top_p: float = None,
|
| 762 |
+
typical_filtering: bool = False,
|
| 763 |
+
typical_mass: float = 0.2,
|
| 764 |
+
typical_min_tokens: int = 1,
|
| 765 |
+
return_probs: bool = False
|
| 766 |
+
):
|
| 767 |
+
"""Convenience function to sample from a categorial distribution with input as
|
| 768 |
+
unnormalized logits.
|
| 769 |
+
|
| 770 |
+
Parameters
|
| 771 |
+
----------
|
| 772 |
+
logits : Tensor[..., vocab_size]
|
| 773 |
+
config: SamplingConfig
|
| 774 |
+
The set of hyperparameters to be used for sampling
|
| 775 |
+
sample : bool, optional
|
| 776 |
+
Whether to perform multinomial sampling, by default True
|
| 777 |
+
temperature : float, optional
|
| 778 |
+
Scaling parameter when multinomial samping, by default 1.0
|
| 779 |
+
top_k : int, optional
|
| 780 |
+
Restricts sampling to only `top_k` values acc. to probability,
|
| 781 |
+
by default None
|
| 782 |
+
top_p : float, optional
|
| 783 |
+
Restricts sampling to only those values with cumulative
|
| 784 |
+
probability = `top_p`, by default None
|
| 785 |
+
|
| 786 |
+
Returns
|
| 787 |
+
-------
|
| 788 |
+
Tensor[...]
|
| 789 |
+
Sampled tokens
|
| 790 |
+
"""
|
| 791 |
+
shp = logits.shape[:-1]
|
| 792 |
+
|
| 793 |
+
if typical_filtering:
|
| 794 |
+
typical_filter(logits,
|
| 795 |
+
typical_mass=typical_mass,
|
| 796 |
+
typical_min_tokens=typical_min_tokens
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
# Apply top_k sampling
|
| 800 |
+
if top_k is not None:
|
| 801 |
+
v, _ = logits.topk(top_k)
|
| 802 |
+
logits[logits < v[..., [-1]]] = -float("inf")
|
| 803 |
+
|
| 804 |
+
# Apply top_p (nucleus) sampling
|
| 805 |
+
if top_p is not None and top_p < 1.0:
|
| 806 |
+
v, sorted_indices = logits.sort(descending=True)
|
| 807 |
+
cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
|
| 808 |
+
|
| 809 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 810 |
+
# Right shift indices_to_remove to keep 1st token over threshold
|
| 811 |
+
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[
|
| 812 |
+
..., :-1
|
| 813 |
+
]
|
| 814 |
+
|
| 815 |
+
# Compute indices_to_remove in unsorted array
|
| 816 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 817 |
+
-1, sorted_indices, sorted_indices_to_remove
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
logits[indices_to_remove] = -float("inf")
|
| 821 |
+
|
| 822 |
+
# Perform multinomial sampling after normalizing logits
|
| 823 |
+
probs = (
|
| 824 |
+
F.softmax(logits / temperature, dim=-1)
|
| 825 |
+
if temperature > 0
|
| 826 |
+
else logits.softmax(dim=-1)
|
| 827 |
+
)
|
| 828 |
+
token = (
|
| 829 |
+
probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
|
| 830 |
+
if sample
|
| 831 |
+
else logits.argmax(-1)
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
if return_probs:
|
| 835 |
+
token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1)
|
| 836 |
+
return token, token_probs
|
| 837 |
+
else:
|
| 838 |
+
return token
|
| 839 |
+
|
| 840 |
+
|
| 841 |
|
| 842 |
def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
|
| 843 |
"""
|