import numpy as np
import random
import torch
import math


def rotate_axis(x, add_angle=0, axis=1):  # TODO Replace with a rotation matrix   # But this is more fun
    axes = list(range(3))
    axes.remove(axis)
    ax1, ax2 = axes
    angle = torch.atan2(x[..., ax1], x[..., ax2])
    if isinstance(add_angle, torch.Tensor):
        while add_angle.ndim < angle.ndim:
            add_angle = add_angle.unsqueeze(-1)
    angle = angle + add_angle
    dist = x.norm(dim=-1)
    t = []
    _, t = zip(*sorted([
        (axis, x[..., axis]),
        (ax1, torch.sin(angle) * dist),
        (ax2, torch.cos(angle) * dist),
    ]))
    return torch.stack(t, dim=-1)


noise_level = 0.5


# stolen from https://gist.github.com/ac1b097753f217c5c11bc2ff396e0a57
# ported from https://github.com/pvigier/perlin-numpy/blob/master/perlin2d.py
def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
    delta = (res[0] / shape[0], res[1] / shape[1])
    d = (shape[0] // res[0], shape[1] // res[1])

    grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1
    angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
    gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)

    tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0],
                                                                                                              0).repeat_interleave(
        d[1], 1)
    dot = lambda grad, shift: (
                torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
                            dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)

    n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
    n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
    n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
    n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
    t = fade(grid[:shape[0], :shape[1]])
    return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])


def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5):
    noise = torch.zeros(shape)
    frequency = 1
    amplitude = 1
    for _ in range(octaves):
        noise += amplitude * rand_perlin_2d(shape, (frequency * res[0], frequency * res[1]))
        frequency *= 2
        amplitude *= persistence
    noise *= random.random() - noise_level  # haha
    noise += random.random() - noise_level  # haha x2
    return noise


def load_clip(model_name="ViT-B/16", device="cuda:0" if torch.cuda.is_available() else "cpu"):
    import clip
    model, preprocess = clip.load(model_name, device=device, jit=False)
    if len(preprocess.transforms) > 4:
        preprocess.transforms = preprocess.transforms[-1:]
    return model, preprocess


# http://blog.andreaskahler.com/2009/06/creating-icosphere-mesh-in-code.html
def ico():
    phi = (1 + 5 ** 0.5) / 2
    return (
        np.array([
            [-1, phi, 0],
            [1, phi, 0],
            [-1, -phi, 0],
            [1, -phi, 0],
            [0, -1, phi],
            [0, 1, phi],
            [0, -1, -phi],
            [0, 1, -phi],
            [phi, 0, -1],
            [phi, 0, 1],
            [-phi, 0, -1],
            [-phi, 0, 1]
        ]) / phi,
        [
            [0, 11, 5],
            [0, 5, 1],
            [0, 1, 7],
            [0, 7, 10],
            [0, 10, 11],
            [1, 5, 9],
            [5, 11, 4],
            [11, 10, 2],
            [10, 7, 6],
            [7, 1, 8],
            [3, 9, 4],
            [3, 4, 2],
            [3, 2, 6],
            [3, 6, 8],
            [3, 8, 9],
            [4, 9, 5],
            [2, 4, 11],
            [6, 2, 10],
            [8, 6, 7],
            [9, 8, 1]
        ]
    )


def ico_at(xyz=np.array([0, 0, 0]), radius=1.0, i=0):
    vert, idx = ico()
    return vert * radius + xyz, [[y + i for y in x] for x in idx]


def save_obj(points, out_path):
    with torch.inference_mode():
        vert_pos, vert_col, vert_rad, vert_opa = (x.detach().cpu().numpy() for x in points)
      
    verts = []
    faces = []
    for xyz, radius, (r, g, b), a in zip(vert_pos, vert_rad, vert_col, vert_opa):
        v, i = ico_at(xyz, radius, len(verts))
        for x, y, z in v:
            verts.append((x, y, z, r, g, b))  # int(r * 255), int(g * 255), int(b * 255), int(a * 255)))
        faces += i

    with open(out_path, "w") as out_file:
        for v in verts:
            out_file.write("v " + " ".join(map(str, v)) + "\n")
        for f in faces:
            out_file.write("f " + " ".join([str(x + 1) for x in f]) + "\n")