| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						from einops import reduce | 
					
					
						
						| 
							 | 
						from jaxtyping import Float, Int64 | 
					
					
						
						| 
							 | 
						from torch import Tensor | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def sample_discrete_distribution( | 
					
					
						
						| 
							 | 
						    pdf: Float[Tensor, "*batch bucket"], | 
					
					
						
						| 
							 | 
						    num_samples: int, | 
					
					
						
						| 
							 | 
						    eps: float = torch.finfo(torch.float32).eps, | 
					
					
						
						| 
							 | 
						) -> tuple[ | 
					
					
						
						| 
							 | 
						    Int64[Tensor, "*batch sample"],   | 
					
					
						
						| 
							 | 
						    Float[Tensor, "*batch sample"],   | 
					
					
						
						| 
							 | 
						]: | 
					
					
						
						| 
							 | 
						    *batch, bucket = pdf.shape | 
					
					
						
						| 
							 | 
						    normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) | 
					
					
						
						| 
							 | 
						    cdf = normalized_pdf.cumsum(dim=-1) | 
					
					
						
						| 
							 | 
						    samples = torch.rand((*batch, num_samples), device=pdf.device) | 
					
					
						
						| 
							 | 
						    index = torch.searchsorted(cdf, samples, right=True).clip(max=bucket - 1) | 
					
					
						
						| 
							 | 
						    return index, normalized_pdf.gather(dim=-1, index=index) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def gather_discrete_topk( | 
					
					
						
						| 
							 | 
						    pdf: Float[Tensor, "*batch bucket"], | 
					
					
						
						| 
							 | 
						    num_samples: int, | 
					
					
						
						| 
							 | 
						    eps: float = torch.finfo(torch.float32).eps, | 
					
					
						
						| 
							 | 
						) -> tuple[ | 
					
					
						
						| 
							 | 
						    Int64[Tensor, "*batch sample"],   | 
					
					
						
						| 
							 | 
						    Float[Tensor, "*batch sample"],   | 
					
					
						
						| 
							 | 
						]: | 
					
					
						
						| 
							 | 
						    normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) | 
					
					
						
						| 
							 | 
						    index = pdf.topk(k=num_samples, dim=-1).indices | 
					
					
						
						| 
							 | 
						    return index, normalized_pdf.gather(dim=-1, index=index) | 
					
					
						
						| 
							 | 
						
 |