import gradio as gr
from huggingface_hub import InferenceClient
import langdetect
import json

# Initialize Hugging Face client with the new model
client = InferenceClient(model="OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5")

# Default system message to guide the assistant
default_system_message = (
    "You are NLPToolkit Agent, an advanced assistant specializing in NLP tasks such as text summarization, "
    "sentiment analysis, text classification, and entity recognition. Adapt your responses to the selected task."
)

# Predefined task-specific instructions
task_instructions = {
    "Summarization": "Summarize the text clearly and concisely.",
    "Sentiment Analysis": "Analyze the sentiment of the text (positive, neutral, negative).",
    "Text Classification": "Classify the text into relevant categories.",
    "Entity Recognition": "Identify and list named entities in the text."
}


# Enhanced text preprocessing function (from pipeline)
def preprocess_input(text):
    """
    Clean and validate the user's input text with better error handling and language detection.
    """
    try:
        # Detect input language
        language = langdetect.detect(text)
        if language != "en":
            return f"Input language detected as {language}. Please provide input in English."
    except langdetect.lang_detect_exception.LangDetectException:
        return "Unable to detect language. Please provide valid text input."
    except Exception as e:
        return f"An error occurred while processing the text: {str(e)}"

    return text.strip()


# Model inference function (from pipeline)
def run_model_inference(messages, max_tokens, temperature, top_p):
    """
    Run model inference based on the messages with specified parameters.
    """
    try:
        response = ""
        for chunk in client.chat_completion(
            messages=messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
        ):
            token = chunk.choices[0].delta.content
            response += token
            yield response
    except Exception as e:
        yield f"Error generating response: {str(e)}"


# Output postprocessing (from pipeline)
def postprocess_output(response):
    """
    Postprocess the model's response before presenting it to the user.
    """
    # Example: Clean up the response or format it if necessary
    return response.strip()


# Enhanced respond function with pipeline integration
def respond(task, message, history, system_message, max_tokens, temperature, top_p):
    """
    Handle user messages and generate responses using the NLP model with integrated pipeline.
    """
    # Apply task-specific instructions
    system_message = f"{system_message} Task: {task_instructions.get(task, 'General NLP task')}"

    # Preprocess the user's input using the pipeline
    message = preprocess_input(message)
    if message.startswith("Input language detected") or message.startswith("Unable to detect"):
        return message  # Early exit on language issues

    # Prepare conversation history
    messages = [{"role": "system", "content": system_message}]
    for user_message, assistant_message in history:
        if user_message:
            messages.append({"role": "user", "content": user_message})
        if assistant_message:
            messages.append({"role": "assistant", "content": assistant_message})

    messages.append({"role": "user", "content": message})

    # Get model response using the pipeline function (streamed)
    response = ""
    for chunk in run_model_inference(
        messages=messages,
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p
    ):
        response = chunk

    # Postprocess the model's response before sending it to the user
    return postprocess_output(response)


# Improved chat history management functions with better file handling
def save_history(history, filename="chat_history.json"):
    try:
        with open(filename, "w") as f:
            json.dump(history, f)
        return "Chat history saved successfully."
    except Exception as e:
        return f"Error saving chat history: {str(e)}"


def load_history(filename="chat_history.json"):
    try:
        with open(filename, "r") as f:
            history = json.load(f)
        return history
    except FileNotFoundError:
        return []
    except json.JSONDecodeError:
        return []  # Handle case where the file is empty or corrupt


# Refactor the Gradio interface to be more organized and responsive
def create_interface():
    """
    Create and enhance the Gradio interface for the chatbot with improved layout and feedback.
    """
    with gr.Blocks() as demo:
        gr.Markdown("## 🧠 NLPToolkit Agent\nAn advanced assistant for NLP tasks, powered by Hugging Face.")

        # Organize task selection and parameters in a better layout
        with gr.Row():
            task = gr.Dropdown(
                choices=["Summarization", "Sentiment Analysis", "Text Classification", "Entity Recognition"],
                value="Summarization",
                label="Select NLP Task"
            )

        with gr.Row():
            user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
            system_message = gr.Textbox(value=default_system_message, label="System Message")

        with gr.Row():
            chat_history = gr.State(value=[])
            assistant_response = gr.Textbox(label="Assistant Response", interactive=False)

        with gr.Row():
            max_tokens = gr.Slider(1, 2048, value=512, label="Max Tokens")
            temperature = gr.Slider(0.1, 4.0, value=0.7, label="Temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top-p (Nucleus Sampling)")

        with gr.Row():
            save_button = gr.Button("Save Chat History")
            load_button = gr.Button("Load Chat History")

        with gr.Row():
            submit_button = gr.Button("Generate Response")

        # Connect button actions and ensure smooth flow
        submit_button.click(
            fn=respond,
            inputs=[task, user_input, chat_history, system_message, max_tokens, temperature, top_p],
            outputs=assistant_response
        )

        save_button.click(fn=save_history, inputs=chat_history, outputs=None)
        load_button.click(fn=load_history, inputs=None, outputs=chat_history)

        gr.Markdown("### 🚀 Powered by Hugging Face and Gradio | Developed by Canstralian")

    return demo


# Run the enhanced Gradio app
if __name__ == "__main__":
    demo = create_interface()
    demo.launch()