import streamlit as st
import time
import base64
import io
import zipfile
from PIL import Image
from together import Together
import os
from dotenv import load_dotenv
from pydantic import BaseModel
from openai import OpenAI
import pandas as pd 

load_dotenv()
api_together = os.getenv("TOGETHER_API_KEY")
api_gemini = os.getenv("API_GEMINI")
MODEL = "gemini-2.0-flash-exp"
clientOpenAI = OpenAI(
    api_key=api_gemini,  
    base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
)

LOGO_STYLES = {
    "Minimalist": {
        "nome": "Minimalist",
        "stile_immagine": (
            "A minimalist logo design prompt focusing on simplicity and essential elements. "
            "Clean lines, a limited color palette, and strategic use of negative space. "
            "Ideal for conveying modernity and clarity."
        )
    },
    "Modern": {
        "nome": "Modern",
        "stile_immagine": (
            "A modern logo design prompt featuring abstract forms, innovative color combinations, "
            "and dynamic geometric shapes. Sleek typography and a contemporary aesthetic make it perfect for tech companies and startups."
        )
    },
    "Retro": {
        "nome": "Retro",
        "stile_immagine": (
            "A retro logo design prompt evoking nostalgia with vintage color schemes, classic typography, "
            "and design elements reminiscent of the past. Perfect for brands that want to express heritage and authenticity."
        )
    },
    "Vintage": {
        "nome": "Vintage",
        "stile_immagine": (
            "A vintage logo design prompt inspired by bygone eras. Emphasizes handcrafted details, worn textures, "
            "and a nostalgic atmosphere, ideal for artisanal products or brands with a long-standing tradition."
        )
    },
    "Geometric": {
        "nome": "Geometric",
        "stile_immagine": (
            "A geometric logo design prompt that leverages simple, precise shapes, clean lines, and symmetry. "
            "Communicates order, professionalism, and a rational approach to design."
        )
    },
    "Typographic": {
        "nome": "Typographic",
        "stile_immagine": (
            "A typographic logo design prompt focused on the creative use of lettering. "
            "Bold typography paired with minimal color usage highlights the strength of word-based identities."
        )
    },
}

class Logo(BaseModel):
    nome: str 
    descrizione: str
    english_description:str

class Loghi(BaseModel):
    loghi: list[Logo]

def generate_ai(num_loghi, tema, creativita):
    prompt = (
        f"Genera {num_loghi} prompt per la generazione immagini che trasmetta questo: {tema}" 
        "Sii molto SINTETICO e usa SIMBOLI stilizzati e non troppi oggetti. Restituisci il risultato in formato JSON seguendo lo schema fornito")
    completion = clientOpenAI.beta.chat.completions.parse(
        model=MODEL,
        messages=[
            {"role": "system", "content": f"Sei un assistente utile per la generazione di IDEE per la generazioni immagini su questo tema: {tema}."},
            {"role": "user", "content": prompt},
        ],
        temperature=creativita,
        response_format=Loghi,
    )
    loghi = completion.choices[0].message.parsed
    print(loghi)
    return loghi

# Funzione per generare le immagini, con gestione errori e retry dopo 10 secondi
def generate_image(prompt, max_retries=5):
    client = Together(api_key=api_together)
    retries = 0
    while retries < max_retries:
        try:
            response = client.images.generate(
                prompt=prompt,
                model="black-forest-labs/FLUX.1-schnell-Free",
                width=1024,
                height=1024,
                steps=4,
                n=1,
                response_format="b64_json"
            )
            return response.data  # Una lista di oggetti con attributo b64_json
        except Exception as e:
            print(f"Errore durante la generazione delle immagini: {e}. Riprovo tra 10 secondi...")
            time.sleep(9)
            retries += 1
    st.error("Numero massimo di tentativi raggiunto. Impossibile generare le immagini.")
    return None

def generate_images(logo: Logo, nome_stile, stile_immagine, num_immagini, colori: list[str] = None):
    if logo:
        images_bytes_list = []
        colors_str = " ".join(colori) if colori else ""

        prompt = f"Create a Simple Logo in {nome_stile} STYLE background white, and COLORS '{colors_str}' of '{logo.english_description}'. use this style {stile_immagine}"
        st.subheader(f"{logo.nome} 🖌️")
        st.write(logo.descrizione)
    print(prompt)
    for numero in range(num_immagini):
        images_data = generate_image(prompt)
        if images_data is not None:
            for i, img_obj in enumerate(images_data):
                try:
                    image_bytes = base64.b64decode(img_obj.b64_json)
                    image = Image.open(io.BytesIO(image_bytes))   
                    st.image(image, caption="")
                    img_byte_arr = io.BytesIO()
                    image.save(img_byte_arr, format='PNG')
                    images_bytes_list.append((f"image_{numero+1}_{i+1}.png", img_byte_arr.getvalue()))
                except Exception as e:
                    st.error(f"Errore nella visualizzazione dell'immagine {i+1}: {e}")
        else:
            st.error("Non è stato possibile generare le immagini. Riprova più tardi.")
        time.sleep(0.5)
    return images_bytes_list

def main():
    st.title("Logo Generator AI 🎨")
    st.sidebar.header("Impostazioni")
    selected_stile = st.sidebar.selectbox("Ambientazione", list(LOGO_STYLES.keys()), index=0)
    stile_default = LOGO_STYLES[selected_stile]["stile_immagine"]
    nome_stile = LOGO_STYLES[selected_stile]["nome"]
    stile_immagine = st.sidebar.text_area("Stile Immagine", stile_default, disabled=False)
    
    auto = True
    tema = st.sidebar.text_input("Tema Logo", value="Formazione Aziendale")
    #auto = st.sidebar.toggle(label= 'Generazione automatica', value = True)
    prompt_input = ""
    colori = st.sidebar.multiselect("Colori", ["Blue", "Orange", "Green", "Yellow", "Red", "Purple"], ["Blue", "Orange"])    
    num_loghi = st.sidebar.slider("Loghi", min_value=0, max_value=30, value=10, disabled=not auto)
    num_immagini = st.sidebar.slider("Variazioni", min_value=1, max_value=6, value=2)
    creativita = st.sidebar.slider("Creativita", min_value=0.1, max_value=1.0, value=0.95, step=0.1)
    submit_button = st.sidebar.button(label="Genera Immagine", type="primary", use_container_width=True)
    st.write("Genera il tuo **Logo Aziendale** tramite l'AI")
    
    if submit_button:
        if auto: 
            if num_loghi > 0: 
                with st.spinner('Generazione Loghi'):
                    loghi = generate_ai(num_loghi, tema, creativita)
                    st.subheader('Loghi 💡')
                    df = pd.DataFrame([{k: v for k, v in logo.model_dump().items() if k != ""} for logo in loghi.loghi])
                    st.dataframe(df, hide_index=True, use_container_width=True)                
                    st.divider()
        with st.spinner('Generazione Immagini'):
            images = []
            if loghi: 
                for logo in loghi.loghi:
                    images.extend(generate_images(logo, nome_stile, stile_immagine, num_immagini, colori))
            if images:
                zip_buffer = io.BytesIO()
                with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
                    for file_name, file_bytes in images:
                        zip_file.writestr(file_name, file_bytes)
                zip_buffer.seek(0)
                st.download_button(
                    label="Download All Images",
                    data=zip_buffer,
                    file_name="images.zip",
                    mime="application/zip", 
                    type='primary'
                )
            st.success("Immagini generate con successo!")

if __name__ == "__main__":
    st.set_page_config(page_title="Logo Generator AI", page_icon="🎨", layout="wide")
    main()