import tqdm

import torch
from einops import rearrange

def scalar_to_batch_tensor(x, batch_size):
    return torch.tensor(x).repeat(batch_size)


def parallelize(
        fn, 
        *iterables,
        parallel: str = "thread_map",
        **kwargs
    ):
    if parallel == "thread_map":
        from tqdm.contrib.concurrent import thread_map
        return thread_map(
            fn, 
            *iterables, 
            **kwargs
        )
    elif parallel == "process_map":
        from tqdm.contrib.concurrent import process_map
        return process_map(
            fn, 
            *iterables, 
            **kwargs
        )
    elif parallel == "single":
        return [fn(x) for x in tqdm.tqdm(*iterables)]
    else:
        raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}")
    
def codebook_flatten(tokens: torch.Tensor):
    """ 
    flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
    """
    return rearrange(tokens, "b c t -> b (t c)")

def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None):
    """
    unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
    """
    tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c)
    return tokens