# Based on liuhaotian/LLaVA-1.6

import sys
import os
import argparse
import time
import subprocess

import gradio as gr
import llava.serve.gradio_web_server as gws

# Execute the pip install command with additional options
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'wheel', 'setuptools'])
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])


def start_controller():
    print("Starting the controller")
    controller_command = [
        sys.executable,
        "-m",
        "llava.serve.controller",
        "--host",
        "0.0.0.0",
        "--port",
        "10000",
    ]
    print("Controller Command:", controller_command)
    return subprocess.Popen(controller_command)


def start_worker(model_path: str, bits=16):
    print(f"Starting the model worker for the model {model_path}")
    model_name = model_path.strip("/").split("/")[-1]
    assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
    if bits != 16:
        model_name += f"-{bits}bit"
    model_name += "-lora"
    worker_command = [
        sys.executable,
        "-m",
        "llava.serve.model_worker",
        "--host",
        "0.0.0.0",
        "--controller",
        "http://localhost:10000",
        "--model-path",
        model_path,
        "--model-name",
        model_name,
        "--model-base",
        "liuhaotian/llava-1.5-7b",
        "--use-flash-attn",
    ]
    print("Worker Command:", worker_command)
    return subprocess.Popen(worker_command)


def handle_text_prompt(text, temperature=0.2, top_p=0.7, max_new_tokens=512):
    """
    Custom API endpoint to handle text prompts.
    Replace the placeholder logic with actual model inference.
    """
    # TODO: Replace the following placeholder with actual model inference code
    print(f"Received prompt: {text}")
    print(f"Parameters - Temperature: {temperature}, Top P: {top_p}, Max New Tokens: {max_new_tokens}")
    
    # Example response (replace with actual model response)
    response = f"Model response to '{text}' with temperature={temperature}, top_p={top_p}, max_new_tokens={max_new_tokens}"
    return response


def add_text_with_image(text, image, mode):
    """
    Custom API endpoint to add text with an image.
    Replace the placeholder logic with actual processing.
    """
    # TODO: Replace the following placeholder with actual processing code
    print(f"Adding text: {text}")
    print(f"Image path: {image}")
    print(f"Image processing mode: {mode}")
    
    # Example response (replace with actual processing code)
    response = f"Added text '{text}' with image at '{image}' using mode '{mode}'."
    return response


def build_custom_demo(embed_mode=False, cur_dir='./', concurrency_count=5):
    """
    Builds a Gradio Blocks interface with custom API endpoints.
    """
    with gr.Blocks() as demo:
        gr.Markdown("# AstroLLaVA")
        gr.Markdown("Welcome to the AstroLLaVA interface. Use the API endpoints to interact with the model.")

        with gr.Row():
            with gr.Column():
                gr.Markdown("## Prompt the Model")
                text_input = gr.Textbox(label="Enter your text prompt", placeholder="Type your prompt here...")
                temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, label="Temperature")
                top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Top P")
                max_tokens_slider = gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="Max New Tokens")
                submit_button = gr.Button("Submit Prompt")
            with gr.Column():
                chatbot_output = gr.Textbox(label="Model Response", interactive=False)

        submit_button.click(
            fn=handle_text_prompt,
            inputs=[text_input, temperature_slider, top_p_slider, max_tokens_slider],
            outputs=chatbot_output,
            api_name="prompt_model"  # Custom API endpoint name
        )

        with gr.Row():
            with gr.Column():
                gr.Markdown("## Add Text with Image")
                add_text_input = gr.Textbox(label="Add Text", placeholder="Enter text to add...")
                add_image_input = gr.Image(label="Upload Image")
                image_process_mode = gr.Radio(choices=["Crop", "Resize", "Pad", "Default"], value="Default", label="Image Process Mode")
                add_submit_button = gr.Button("Add Text with Image")
            with gr.Column():
                add_output = gr.Textbox(label="Add Text Response", interactive=False)

        add_submit_button.click(
            fn=add_text_with_image,
            inputs=[add_text_input, add_image_input, image_process_mode],
            outputs=add_output,
            api_name="add_text_with_image"  # Another custom API endpoint
        )

        # Additional API endpoints can be added here following the same structure

    return demo


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="AstroLLaVA Gradio App")
    parser.add_argument("--host", type=str, default="0.0.0.0", help="Hostname to listen on")
    parser.add_argument("--port", type=int, default=7860, help="Port number")
    parser.add_argument("--controller-url", type=str, default="http://localhost:10000", help="Controller URL")
    parser.add_argument("--concurrency-count", type=int, default=5, help="Number of concurrent requests")
    parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"], help="Model list mode")
    parser.add_argument("--share", action="store_true", help="Share the Gradio app publicly")
    parser.add_argument("--moderate", action="store_true", help="Enable moderation")
    parser.add_argument("--embed", action="store_true", help="Enable embed mode")
    args = parser.parse_args()
    gws.args = args
    gws.models = []

    gws.title_markdown += """ AstroLLaVA """

    print(f"AstroLLaVA arguments: {gws.args}")

    model_path = os.getenv("model", "universeTBD/AstroLLaVA_v2")
    bits = int(os.getenv("bits", 4))
    concurrency_count = int(os.getenv("concurrency_count", 5))

    controller_proc = start_controller()
    worker_proc = start_worker(model_path, bits=bits)

    # Wait for worker and controller to start
    print("Waiting for worker and controller to start...")
    time.sleep(30)

    exit_status = 0
    try:
        # Build the custom Gradio demo with additional API endpoints
        demo = build_custom_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
        print("Launching Gradio with custom API endpoints...")
        demo.queue(
            status_update_rate=10,
            api_open=False
        ).launch(
            server_name=args.host,
            server_port=args.port,
            share=args.share
        )

    except Exception as e:
        print(f"An error occurred: {e}")
        exit_status = 1
    finally:
        worker_proc.kill()
        controller_proc.kill()

        sys.exit(exit_status)