import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import torch
import cv2
import os
import base64
import soundfile as sf
import time

# --- Set up Models ---

# Stable Diffusion for image generation
scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1",
    scheduler=scheduler,
    torch_dtype=torch.float16
).to("cuda")

# LLaVA for vision-based language understanding
tokenizer = AutoTokenizer.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers")
model = AutoModelForCausalLM.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers").to("cuda")

# Open-source language model for text generation (e.g., GPT-Neo)
gpt_neo_pipe = pipeline("text-generation", model="EleutherAI/gpt-neo-1.3B")

# Text-to-Speech
text_to_speech = pipeline(
    "text-to-speech", model="espnet/fastspeech2_en_ljspeech"
)

# --- Functions ---

def process_image(image_base64, chat_history):
    """Processes an image, sends it to LLaVA, and generates a response."""
    # Prepare LLaVA input
    input_text = f"""<image> {image_base64} </image>\n\nWhat do you see in this image?"""
    inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

    # Generate response using LLaVA
    with torch.no_grad():
        outputs = model(**inputs)
    response = tokenizer.decode(outputs.logits.argmax(-1)[0], skip_special_tokens=True)

    # Generate speech from the response
    audio = text_to_speech(response)
    audio_path = "generated_audio.wav"
    sf.write(audio_path, audio[0].numpy(), samplerate=22050)

    # Update chat history
    chat_history += "You:  Image\n"
    chat_history += "Model: " + response + "\n"

    return chat_history, audio_path

def generate_image(prompt, chat_history):
    """Generates an image using Stable Diffusion based on a prompt."""
    image = pipe(
        prompt=prompt,
        guidance_scale=7.5,
        num_inference_steps=50,
    ).images[0]

    # Update chat history
    chat_history += "You: " + prompt + "\n"
    chat_history += "Model:  Image\n"

    return chat_history, image

def process_text(text, chat_history):
    """Processes text, generates a response using GPT-Neo, and generates speech."""
    # Generate response using GPT-Neo
    response = gpt_neo_pipe(
        text,
        max_length=100,
        num_return_sequences=1,
    )[0]["generated_text"]

    # Generate speech from the response
    audio = text_to_speech(response)
    audio_path = "generated_audio.wav"
    sf.write(audio_path, audio[0].numpy(), samplerate=22050)

    # Update chat history
    chat_history += "You: " + text + "\n"
    chat_history += "Model: " + response + "\n"

    return chat_history, audio_path

# --- Webcam Capture ---

def capture_image():
    """Captures a screenshot from the webcam."""
    cap = cv2.VideoCapture(0)
    ret, frame = cap.read()
    cap.release()
    image = Image.fromarray(frame)
    image_bytes = image.convert("RGB").save("captured_image.jpg", "JPEG")
    with open("captured_image.jpg", "rb") as f:
        image_base64 = base64.b64encode(f.read()).decode("utf-8")
    return image_base64

# --- Gradio Interface ---

with gr.Blocks() as demo:
    gr.Markdown("## Llama-LLaVA Vision Speech Assistant")
    chat_history = gr.Textbox(label="Chat History", lines=10, interactive=False)
    webcam_output = gr.Image(label="Webcam Feed", interactive=False)
    image_input = gr.Image(label="Uploaded Image")
    text_input = gr.Textbox(label="Enter Text")
    audio_output = gr.Audio(label="Audio Response")

    # Screenshot button
    screenshot_button = gr.Button("Capture Screenshot")
    screenshot_button.click(fn=capture_image, outputs=image_input)

    # Image processing (LLaVA)
    image_input.change(fn=process_image, inputs=[image_input, chat_history], outputs=[chat_history, audio_output])

    # Text processing (GPT-Neo)
    text_input.submit(fn=process_text, inputs=[text_input, chat_history], outputs=[chat_history, audio_output])

    # Image generation (Stable Diffusion)
    with gr.Tab("Image Generation"):
        image_prompt = gr.Textbox(label="Enter image prompt:")
        image_generation_output = gr.Image(label="Generated Image")
        generate_image_button = gr.Button("Generate Image")
        generate_image_button.click(
            fn=generate_image, inputs=[image_prompt, chat_history], outputs=[chat_history, image_generation_output]
        )

    # Webcam stream
    with gr.Tab("Webcam"):
        webcam_output = gr.Image(label="Webcam Feed", source="webcam", interactive=False)
        # Update webcam image every second
        def update_webcam():
            cap = cv2.VideoCapture(0)
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                image = Image.fromarray(frame)
                yield image
                time.sleep(1)  # Update every second

        webcam_output.source = update_webcam()

    demo.launch(share=True)