import json
import os
import traceback
from typing import List, Tuple

import gradio as gr
import requests
from huggingface_hub import HfApi

hf_api = HfApi()
roots_datasets = {
    dset.id.split("/")[-1]: dset
    for dset in hf_api.list_datasets(
        author="bigscience-data", use_auth_token=os.environ.get("bigscience_data_token")
    )
}


def get_docid_html(docid):
    data_org, dataset, docid = docid.split("/")
    metadata = roots_datasets[dataset]
    locked_color = "LightGray"
    open_color = "#7978FF"
    if metadata.private:
        docid_html = """
        <a title="This dataset is private. See the introductory text for more information"
            style="color:{locked_color}; font-weight: bold; text-decoration:none"
            onmouseover="style='color:{locked_color}; font-weight: bold; text-decoration:underline'"
            onmouseout="style='color:{locked_color}; font-weight: bold; text-decoration:none'"
            href="https://huggingface.co/datasets/bigscience-data/{dataset}"
            target="_blank">
            🔒{dataset}
        </a>
        <span style="color:{open_color}; ">/{docid}</span>""".format(
            dataset=dataset,
            docid=docid,
            locked_color=locked_color,
            open_color=open_color,
        )
    else:
        docid_html = """
        <a title="This dataset is licensed {metadata}"
            style="color:{open_color}; font-weight: bold; text-decoration:none"
            onmouseover="style='color:{open_color}; font-weight: bold; text-decoration:underline'"
            onmouseout="style='color:{open_color}; font-weight: bold; text-decoration:none'"
            href="https://huggingface.co/datasets/bigscience-data/{dataset}"
            target="_blank">
            {dataset}
        </a>
        <span style="color:{open_color}; ">/{docid}</span>""".format(
            metadata=metadata.tags[0].split(":")[-1],
            dataset=dataset,
            docid=docid,
            open_color=open_color,
        )
    return docid_html


PII_TAGS = {"KEY", "EMAIL", "USER", "IP_ADDRESS", "ID", "IPv4", "IPv6"}
PII_PREFIX = "PI:"


def process_pii(text):
    for tag in PII_TAGS:
        text = text.replace(
            PII_PREFIX + tag,
            """<b><mark style="background: Fuchsia; color: Lime;">REDACTED {}</mark></b>""".format(
                tag
            ),
        )
    return text


def extract_lang_from_docid(docid):
    return docid.split("_")[1]


def format_result(result, highlight_terms, exact_search, datasets_filter=None):
    text, url, docid = result
    if datasets_filter is not None:
        datasets_filter = set(datasets_filter)
        dataset = docid.split("/")[1]
        if not dataset in datasets_filter:
            return ""

    if exact_search:
        query_start = text.find(highlight_terms)
        query_end = query_start + len(highlight_terms)
        tokens_html = text[0:query_start]
        tokens_html += "<b>{}</b>".format(text[query_start:query_end])
        tokens_html += text[query_end:]
    else:
        tokens = text.split()
        tokens_html = []
        for token in tokens:
            if token in highlight_terms:
                tokens_html.append("<b>{}</b>".format(token))
            else:
                tokens_html.append(token)
        tokens_html = " ".join(tokens_html)
    tokens_html = process_pii(tokens_html)

    url_html = (
        """
        <span style='font-size:12px; font-family: Arial; color:Silver; text-align: left;'>
            <a style='text-decoration:none; color:Silver;'
                onmouseover="style='text-decoration:underline; color:Silver;'"
                onmouseout="style='text-decoration:none; color:Silver;'"
                href='{url}'
                target="_blank">
                {url}
            </a>
        </span><br>
    """.format(
            url=url
        )
        if url is not None
        else ""
    )
    docid_html = get_docid_html(docid)
    language = extract_lang_from_docid(docid)
    result_html = """{}
        <span style='font-size:14px; font-family: Arial; color:MediumAquaMarine'>Language: {} | </span>
        <span style='font-size:14px; font-family: Arial; color:#7978FF; text-align: left;'>Document ID: {} | </span>
        <a href="https://forms.gle/AdBLLwRApqcLkHYA8" target="_blank">
            <button style="color:#ffcdf8; ">🏴‍☠️ Flag result 🏴‍☠️</button>
        </a><br>
        <span style='font-family: Arial;'>{}</span><br>
        <br>
    """.format(
        url_html, language, docid_html, tokens_html
    )
    return "<p>" + result_html + "</p>"


def format_result_page(
    language, results, highlight_terms, num_results, exact_search, datasets_filter=None
) -> gr.HTML:

    filtered_num_results = 0
    header_html = ""

    if language == "detect_language" and not exact_search:
        header_html += """<div style='font-family: Arial; color:MediumAquaMarine; text-align: center; line-height: 3em'>
            Detected language: <b style='color:MediumAquaMarine'>{}</b></div>""".format(
            list(results.keys())[0]
        )

    result_page_html = ""
    for lang, results_for_lang in results.items():
        print("Processing language", lang)
        if len(results_for_lang) == 0:
            if exact_search:
                result_page_html += """<div style='font-family: Arial; color:Silver; text-align: left; line-height: 3em'>
                    No results found.</div>"""
            else:
                result_page_html += """<div style='font-family: Arial; color:Silver; text-align: left; line-height: 3em'>
                    No results for language: <b>{}</b></div>""".format(
                    lang
                )
            continue
        results_for_lang_html = ""
        for result in results_for_lang:
            result_html = format_result(
                result, highlight_terms, exact_search, datasets_filter
            )
            if result_html != "":
                filtered_num_results += 1
            results_for_lang_html += result_html
        if language == "all" and not exact_search:
            results_for_lang_html = f"""
                <details>
                    <summary style='font-family: Arial; color:MediumAquaMarine; text-align: left; line-height: 3em'>
                        Results for language: <b>{lang}</b>
                    </summary>
                    {results_for_lang_html}
                </details>"""
        result_page_html += results_for_lang_html

    if num_results is not None:
        header_html += """<div style='font-family: Arial; color:MediumAquaMarine; text-align: center; line-height: 3em'>
            Total number of matches: <b style='color:MediumAquaMarine'>{}</b></div>""".format(
            num_results
        )
    return header_html + result_page_html


def extract_results_from_payload(query, language, payload, exact_search):
    results = payload["results"]
    processed_results = dict()
    datasets = set()
    highlight_terms = None
    num_results = None

    if exact_search:
        highlight_terms = query
        num_results = payload["num_results"]
        results = {"dummy": results}
    else:
        highlight_terms = payload["highlight_terms"]

    for lang, results_for_lang in results.items():
        processed_results[lang] = list()
        for result in results_for_lang:
            text = result["text"]
            url = (
                result["meta"]["url"]
                if "meta" in result
                and result["meta"] is not None
                and "url" in result["meta"]
                else None
            )
            docid = result["docid"]
            _, dataset, _ = docid.split("/")
            datasets.add(dataset)
            processed_results[lang].append((text, url, docid))

    return processed_results, highlight_terms, num_results, list(datasets)


def no_query_error_message():
    return f"""
        <p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'>
        Please provide a non-empty query.
        </p><br><hr><br>"""


def process_error(error_type, payload):
    if error_type == "unsupported_lang":
        detected_lang = payload["err"]["meta"]["detected_lang"]
        return f"""
            <p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'>
            Detected language <b>{detected_lang}</b> is not supported.<br>
            Please choose a language from the dropdown or type another query.
            </p><br><hr><br>"""


def extract_error_from_payload(payload):
    if "err" in payload:
        return payload["err"]["type"]
    return None


def request_payload(query, language, exact_search, num_results=10, received_results=0):
    post_data = {"query": query, "k": num_results, "received_results": received_results}
    if language != "detect_language":
        post_data["lang"] = language
    address = "http://34.105.160.81:8080" if exact_search else os.environ.get("address")
    output = requests.post(
        address,
        headers={"Content-type": "application/json"},
        data=json.dumps(post_data),
        timeout=60,
    )
    payload = json.loads(output.text)
    return payload


title = (
    """<p style="text-align: center; font-size:28px"> 🌸 🔎 ROOTS search tool 🔍 🌸 </p>"""
)
description = """
The ROOTS corpus was developed during the [BigScience workshop](https://bigscience.huggingface.co/) for the purpose
of training the Multilingual Large Language Model [BLOOM](https://huggingface.co/bigscience/bloom). This tool allows
you to search through the ROOTS corpus. We serve a BM25 index for each language or group of languages included in
ROOTS. You can read more about the details of the tool design
[here](https://huggingface.co/spaces/bigscience-data/scisearch/blob/main/roots_search_tool_specs.pdf). For more
information and instructions on how to access the full corpus check [this form](https://forms.gle/qyYswbEL5kA23Wu99)."""


if __name__ == "__main__":
    demo = gr.Blocks(css=".underline-on-hover:hover { text-decoration: underline; }")

    with demo:
        processed_results_state = gr.State([])
        highlight_terms_state = gr.State([])
        num_results_state = gr.State(0)
        exact_search_state = gr.State(False)
        received_results_state = gr.State(0)

        with gr.Row():
            gr.Markdown(value=title)
        with gr.Row():
            gr.Markdown(value=description)
        with gr.Row():
            query = gr.Textbox(
                lines=1,
                max_lines=1,
                placeholder="Put your query in double quotes for exact search.",
                label="Query",
            )
        with gr.Row():
            lang = gr.Dropdown(
                choices=[
                    "ar",
                    "ca",
                    "code",
                    "en",
                    "es",
                    "eu",
                    "fr",
                    "id",
                    "indic",
                    "nigercongo",
                    "pt",
                    "vi",
                    "zh",
                    "detect_language",
                    "all",
                ],
                value="en",
                label="Language",
            )
            k = gr.Slider(
                1,
                100,
                value=10,
                step=1,
                label="Max Results in fuzzy search or Max Results per page in exact search",
            )
        with gr.Row():
            submit_btn = gr.Button("Submit")
        with gr.Row(visible=False) as datasets_filter:
            available_datasets = gr.Dropdown(
                type="value",
                choices=[],
                value=[],
                label="Datasets Filter",
                multiselect=True,
            )
        with gr.Row():
            result_page_html = gr.HTML(label="Results")

        with gr.Row(visible=False) as pagination:
            next_page_btn = gr.Button("Next Page")

        def run_query(query, lang, k, dropdown_input, received_results):
            query = query.strip()
            exact_search = False
            if query.startswith('"') and query.endswith('"') and len(query) >= 2:
                exact_search = True
                query = query[1:-1]
            else:
                query = " ".join(query.split())
            if query == "" or query is None:
                return (
                    [],
                    [],
                    0,
                    False,
                    no_query_error_message(),
                    [],
                )

            payload = request_payload(query, lang, exact_search, k, received_results)
            err = extract_error_from_payload(payload)
            if err is not None:
                return (
                    [],
                    [],
                    0,
                    False,
                    process_error(err, payload),
                    [],
                )

            (
                processed_results,
                highlight_terms,
                num_results,
                ds,
            ) = extract_results_from_payload(
                query,
                lang,
                payload,
                exact_search,
            )
            result_page = format_result_page(
                lang, processed_results, highlight_terms, num_results, exact_search
            )
            return (
                processed_results,
                highlight_terms,
                num_results,
                exact_search,
                result_page,
                ds,
            )

        def submit(query, lang, k, dropdown_input):
            print("submitting", query, lang, k)
            (
                processed_results,
                highlight_terms,
                num_results,
                exact_search,
                result_page,
                datasets,
            ) = run_query(query, lang, k, dropdown_input, 0)
            has_more_results = exact_search and (num_results > k)
            current_results = (
                len(next(iter(processed_results.values())))
                if len(processed_results) > 0
                else 0
            )
            return [
                processed_results,
                highlight_terms,
                num_results,
                exact_search,
                gr.update(visible=True)
                if current_results > 0
                else gr.update(visible=False),
                gr.Dropdown.update(choices=datasets, value=datasets),
                gr.update(visible=has_more_results),
                current_results,
                result_page,
            ]

        def next_page(
            query,
            lang,
            k,
            dropdown_input,
            received_results,
            processed_results,
        ):
            (
                processed_results,
                highlight_terms,
                num_results,
                exact_search,
                result_page,
                datasets,
            ) = run_query(query, lang, k, dropdown_input, received_results)
            current_results = sum(
                len(results) for results in processed_results.values()
            )
            has_more_results = exact_search and (
                received_results + current_results < num_results
            )
            print("received_results", received_results)
            print("current_results", current_results)
            print("has_more_results", has_more_results)
            return [
                processed_results,
                highlight_terms,
                num_results,
                exact_search,
                gr.update(visible=True)
                if current_results > 0
                else gr.update(visible=False),
                gr.Dropdown.update(choices=datasets, value=datasets),
                gr.update(visible=current_results >= k and has_more_results),
                received_results + current_results,
                result_page,
            ]

        def filter_datasets(
            lang,
            processed_results,
            highlight_terms,
            num_results,
            exact_search,
            datasets_filter,
        ):
            result_page_html = format_result_page(
                lang,
                processed_results,
                highlight_terms,
                num_results,
                exact_search,
                datasets_filter,
            )
            return result_page_html

        query.submit(
            fn=submit,
            inputs=[query, lang, k, available_datasets],
            outputs=[
                processed_results_state,
                highlight_terms_state,
                num_results_state,
                exact_search_state,
                datasets_filter,
                available_datasets,
                pagination,
                received_results_state,
                result_page_html,
            ],
        )
        submit_btn.click(
            submit,
            inputs=[query, lang, k, available_datasets],
            outputs=[
                processed_results_state,
                highlight_terms_state,
                num_results_state,
                exact_search_state,
                datasets_filter,
                available_datasets,
                pagination,
                received_results_state,
                result_page_html,
            ],
        )

        next_page_btn.click(
            next_page,
            inputs=[
                query,
                lang,
                k,
                available_datasets,
                received_results_state,
                processed_results_state,
            ],
            outputs=[
                processed_results_state,
                highlight_terms_state,
                num_results_state,
                exact_search_state,
                datasets_filter,
                available_datasets,
                pagination,
                received_results_state,
                result_page_html,
            ],
        )

        available_datasets.change(
            filter_datasets,
            inputs=[
                lang,
                processed_results_state,
                highlight_terms_state,
                num_results_state,
                exact_search_state,
                available_datasets,
            ],
            outputs=result_page_html,
        )
    demo.launch(enable_queue=True, debug=True)