import gradio as gr
import shutil
import zipfile
import tensorflow as tf
import pandas as pd
import pathlib
import PIL.Image
import os
import subprocess

def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
    w, h = image.size
    if w == h:
        return image
    elif w > h:
        new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
        new_image.paste(image, (0, (w - h) // 2))
        return new_image
    else:
        new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
        new_image.paste(image, ((h - w) // 2, 0))
        return new_image


class ModelTrainer:
    def __init__(self):
        self.training_pictures = []
        self.training_model = None

    def unzip_file(self, zip_file_path):
        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
            extracted_path = zip_file_path.replace('.zip', '')
            zip_ref.extractall(extracted_path)
            file_names = zip_ref.namelist()
            for file_name in file_names:
                if file_name.endswith(('.jpeg', '.jpg', '.png')):
                    self.training_pictures.append(f'{extracted_path}/{file_name}')

    def train(self, pretrained_model_name_or_path: str, instance_images: list | None):
        output_model_name = 'a-xyz-model'
        resolution = 512
        repo_dir = pathlib.Path(__file__).parent
        subdirs = ['train-instance', 'train-class', 'experiments']
        dir_paths = []

        for subdir in subdirs:
            dir_path = repo_dir / subdir / output_model_name
            dir_paths.append(dir_path)
            shutil.rmtree(dir_path, ignore_errors=True)
            os.makedirs(dir_path, exist_ok=True)

        instance_data_dir, class_data_dir, output_dir = dir_paths

        for i, temp_path in enumerate(instance_images):
            image = PIL.Image.open(temp_path.name)
            image = pad_image(image)
            image = image.resize((resolution, resolution))
            image = image.convert('RGB')
            out_path = instance_data_dir / f'{i:03d}.jpg'
            image.save(out_path, format='JPEG', quality=100)

        command = [
            'python', '-u',
            'train_dreambooth_cloneofsimo_lora.py',
            '--pretrained_model_name_or_path', pretrained_model_name_or_path,
            '--instance_data_dir', instance_data_dir,
            '--class_data_dir', class_data_dir,
            '--resolution', '768',
            '--output_dir', output_dir,
            '--instance_prompt', 'a photo of a pwsm dog',
            '--with_prior_preservation',
            '--class_prompt', 'a dog',
            '--prior_loss_weight', '1.0',
            '--num_class_images', '100',
            '--learning_rate', '0.0004',
            '--train_batch_size', '1',
            '--sample_batch_size', '1',
            '--max_train_steps', '400',
            '--gradient_accumulation_steps', '1',
            '--gradient_checkpointing',
            '--train_text_encoder',
            '--learning_rate_text', '5e-6',
            '--save_steps', '100',
            '--seed', '1337',
            '--lr_scheduler', 'constant',
            '--lr_warmup_steps', '0'
        ]

        result = subprocess.run(command)
        return result

    def generate_picture(self, row):
        num_of_training_steps, learning_rate, checkpoint_steps, abc = row
        return f'Picture generated for num_of_training_steps: {num_of_training_steps}, learning_rate: {learning_rate}, checkpoint_steps: {checkpoint_steps}'

    def generate_pictures(self, csv_input):
        csv = pd.read_csv(csv_input.name)
        result = []
        for index, row in csv.iterrows():
            result.append(self.generate_picture(row))
        return "\n".join(str(item) for item in result)

loader = ModelTrainer()

with gr.Blocks() as demo:
    with gr.Box():
        instance_images = gr.Files(label='Instance images')
        pretrained_model_name_or_path = gr.Textbox(lines=1, label='pretrained_model_name_or_path', placeholder='stabilityai/stable-diffusion-2-1')
        output_message = gr.Markdown()
        train_button = gr.Button('Train')
        train_button.click(fn=loader.train, inputs=[pretrained_model_name_or_path, instance_images], outputs=[output_message])
    with gr.Box():
        csv_input = gr.File(label='CSV File')
        output_message2 = gr.Markdown()
        generate_button = gr.Button('Generate Pictures from CSV')
        generate_button.click(fn=loader.generate_pictures, inputs=[csv_input], outputs=[output_message2])

demo.launch()