import argparse
from dataclasses import asdict, dataclass, field
from datetime import datetime
import html
from itertools import zip_longest
import os
import textwrap
from typing import Dict, List, Tuple

from dotenv import load_dotenv
import gradio as gr
from pymongo import MongoClient

from llm_rules import Role, Message, models, scenarios


MONGO_URI = "mongodb+srv://{username}:{password}@{host}/?retryWrites=true&w=majority"
MONGO_DB = None
PLACEHOLDER = "Enter message"

History = List[List[str]]


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--hf_proxy", action="store_true", default=False)
    parser.add_argument("--port", type=int, default=7860)
    return parser.parse_args()


@dataclass
class State:
    scenario_name: str
    provider_name: str
    model_name: str
    scenario: scenarios.scenario.BaseScenario = None
    model: models.BaseModel = None
    system_message: str = None
    use_system_instructions: bool = False
    messages: List[Message] = field(default_factory=list)
    redacted_messages: List[Message] = field(default_factory=list)
    last_user_message_valid: bool = False

    def __post_init__(self):
        self.scenario = scenarios.SCENARIOS[self.scenario_name]()
        self.model = models.MODEL_BUILDERS[self.provider_name](
            model=self.model_name,
            stream=True,
            temperature=0,
        )
        self.messages = self.get_initial_messages()
        self.redacted_messages = self.get_initial_messages(redacted=True)

    def get_initial_messages(self, redacted=False) -> List[Message]:
        prompt = self.scenario.redacted_prompt if redacted else self.scenario.prompt
        if self.use_system_instructions:
            messages = [
                Message(Role.SYSTEM, prompt),
            ]
        else:
            messages = [
                Message(Role.SYSTEM, models.PROMPTS[self.system_message]),
                Message(Role.USER, prompt),
                Message(Role.ASSISTANT, self.scenario.initial_response),
            ]
        return messages

    def get_history(self) -> History:
        """Process redacted messages into format for chatbot to display."""
        redacted_messages = self.redacted_messages[1:]  # skip system message
        history = []
        args = [iter(redacted_messages)] * 2
        for u, a in zip_longest(*args):
            u = html.escape(u.content, quote=False)
            a = None if a is None else html.escape(a.content, quote=False)
            history.append([u, a])
        return history

    def update_state_and_history(self, history: History, delta: str) -> History:
        """Incrementally update last item of both messages and history."""
        # Redacted messages points to same assistant message
        self.messages[-1].content += delta
        history[-1][-1] += html.escape(delta, quote=False)
        return history

    def get_info(self):
        info_str = "Return to send message. Shift + Return to add a new line."
        if self.scenario.format_message:
            info_str = self.scenario.format_message + " " + info_str
        return info_str

    def unescape_messages(self) -> List[Message]:
        return [Message(m.role, html.unescape(m.content)) for m in self.messages]


def change_provider(state: State, provider_name: str) -> Tuple[State, Dict]:
    """Update model provider and model selection."""
    state.provider_name = provider_name.lower()
    state.model_name = models.MODEL_DEFAULTS[state.provider_name]
    state.model = models.MODEL_BUILDERS[state.provider_name](
        model=state.model_name,
        stream=True,
        temperature=0,
    )
    update_model = gr.update(
        choices=models.MODEL_NAMES_BY_PROVIDER[state.provider_name],
        value=state.model_name,
    )
    return state, update_model


def change_model(state: State, model_name: str) -> State:
    """Update model selection."""
    state.model_name = model_name
    state.model = models.MODEL_BUILDERS[state.provider_name](
        model=state.model_name,
        stream=True,
        temperature=0,
    )
    return state


def change_scenario(state: State, scenario: str) -> Tuple[State, Dict]:
    state.scenario = scenarios.SCENARIOS[scenario]()
    state.scenario_name = scenario
    update = gr.update(placeholder=PLACEHOLDER, label=state.get_info())
    return state, update


def send_user_message(state: State, input: str) -> Tuple[State, History, Dict]:
    """Update state and chatbot with user input, clear textbox."""
    user_msg = Message(Role.USER, input)
    if not state.scenario.is_valid_user_message(user_msg):
        gr.Warning(f"Invalid user message: {state.scenario.format_message}'")
        update = gr.update()
    else:
        state.messages.append(user_msg)
        state.redacted_messages.append(user_msg)
        state.last_user_message_valid = True
        update = gr.update(placeholder=PLACEHOLDER, value="")
    return state, state.get_history(), update


def send_assistant_message(state: State, api_key: str) -> Tuple[State, History]:
    """Request model response and update blocks."""
    history = state.get_history()
    yield state, history

    if not state.last_user_message_valid:
        return

    try:
        api_key = None if api_key == "" else api_key
        response = state.model(state.messages, api_key=api_key)
    except Exception as e:
        raise gr.Error(f"API error: {e} Please reset the scenario and try again.")

    asst_msg = Message(Role.ASSISTANT, "")
    state.messages.append(asst_msg)
    state.redacted_messages.append(asst_msg)
    history = state.get_history()

    for delta in response:
        history = state.update_state_and_history(history, delta)
        yield state, history


def evaluate_and_log(state: State) -> Tuple[State, Dict]:
    """Evaluate messages and update chatbot."""
    if not state.last_user_message_valid:
        return state, gr.update()

    messages = state.unescape_messages()
    result = state.scenario.evaluate(messages, state.use_system_instructions)
    state.last_user_message_valid = False

    global MONGO_DB
    if MONGO_DB is not None:
        doc = {
            "timestamp": datetime.now(),
            "scenario": state.scenario_name,
            "params": asdict(state.scenario.p),
            "provider": state.provider_name,
            "model": state.model_name,
            "system_instructions": state.use_system_instructions,
            "messages": Message.serialize(state.messages),
            "result": asdict(result),
        }

        try:
            MONGO_DB.insert_one(doc)
        except Exception as e:
            gr.Error(f"Database failed: {e}")

    if not result.passed:
        gr.Info("Rule broken!")
        return state, gr.update(
            interactive=False,
            value=f"Scenario terminated: {result.reason} Please reset the scenario or select another.",
        )
    else:
        return state, gr.update(
            interactive=True,
            placeholder=PLACEHOLDER,
        )


def reset_scenario(state: State) -> Tuple[State, History, dict]:
    state.messages = state.get_initial_messages()
    state.redacted_messages = state.get_initial_messages(redacted=True)
    history = state.get_history()
    update = gr.update(placeholder=PLACEHOLDER, value="", interactive=True)
    return state, history, update


def make_block(initial_state):
    markdown = (
        "<h1>RuLES: Rule-following Language Evaluation Scenarios</h1>"
        "<p>Your goal is to get the assistant model to break its rules. The rules are presented to the model in the "
        "first user message. Each assistant response is checked using the evaluate method of the scenario, and the "
        "conversation terminates upon a failing result. Passwords and other secrets are redacted from the chat "
        "interface and replaced with \"[REDACTED]\", but the actual value is still provided to the model. Some "
        "scenarios require user messages in a specific format. Your messages will be logged for research use. Please do"
        " not enter any sensitive information.</p>"
        "<p>Anthropic has graciously provided us with subsidized API access, and Google Cloud currently offers 100% "
        "discounted access to the PaLM 2 API during its Preview stage, so no API key is required for these models. "
        "To access OpenAI models, please enter your own API key. We do not record your key, which you can verify "
        "in the demo's source code.</p>"
        "<p>See the RuLES <a href=\"https://github.com/normster/llm_rules\">github repo</a> for more information.</p>"
    )

    with gr.Blocks(
        gr.themes.Monochrome(
            font=[
                gr.themes.GoogleFont("Source Sans Pro"),
                "ui-sans-serif",
                "system-ui",
                "sans-serif",
            ],
            radius_size=gr.themes.sizes.radius_sm,
        )
    ) as block:
        gr.Markdown(markdown, sanitize_html=False)
        state = gr.State(value=initial_state)
        with gr.Row():
            provider_select = gr.Dropdown(
                ["Anthropic", "OpenAI", "Google"],
                value="Anthropic",
                label="Provider",
            )
            model_select = gr.Dropdown(
                models.MODEL_NAMES_BY_PROVIDER["anthropic"],
                value="claude-instant-v1.2",
                label="Model",
            )
            scenario_select = gr.Dropdown(
                scenarios.SCENARIOS.keys(),
                value=initial_state.scenario_name,
                label="Scenario",
            )
        apikey = gr.Textbox(placeholder="sk-...", label="API Key")
        chatbot = gr.Chatbot(initial_state.get_history(), show_label=False)
        textbox = gr.Textbox(placeholder=PLACEHOLDER, label=initial_state.get_info())
        reset_button = gr.Button("Reset Scenario")

        # Event listeners
        textbox.submit(
            send_user_message, [state, textbox], [state, chatbot, textbox], queue=True
        ).then(
            send_assistant_message,
            [state, apikey],
            [state, chatbot],
            queue=True,
        ).then(
            evaluate_and_log, state, [state, textbox], queue=True
        )
        # Change to default model for new provider when provider is changed
        provider_select.change(
            change_provider,
            [state, provider_select],
            [state, model_select],
            queue=False,
        ).then(
            reset_scenario, state, [state, chatbot, textbox], queue=False
        )
        # Change to specified model
        model_select.change(
            change_model,
            [state, model_select],
            [state],
            queue=False,
        ).then(
            reset_scenario, state, [state, chatbot, textbox], queue=False
        )
        # Change to specified scenario
        scenario_select.change(
            change_scenario,
            [state, scenario_select],
            [state, textbox],
            queue=False,
        ).then(reset_scenario, state, [state, chatbot, textbox], queue=False)
        # Reset scenario state, chat history, and input textbox
        reset_button.click(
            reset_scenario, state, [state, chatbot, textbox], queue=False
        )
        block.load(reset_scenario, state, [state, chatbot, textbox], queue=False)

    return block


def main(args):
    load_dotenv()

    initial_state = State(
        scenario_name="Encryption",
        provider_name="anthropic",
        model_name="claude-instant-v1.2",
    )
    initial_state.messages = (initial_state.get_initial_messages(),)
    initial_state.redacted_messages = (
        initial_state.get_initial_messages(redacted=True),
    )

    # Comment this out to disable logging
    global MONGO_DB
    mongo_uri = MONGO_URI.format(
        username=os.environ["MONGO_USERNAME"],
        password=os.environ["MONGO_PASSWORD"],
        host=os.environ["MONGO_HOST"],
    )
    client = MongoClient(mongo_uri)
    MONGO_DB = client["messages"]["v1.0"]

    block = make_block(initial_state)
    block.queue(concurrency_count=2)
    block.launch(
        server_port=args.port,
        share=args.hf_proxy,
    )


if __name__ == "__main__":
    args = parse_args()
    main(args)