import numpy as np
import cv2


def crop_and_scaled_imgs(imgs):
    PAD = 10
    # use the last image to find the bounding box of the non-white area and the transformation parameters
    # and then apply the transformation to all images


    img = imgs[-1]
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    # Threshold the image to create a binary mask
    _, binary_mask = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV)

    # Find the coordinates of non-zero pixels
    coords = cv2.findNonZero(binary_mask)

    # Get the bounding box of the non-zero pixels
    x, y, w, h = cv2.boundingRect(coords)
    x -= PAD
    y -= PAD
    w += 2 * PAD
    h += 2 * PAD

    # Calculate the position to center the ROI in the 256x256 image
    start_x = max(0, (256 - w) // 2)
    start_y = max(0, (256 - h) // 2)

    # Create a new 256x256 rgb images
    new_imgs = [np.ones((256, 256, 3), dtype=np.uint8) * 255 for _ in range(len(imgs))]
    for i in range(len(imgs)):
        # Extract the ROI (region of interest) of the non-white area
        roi = imgs[i][y:y+h, x:x+w]
        # If the ROI is larger than 256x256, resize it

        if w > 256 or h > 256:
            scale = min(256 / w, 256 / h)
            new_w = int(w * scale)
            new_h = int(h * scale)
            roi = cv2.resize(roi, (new_w, new_h), interpolation=cv2.INTER_AREA)
            w, h = new_w, new_h

        # new_imgs[i] = np.ones((256, 256), dtype=np.uint8) * 255
        # centered_img = np.ones((256, 256), dtype=np.uint8) * 255

        # Place the ROI in the centered position
        new_imgs[i][start_y:start_y+h, start_x:start_x+w] = roi
    
    return new_imgs


HALF_INF = 63
INF = 126
EPS_DIST = 1/20
EPS_ANGLE = 2.86
SCALE = 15

MOVE_SPEED = 25
ROTATE_SPEED = 30
FPS = 24

class Turtle:
    def __init__(self, canvas_size=(800, 800)):
        self.x = canvas_size[0] // 2 
        self.y = canvas_size[1] // 2 
        self.heading = 0
        self.canvas = np.ones((canvas_size[1], canvas_size[0], 3), dtype=np.uint8) * 255
        self.is_down = True
        self.time_since_last_frame = 0
        self.frames = [self.canvas.copy()]


    def forward(self, dist):
        # print('st', self.x, self.y)
        # self.forward_step(dist * SCALE)
        # print('ed', self.x, self.y)
        # return
        dist = dist * SCALE
        sign = 1 if dist > 0 else -1
        abs_dist = abs(dist)
        if self.time_since_last_frame + abs_dist / MOVE_SPEED >= 1:
            dist1 = (1 - self.time_since_last_frame) * MOVE_SPEED
            self.forward_step(dist1 * sign)
            self.save_frame_with_turtle()
            self.time_since_last_frame = 0 
            # for loop to step forward
            num_steps = int((abs_dist - dist1) / MOVE_SPEED)
            for _ in range(num_steps):
                self.forward_step(MOVE_SPEED * sign)
                self.save_frame_with_turtle()
            last_abs_dist = abs_dist - dist1 - num_steps * MOVE_SPEED
            if last_abs_dist >= MOVE_SPEED:
                self.forward_step(MOVE_SPEED * sign)
                self.save_frame_with_turtle()
                last_abs_dist -= MOVE_SPEED
            self.forward_step(last_abs_dist * sign)
            self.time_since_last_frame = last_abs_dist / MOVE_SPEED
        else:
            self.forward_step(abs_dist * sign)
            # self.time_since_last_frame += abs_dist / MOVE_SPEED
            # if self.time_since_last_frame >= 1:
            #     self.time_since_last_frame = 0

    def forward_step(self, dist):
        # print('step', dist)
        if dist == 0:
            return
        x0, y0 = self.x, self.y
        x1 = (x0 + dist * np.cos(self.heading))
        y1 = (y0 - dist * np.sin(self.heading))
        if self.is_down:
            cv2.line(self.canvas, (int(np.rint(x0)), int(np.rint(y0))), (int(np.rint(x1)), int(np.rint(y1))), (0, 0, 0), 3)
        self.x, self.y = x1, y1
        self.time_since_last_frame += abs(dist) / MOVE_SPEED
        # self.frames.append(self.canvas.copy())
        # self.save_frame_with_turtle()

    def save_frame_with_turtle(self):
        # save the current frame to frames buffer
        # also plot a red triangle to represent the turtle pointing to the current direction

        # draw the turtle
        x, y = self.x, self.y
        canvas_copy = self.canvas.copy()
        triangle_size = 10
        x0 = int(np.rint(x + triangle_size * np.cos(self.heading)))
        y0 = int(np.rint(y - triangle_size * np.sin(self.heading)))
        x1 = int(np.rint(x + triangle_size * np.cos(self.heading + 2 * np.pi / 3)))
        y1 = int(np.rint(y - triangle_size * np.sin(self.heading + 2 * np.pi / 3)))
        x2 = int(np.rint(x + triangle_size * np.cos(self.heading - 2 * np.pi / 3)))
        y2 = int(np.rint(y - triangle_size * np.sin(self.heading - 2 * np.pi / 3)))
        x3 = int(np.rint(x - 0.25 * triangle_size * np.cos(self.heading)))
        y3 = int(np.rint(y + 0.25 * triangle_size * np.sin(self.heading)))
        # fill the triangle
        cv2.fillPoly(canvas_copy, [np.array([(x0, y0), (x1, y1), (x3, y3), (x2, y2)], dtype=np.int32)], (0, 0, 255))

        self.frames.append(canvas_copy)



    def left(self, angle):
        # print('angel', angle)
        # print('ast', self.heading)
        # self.heading += angle * np.pi / 180
        self.turn_to(angle)
        # print('aed', self.heading)

    def right(self, angle):
        # print('angel', angle)
        # print('ast', self.heading)
        # self.heading -= angle * np.pi / 180
        self.turn_to(-angle)
        # print('aed', self.heading)

    def turn_to(self, angle):
        abs_angle = abs(angle)
        sign = 1 if angle > 0 else -1
        if self.time_since_last_frame + abs(angle) / ROTATE_SPEED > 1:
            angle1 = (1 - self.time_since_last_frame) * ROTATE_SPEED
            self.turn_to_step(angle1 * sign)
            self.save_frame_with_turtle()
            self.time_since_last_frame = 0
            num_steps = int((abs_angle - angle1) / ROTATE_SPEED)
            for _ in range(num_steps):
                self.turn_to_step(ROTATE_SPEED * sign)
                self.save_frame_with_turtle()
            last_abs_angle = abs_angle - angle1 - num_steps * ROTATE_SPEED
            if last_abs_angle >= ROTATE_SPEED:
                self.turn_to_step(ROTATE_SPEED * sign)
                self.save_frame_with_turtle()
                last_abs_angle -= ROTATE_SPEED
            self.turn_to_step(last_abs_angle * sign)
            self.time_since_last_frame = last_abs_angle / ROTATE_SPEED
        else:
            self.turn_to_step(abs_angle * sign)
            # self.time_since_last_frame += abs_angle / ROTATE_SPEED

    def turn_to_step(self, angle):
        # print('turn step', angle)
        self.heading += angle * np.pi / 180
        self.time_since_last_frame += abs(angle) / ROTATE_SPEED

    def penup(self):
        self.is_down = False

    def pendown(self):
        self.is_down = True

    def save(self, path):
        if path:
            cv2.imwrite(path, self.canvas)
        return self.canvas
    
    def save_gif(self, path):
        import imageio.v3 as iio
        frames_rgb = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in self.frames]
        print(f'number of frames: {len(frames_rgb)}')
        frames_rgb.extend(FPS*2 * [frames_rgb[-1]])

        frames_rgb = crop_and_scaled_imgs(frames_rgb)
        # iio.imwrite(path, np.stack(frames_rgb), fps=30, plugin='pillow')
        return iio.imwrite('<bytes>', np.stack(frames_rgb), fps=FPS, loop=0, plugin='pillow', format='gif')


    class _TurtleState:
        def __init__(self, turtle):
            self.turtle = turtle
            self.position = None
            self.heading = None
            self.pen_status = None

        def __enter__(self):
            self.position = (self.turtle.x, self.turtle.y)
            self.heading = self.turtle.heading
            self.pen_status = self.turtle.is_down
            return self

        def __exit__(self, exc_type, exc_val, exc_tb):
            self.turtle.penup()
            self.turtle.x, self.turtle.y = self.position
            self.turtle.heading = self.heading
            if self.pen_status:
                self.turtle.pendown()

if __name__ == "__main__":
    turtle = Turtle()

    def forward(dist):
        turtle.forward(dist)

    def left(angle):
        turtle.left(angle)

    def right(angle):
        turtle.right(angle)

    def penup():
        turtle.penup()

    def pendown():
        turtle.pendown()

    def save(path):
        turtle.save(path)

    def fork_state():
        """
        Clone the current state of the turtle.

        Usage:
        with clone_state():
            forward(100)
            left(90)
            forward(100)
        """
        return turtle._TurtleState(turtle)

    # Example usage
    def example_plot():
        forward(5)

        with fork_state():
            forward(10)
            left(90)
            forward(10)
            with fork_state():
                right(90)
                forward(20)
                left(90)
                forward(10)
            left(90)
            forward(10)

        right(90)
        forward(50)
        save("test2.png")
        return turtle.frames

    def plot2():
        for j in range(2):
            forward(2)
            left(0.0)
            for i in range(4):
                forward(2)
                left(90)
            forward(0)
            left(180.0)
            forward(2)
            left(180.0)
        FINAL_IMAGE = turtle.save("")

    def plot3():
        frames = []
        frames.append(np.array(turtle.save("")))
        for j in range(2):
            forward(2)
            frames.append(np.array(turtle.save("")))
            left(0.0)
            for i in range(4):
                forward(2)
                left(90)
                frames.append(np.array(turtle.save("")))
            forward(0)
            left(180.0)
            forward(2)
            left(180.0)
            frames.append(np.array(turtle.save("")))
        
        return frames

    def make_gif(frames, filename):
        import imageio
        frames_rgb = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames]
        imageio.mimsave(filename, frames_rgb, fps=30)

    def make_gif2(frames, filename):
        import imageio.v3 as iio
        frames_rgb = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames]
        print(f'number of frames: {len(frames_rgb)}')
        iio.imwrite(filename, np.stack(frames_rgb), fps=30, plugin='pillow')
    
    def make_gif3(frames, filename):
        from moviepy.editor import ImageSequenceClip
        clip = ImageSequenceClip(list(frames), fps=20)
        clip.write_gif(filename, fps=20)

    def make_gif4(frames, filename):
        from array2gif import write_gif
        write_gif(frames, filename, fps=20)

    def make_gif5(frames, filename):
        from PIL import Image
        frames_rgb = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames]
        images = [Image.fromarray(frame) for frame in frames_rgb]
        images[0].save(filename, save_all=True, append_images=images[1:], duration=100, loop=0)



    def plot4():
        # the following program draws a treelike pattern
        import random

        def draw_tree(level, length, angle):
            if level == 0:
                return
            else:
                forward(length)
                left(angle)
                draw_tree(level-1, length*0.7, angle*0.8)
                right(angle*2)
                draw_tree(level-1, length*0.7, angle*0.8)
                left(angle)
                forward(-length)

        random.seed(0)  # Comment this line to change the randomness
        for _ in range(7):  # Adjust the number to control the density
            draw_tree(5, 5, 30)
            forward(0)
            left(random.randint(0, 360))
        turtle.save("test3.png")
        return turtle.frames

    def plot5():
        for i in range(7):
            with fork_state():
                for j in range(4):
                    forward(2*i)
                    left(90.0)
        return turtle.frames


    # make_gif2(plot5(), "test.gif")
    frames = plot5()
    # frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames]
    # breakpoint()
    # from moviepy.editor import ImageClip, concatenate_videoclips
    # clips = [ImageClip(frame).set_duration(1/24) for frame in frames]
    # concat_clip = concatenate_videoclips(clips, method="compose")
    # concat_clip.write_videofile("test.mp4", fps=24)



    img_bytes_string = turtle.save_gif("")
    # turtle.save('test3.png')
    with open("test5.gif", "wb") as f:
        f.write(img_bytes_string)

    


    # example_plot()
    # plot2()