import os
import gradio as gr
from typing import Callable, Generator
import base64
from openai import OpenAI

END_POINT = os.environ.get("ENDPOINT")
SECRET_KEY = os.environ.get("SECRETKEY")
USERS = os.environ.get("USERS")
PWD = os.environ.get("PWD")

def get_fn(model_name: str, **model_kwargs) -> Callable:
    """Create a chat function with the specified model."""
    
    # Instantiate an OpenAI client for a custom endpoint
    try:
        client = OpenAI(
            base_url=END_POINT,
            api_key=SECRET_KEY,  
        )
    except Exception as e:
        print(f"The API or base URL were not defined: {str(e)}")
        raise e  

    def predict(
        messages: list,  
        temperature: float,
        max_tokens: int,
        top_p: float
    ) -> Generator[str, None, None]:
        try:
            # Call the OpenAI API with the formatted messages
            response = client.chat.completions.create(
                model=model_name,  
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                stream=True, 
                response_format={"type": "text"},
            )

            response_text = ""
            for chunk in response:
                if len(chunk.choices[0].delta.content) > 0:
                    content = chunk.choices[0].delta.content
                    if content:
                        response_text += content
                        yield response_text.strip()
            
            if not response_text.strip():
                yield "I apologize, but I was unable to generate a response. Please try again."

        except Exception as e:
            print(f"Error during generation: {str(e)}")
            yield f"An error occurred: {str(e)}"

    return predict

def get_image_base64(url: str, ext: str) -> str:
    with open(url, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
    return f"data:image/{ext};base64,{encoded_string}"

def handle_user_msg(message: str) -> str:
    if isinstance(message, str):
        return message
    elif isinstance(message, dict):
        if message.get("files"):
            ext = os.path.splitext(message["files"][-1])[1].strip(".").lower()
            if ext in ["png", "jpg", "jpeg", "gif", "pdf"]:
                encoded_str = get_image_base64(message["files"][-1], ext)
                return f"{message.get('text', '')}\n![Image]({encoded_str})"
            else:
                raise NotImplementedError(f"Unsupported file type: {ext}")
        else:
            return message.get("text", "")
    else:
        raise NotImplementedError("Unsupported message type")

def get_interface_args(pipeline: str):
    if pipeline == "chat":
        inputs = None
        outputs = None

        def preprocess(message, history):
            messages = []
            files = None
            for user_msg, assistant_msg in history:
                if assistant_msg is not None:
                    messages.append({"role": "user", "content": handle_user_msg(user_msg)})
                    messages.append({"role": "assistant", "content": assistant_msg})
                else:
                    files = user_msg
            if isinstance(message, str) and files is not None:
                message = {"text": message, "files": files}
            elif isinstance(message, dict) and files is not None:
                if not message.get("files"):
                    message["files"] = files
            messages.append({"role": "user", "content": handle_user_msg(message)})
            return {"messages": messages}

        postprocess = lambda x: x  # No additional postprocessing needed

    else:
        raise ValueError(f"Unsupported pipeline type: {pipeline}")
    return inputs, outputs, preprocess, postprocess

def registry(name: str = None, **kwargs) -> gr.ChatInterface:
    """Create a Gradio Interface with similar styling and parameters."""
    
    # Retrieving preprocess and postprocess functions
    _, _, preprocess, postprocess = get_interface_args("chat")
    
    # Getting the predict function
    predict_fn = get_fn(model_name=name, **kwargs)
    
    # Defining a wrapper function that integrates preprocessing and postprocessing
    def wrapper(message, history, system_prompt, temperature, max_tokens, top_p):
        # Preprocessing the inputs
        preprocessed = preprocess(message, history)
        
        # Extracting the preprocessed messages
        messages = preprocessed["messages"]
        
        # Calling the predict function and generate the response
        response_generator = predict_fn(
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=top_p
        )
        
        # Collecting the generated response
        response = ""
        for partial_response in response_generator:
            response = partial_response  # Gradio will handle streaming
            yield response

    # Creating the Gradio ChatInterface with the wrapper function
    interface = gr.ChatInterface(
        fn=wrapper,
        additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
        additional_inputs=[
            gr.Textbox(
                value="You are a helpful AI assistant.",
                label="System prompt"
            ),
            gr.Slider(0.0, 1.0, value=0.7, label="Temperature"),
            gr.Slider(128, 4096, value=1024, label="Max new tokens"),
            gr.Slider(0.0, 1.0, value=0.95, label="Top P sampling"),
        ],
    )
    
    return interface