import streamlit as st
from PIL import Image
import torch
from transformers import (
    ViTFeatureExtractor, 
    ViTForImageClassification, 
    pipeline,
    AutoTokenizer,
    AutoModelForSeq2SeqLM
)
from diffusers import StableDiffusionPipeline

# Load models
@st.cache_resource
def load_models():
    age_model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier')
    age_transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier')
    
    gender_model = ViTForImageClassification.from_pretrained('rizvandwiki/gender-classification-2')
    gender_transforms = ViTFeatureExtractor.from_pretrained('rizvandwiki/gender-classification-2')
    
    emotion_model = ViTForImageClassification.from_pretrained('dima806/facial_emotions_image_detection')
    emotion_transforms = ViTFeatureExtractor.from_pretrained('dima806/facial_emotions_image_detection')
    
    object_detector = pipeline("object-detection", model="facebook/detr-resnet-50")
    
    action_model = ViTForImageClassification.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224')
    action_transforms = ViTFeatureExtractor.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224')
    
    prompt_enhancer_tokenizer = AutoTokenizer.from_pretrained("gokaygokay/Flux-Prompt-Enhance")
    prompt_enhancer_model = AutoModelForSeq2SeqLM.from_pretrained("gokaygokay/Flux-Prompt-Enhance")
    prompt_enhancer = pipeline('text2text-generation',
                               model=prompt_enhancer_model,
                               tokenizer=prompt_enhancer_tokenizer,
                               repetition_penalty=1.2,
                               device="cpu")
    
    # Load BK-SDM-Tiny for image generation
    pipe = StableDiffusionPipeline.from_pretrained("nota-ai/bk-sdm-tiny", torch_dtype=torch.float16)
    return (age_model, age_transforms, gender_model, gender_transforms, 
            emotion_model, emotion_transforms, object_detector, 
            action_model, action_transforms, prompt_enhancer, pipe)

models = load_models()
(age_model, age_transforms, gender_model, gender_transforms, 
 emotion_model, emotion_transforms, object_detector, 
 action_model, action_transforms, prompt_enhancer, pipe) = models

def predict(image, model, transforms):
    # Convert the image to RGB format if necessary
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    # Apply the transformations and predict
    inputs = transforms(images=[image], return_tensors='pt')
    output = model(**inputs)
    proba = output.logits.softmax(1)
    return proba.argmax(1).item()

def detect_attributes(image):
    age = predict(image, age_model, age_transforms)
    gender = predict(image, gender_model, gender_transforms)
    emotion = predict(image, emotion_model, emotion_transforms)
    action = predict(image, action_model, action_transforms)
    
    objects = object_detector(image)
    
    return {
        'age': age_model.config.id2label[age],
        'gender': gender_model.config.id2label[gender],
        'emotion': emotion_model.config.id2label[emotion],
        'action': action_model.config.id2label[action],
        'objects': [obj['label'] for obj in objects]
    }

def generate_prompt(attributes):
    prompt = f"A {attributes['age']} year old {attributes['gender']} person feeling {attributes['emotion']} "
    prompt += f"while {attributes['action']}. "
    if attributes['objects']:
        prompt += f"Image has {', '.join(attributes['objects'])}. "
    return prompt

def enhance_prompt(prompt):
    prefix = "enhance prompt: "
    enhanced = prompt_enhancer(prefix + prompt, max_length=256)
    return enhanced[0]['generated_text']

@st.cache_data
def generate_image(prompt):
    # Generate image from the prompt using the BK-SDM-Tiny model
    with torch.no_grad():
        image = pipe(prompt, num_inference_steps=50).images[0]
    return image

st.title("Image Attribute Detection and Image Generation")

uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    st.image(image, caption='Uploaded Image', use_column_width=True)

    if st.button('Analyze Image'):
        with st.spinner('Detecting attributes...'):
            attributes = detect_attributes(image)

        st.write("Detected Attributes:")
        for key, value in attributes.items():
            st.write(f"{key.capitalize()}: {value}")

        with st.spinner('Generating prompt...'):
            initial_prompt = generate_prompt(attributes)
            enhanced_prompt = enhance_prompt(initial_prompt)
        
        st.write("Initial Prompt:")
        st.write(initial_prompt)
        st.write("Enhanced Prompt:")
        st.write(enhanced_prompt)

        with st.spinner('Generating image...'):
            generated_image = generate_image(enhanced_prompt)
        st.image(generated_image, caption='Generated Image', use_column_width=True)