import torch
from src.smb.level import MarioLevel
from src.gan.gans import nz
from src.utils.filesys import getpath

# 采样噪声
def sample_latvec(n=1, device='cpu', distribuion='uniform'):
    if distribuion == 'uniform':
        return torch.rand(n, nz, 1, 1, device=device) * 2 - 1
    elif distribuion == 'normal':
        return torch.randn(n, nz, 1, 1, device=device)
    else:
        raise TypeError(f'unknow noise distribution: {distribuion}')

# 处理onehot数组
def process_onehot(raw_tensor_onehot):
    H, W = MarioLevel.height, MarioLevel.seg_width
    res = []
    for single in raw_tensor_onehot:
        data = single[:, :H, :W].detach().cpu().numpy()
        lvl = MarioLevel.from_one_hot_arr(data)
        res.append(lvl)
    return res if len(res) > 1 else res[0]

def get_decoder(path='models/decoder.pth', device='cpu'):
    decoder = torch.load(getpath(path), map_location=device)
    decoder.requires_grad_(False)
    decoder.eval()
    return decoder