import os
import tempfile
from typing import List, Callable

import gradio as gr
import pandas as pd
from autorag.data.parse import langchain_parse
from autorag.data.parse.base import _add_last_modified_datetime
from autorag.data.parse.llamaparse import llama_parse
from autorag.data.qa.schema import Raw
from autorag.utils import result_to_dataframe
from llama_index.llms.openai import OpenAI

from src.create import default_create, fast_create, advanced_create
from src.util import on_submit_openai_key, on_submit_llama_cloud_key, on_submit_upstage_key

@result_to_dataframe(["texts", "path", "page", "last_modified_datetime"])
def original_parse(fn: Callable, **kwargs):
	result = fn(**kwargs)
	result = _add_last_modified_datetime(result)
	return result

def change_lang_choice(lang: str) -> str:
	lang_dict = {
		"English": "en",
		"한국어": "ko",
		"日本語": "ja"
	}
	return lang_dict[lang]

def change_visible_status_api_key(parse_method: str):
	if parse_method == "llama-parse":
		return gr.update(visible=True), gr.update(visible=False)
	elif parse_method == "upstage🇰🇷":
		return gr.update(visible=False), gr.update(visible=True)
	else:
		return gr.update(visible=False), gr.update(visible=False)



def run_parse(file_lists: List[str], parse_method: str, original_raw_df, progress=gr.Progress()):
	# save an input file to a directory

	progress(0.05)
	langchain_parse_original = langchain_parse.__wrapped__

	if parse_method in ["pdfminer", "pdfplumber", "pypdfium2", "pypdf", "pymupdf"]:
		raw_df: pd.DataFrame = original_parse(langchain_parse_original,
											  data_path_list=file_lists, parse_method=parse_method)
	elif parse_method == "llama-parse":
		llama_cloud_api_key = os.getenv("LLAMA_CLOUD_API_KEY")
		if llama_cloud_api_key is None:
			return "Please submit your Llama Cloud API key first.", original_raw_df
		raw_df: pd.DataFrame = original_parse(llama_parse.__wrapped__, data_path_list=file_lists)
	elif parse_method == "upstage🇰🇷":
		upstage_api_key = os.getenv("UPSTAGE_API_KEY")
		if upstage_api_key is None:
			return "Please submit your Upstage API key first.", original_raw_df
		raw_df: pd.DataFrame = original_parse(langchain_parse_original,
											  data_path_list=file_lists, parse_method="upstagedocumentparse")
	else:
		return "Unsupported parse method.", original_raw_df
	progress(0.8)

	return "Parsing Complete. Download at the bottom button.", raw_df


def run_chunk(use_existed_raw: bool, raw_df: pd.DataFrame, raw_file: str, chunk_method: str, chunk_size: int, chunk_overlap: int,
			  lang: str = "English", original_corpus_df = None, progress=gr.Progress()):
	lang = change_lang_choice(lang)
	if not use_existed_raw:
		raw_df = pd.read_parquet(raw_file, engine="pyarrow")
	raw_instance = Raw(raw_df)

	if chunk_method in ["Token", "Sentence"]:
		corpus = raw_instance.chunk("llama_index_chunk", chunk_method=chunk_method, chunk_size=chunk_size,
									chunk_overlap=chunk_overlap, add_file_name=lang)
	elif chunk_method in ["Semantic"]:
		corpus = raw_instance.chunk("llama_index_chunk", chunk_method="Semantic_llama_index",
									embed_model="openai", breakpoint_percnetile_threshold=0.95,
									add_file_name=lang)
	elif chunk_method == "Recursive":
		corpus = raw_instance.chunk("langchain_chunk", chunk_method="recursivecharacter",
									add_file_name=lang, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
	else:
		gr.Error("Unsupported chunk method.")
		return "Unsupported chunk method.", original_corpus_df
	progress(0.8)
	return "Chunking Complete. Download at the bottom button.", corpus.data


def run_qa(use_existed_corpus: bool, corpus_df: pd.DataFrame, corpus_file: str, qa_method: str,
		   model_name: str, qa_cnt: int, batch_size: int, lang: str = "English", original_qa_df = None,
		   progress=gr.Progress()):
	lang = change_lang_choice(lang)
	if not use_existed_corpus:
		corpus_df = pd.read_parquet(corpus_file, engine="pyarrow")

	if os.getenv("OPENAI_API_KEY") is None:
		gr.Error("Please submit your OpenAI API key first.")
		return "Please submit your OpenAI API key first.", original_qa_df
	if model_name is None:
		gr.Error("Please select a model first.")
		return "Please select a model first.", original_qa_df

	llm = OpenAI(model=model_name)

	if qa_method == "default":
		qa = default_create(corpus_df, llm=llm, n=qa_cnt, lang=lang, progress=progress, batch_size=batch_size)
	elif qa_method == "fast":
		qa = fast_create(corpus_df, llm=llm, n=qa_cnt, lang=lang, progress=progress, batch_size=batch_size)
	elif qa_method == "advanced":
		qa = advanced_create(corpus_df, llm=llm, n=qa_cnt, lang=lang, progress=progress, batch_size=batch_size)
	else:
		gr.Error("Unsupported QA method.")
		return "Unsupported QA method.", original_qa_df

	return "QA Creation Complete. Download at the bottom button.", qa.data


def download_state(state: pd.DataFrame, change_name: str):
	if state is None:
		gr.Error("No data to download.")
		return ""
	with tempfile.TemporaryDirectory() as temp_dir:
		filename = os.path.join(temp_dir, f"{change_name}.parquet")
		state.to_parquet(filename, engine="pyarrow")
		yield filename


with gr.Blocks(theme="earneleh/paris") as demo:
	raw_df_state = gr.State()
	corpus_df_state = gr.State()
	qa_df_state = gr.State()
	gr.HTML("<h1>AutoRAG Data Creation 🛠️</h1>")
	with gr.Row():
		openai_key_textbox = gr.Textbox(label="Please input your OpenAI API key and press Enter.", type="password",
										info="You can get your API key from https://platform.openai.com/account/api-keys\n\n"
											 "AutoRAG do not store your API key.",
										autofocus=True)
		api_key_status_box = gr.Textbox(label="OpenAI API status", value="Not Set", interactive=False)
		lang_choice = gr.Radio(["English", "한국어", "日本語"], label="Language",
									   value="English", info="Choose Langauge. En, Ko, Ja are supported.",
									   interactive=True)

	with gr.Row(visible=False) as llama_cloud_api_key_row:
		llama_key_textbox = gr.Textbox(label="Please input your Llama Cloud API key and press Enter.", type="password",
									   		info="You can get your API key from https://docs.cloud.llamaindex.ai/llamacloud/getting_started/api_key\n\n"
											 "AutoRAG do not store your API key.",)
		llama_key_status_box = gr.Textbox(label="Llama Cloud API status", value="Not Set", interactive=False)

	with gr.Row(visible=False) as upstage_api_key_row:
		upstage_key_textbox = gr.Textbox(label="Please input your Upstage API key and press Enter.", type="password",
									   		info="You can get your API key from https://upstage.ai/\n\n"
											 "AutoRAG do not store your API key.",)
		upstage_key_status_box = gr.Textbox(label="Upstage API status", value="Not Set", interactive=False)

	with gr.Row():
		with gr.Column(scale=1):
			gr.Markdown("## 1. Parse your PDF files\n\nUpload your pdf files and make it to raw.parquet.")
			document_file_input = gr.File(label="Upload Files", type="filepath", file_count="multiple")
			parse_choice = gr.Dropdown(
				["pdfminer", "pdfplumber", "pypdfium2", "pypdf", "pymupdf", "llama-parse", "upstage🇰🇷"],
				label="Parsing Method", info="Choose parsing method that you want")
			parse_button = gr.Button(value="Run Parsing")
			parse_status = gr.Textbox(value="Not Started", interactive=False)
			raw_download_button = gr.DownloadButton(value=download_state, inputs=[raw_df_state, gr.State("raw")],
				label="Download raw.parquet")

		with gr.Column(scale=1):
			gr.Markdown(
				"## 2. Chunk your raw.parquet\n\nUse parsed raw.parquet or upload your own. It will make a corpus.parquet."
			)
			raw_file_input = gr.File(label="Upload raw.parquet", type="filepath", file_count="single", visible=False)
			use_previous_raw_file = gr.Checkbox(label="Use previous raw.parquet", value=True)

			chunk_choice = gr.Dropdown(
				["Token", "Sentence", "Semantic", "Recursive"],
				label="Chunking Method", info="Choose chunking method that you want")
			chunk_size = gr.Slider(minimum=128, maximum=1024, step=128, label="Chunk Size", value=256)
			chunk_overlap = gr.Slider(minimum=16, maximum=256, step=16, label="Chunk Overlap", value=32)
			chunk_button = gr.Button(value="Run Chunking")
			chunk_status = gr.Textbox(value="Not Started", interactive=False)
			corpus_download_button = gr.DownloadButton(label="Download corpus.parquet",
											   value=download_state, inputs=[corpus_df_state, gr.State("corpus")])

		with gr.Column(scale=1):
			gr.Markdown(
				"## 3. Create QA dataset from your corpus.parquet\n\nQA dataset is essential to run AutoRAG. Upload corpus.parquet & select QA method and run.")
			gr.HTML("<b style='color: red; background-color: black; font-weight: bold;'>Warning: QA Creation uses an OpenAI model, which can be costly. Start with a small batch to gauge expenses.</b>")
			corpus_file_input = gr.File(label="Upload corpus.parquet", type="filepath", file_count="single",
										visible=False)
			use_previous_corpus_file = gr.Checkbox(label="Use previous corpus.parquet", value=True)

			qa_choice = gr.Radio(["default", "fast", "advanced"], label="QA Method",
									info="Choose QA method that you want")
			model_choice = gr.Radio(["gpt-4o-mini", "gpt-4o"], label="Select model for data creation",
									)
			qa_cnt = gr.Slider(minimum=20, maximum=150, step=5, label="Number of QA pairs", value=80)
			batch_size = gr.Slider(minimum=1, maximum=16, step=1,
								   label="Batch Size to OpenAI model. If there is an error, decrease this.", value=16)
			run_qa_button = gr.Button(value="Run QA Creation")
			qa_status = gr.Textbox(value="Not Started", interactive=False)
			gr.Markdown("### Do you want to customize your QA dataset? Join a waitlist for AutoRAG data creation studio.")
			gr.Button("Join Data Creation Studio Waitlist", link="https://tally.so/r/wdDo6N")
			qa_download_button = gr.DownloadButton(label="Download qa.parquet",
										   value=download_state, inputs=[qa_df_state, gr.State("qa")])

	#================================================================================================#
	# Logics

	use_previous_raw_file.change(lambda x: gr.update(visible=not x), inputs=[use_previous_raw_file],
								 outputs=[raw_file_input])
	use_previous_corpus_file.change(lambda x: gr.update(visible=not x), inputs=[use_previous_corpus_file],
									outputs=[corpus_file_input])
	openai_key_textbox.submit(on_submit_openai_key, inputs=[openai_key_textbox], outputs=api_key_status_box)

	# Parsing
	parse_button.click(run_parse, inputs=[document_file_input, parse_choice, raw_df_state],
					   outputs=[parse_status, raw_df_state])

	# Chunking
	chunk_button.click(run_chunk, inputs=[use_previous_raw_file, raw_df_state, raw_file_input, chunk_choice, chunk_size, chunk_overlap,
										  lang_choice, corpus_df_state],
					   outputs=[chunk_status, corpus_df_state])

	# QA Creation
	run_qa_button.click(run_qa, inputs=[use_previous_corpus_file, corpus_df_state, corpus_file_input, qa_choice,
										model_choice, qa_cnt, batch_size, lang_choice,
										qa_df_state],
						outputs=[qa_status, qa_df_state])

	# API Key visibility
	parse_choice.change(change_visible_status_api_key, inputs=[parse_choice],
						outputs=[llama_cloud_api_key_row, upstage_api_key_row])
	llama_key_textbox.submit(on_submit_llama_cloud_key, inputs=[llama_key_textbox], outputs=llama_key_status_box)
	upstage_key_textbox.submit(on_submit_upstage_key, inputs=[upstage_key_textbox], outputs=upstage_key_status_box)


# if __name__ == "__main__":
# 	demo.launch(share=False, debug=True)
demo.launch(share=False, debug=False)