from typing import Any, List, Dict, Literal, Optional
from argparse import ArgumentParser
import threading
import cv2
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer

import facefusion.globals
import facefusion.processors.frame.core as frame_processors
from facefusion import wording
from facefusion.face_analyser import clear_face_analyser
from facefusion.content_analyser import clear_content_analyser
from facefusion.typing import Frame, Face, Update_Process, ProcessMode, ModelValue, OptionsWithModel
from facefusion.utilities import conditional_download, resolve_relative_path, is_file, is_download_done, map_device, create_metavar, update_status
from facefusion.vision import read_image, read_static_image, write_image
from facefusion.processors.frame import globals as frame_processors_globals
from facefusion.processors.frame import choices as frame_processors_choices

FRAME_PROCESSOR = None
THREAD_SEMAPHORE : threading.Semaphore = threading.Semaphore()
THREAD_LOCK : threading.Lock = threading.Lock()
NAME = 'FACEFUSION.FRAME_PROCESSOR.FRAME_ENHANCER'
MODELS: Dict[str, ModelValue] =\
{
	'real_esrgan_x2plus':
	{
		'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models/real_esrgan_x2plus.pth',
		'path': resolve_relative_path('../.assets/models/real_esrgan_x2plus.pth'),
		'scale': 2
	},
	'real_esrgan_x4plus':
	{
		'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models/real_esrgan_x4plus.pth',
		'path': resolve_relative_path('../.assets/models/real_esrgan_x4plus.pth'),
		'scale': 4
	},
	'real_esrnet_x4plus':
	{
		'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models/real_esrnet_x4plus.pth',
		'path': resolve_relative_path('../.assets/models/real_esrnet_x4plus.pth'),
		'scale': 4
	}
}
OPTIONS : Optional[OptionsWithModel] = None


def get_frame_processor() -> Any:
	global FRAME_PROCESSOR

	with THREAD_LOCK:
		if FRAME_PROCESSOR is None:
			model_path = get_options('model').get('path')
			model_scale = get_options('model').get('scale')
			FRAME_PROCESSOR = RealESRGANer(
				model_path = model_path,
				model = RRDBNet(
					num_in_ch = 3,
					num_out_ch = 3,
					scale = model_scale
				),
				device = map_device(facefusion.globals.execution_providers),
				scale = model_scale
			)
	return FRAME_PROCESSOR


def clear_frame_processor() -> None:
	global FRAME_PROCESSOR

	FRAME_PROCESSOR = None


def get_options(key : Literal['model']) -> Any:
	global OPTIONS

	if OPTIONS is None:
		OPTIONS =\
		{
			'model': MODELS[frame_processors_globals.frame_enhancer_model]
		}
	return OPTIONS.get(key)


def set_options(key : Literal['model'], value : Any) -> None:
	global OPTIONS

	OPTIONS[key] = value


def register_args(program : ArgumentParser) -> None:
	program.add_argument('--frame-enhancer-model', help = wording.get('frame_processor_model_help'), dest = 'frame_enhancer_model', default = 'real_esrgan_x2plus', choices = frame_processors_choices.frame_enhancer_models)
	program.add_argument('--frame-enhancer-blend', help = wording.get('frame_processor_blend_help'), dest = 'frame_enhancer_blend', type = int, default = 80, choices = frame_processors_choices.frame_enhancer_blend_range, metavar = create_metavar(frame_processors_choices.frame_enhancer_blend_range))


def apply_args(program : ArgumentParser) -> None:
	args = program.parse_args()
	frame_processors_globals.frame_enhancer_model = args.frame_enhancer_model
	frame_processors_globals.frame_enhancer_blend = args.frame_enhancer_blend


def pre_check() -> bool:
	if not facefusion.globals.skip_download:
		download_directory_path = resolve_relative_path('../.assets/models')
		model_url = get_options('model').get('url')
		conditional_download(download_directory_path, [ model_url ])
	return True


def pre_process(mode : ProcessMode) -> bool:
	model_url = get_options('model').get('url')
	model_path = get_options('model').get('path')
	if not facefusion.globals.skip_download and not is_download_done(model_url, model_path):
		update_status(wording.get('model_download_not_done') + wording.get('exclamation_mark'), NAME)
		return False
	elif not is_file(model_path):
		update_status(wording.get('model_file_not_present') + wording.get('exclamation_mark'), NAME)
		return False
	if mode == 'output' and not facefusion.globals.output_path:
		update_status(wording.get('select_file_or_directory_output') + wording.get('exclamation_mark'), NAME)
		return False
	return True


def post_process() -> None:
	clear_frame_processor()
	clear_face_analyser()
	clear_content_analyser()
	read_static_image.cache_clear()


def enhance_frame(temp_frame : Frame) -> Frame:
	with THREAD_SEMAPHORE:
		paste_frame, _ = get_frame_processor().enhance(temp_frame)
		temp_frame = blend_frame(temp_frame, paste_frame)
	return temp_frame


def blend_frame(temp_frame : Frame, paste_frame : Frame) -> Frame:
	frame_enhancer_blend = 1 - (frame_processors_globals.frame_enhancer_blend / 100)
	paste_frame_height, paste_frame_width = paste_frame.shape[0:2]
	temp_frame = cv2.resize(temp_frame, (paste_frame_width, paste_frame_height))
	temp_frame = cv2.addWeighted(temp_frame, frame_enhancer_blend, paste_frame, 1 - frame_enhancer_blend, 0)
	return temp_frame


def process_frame(source_face : Face, reference_face : Face, temp_frame : Frame) -> Frame:
	return enhance_frame(temp_frame)


def process_frames(source_path : str, temp_frame_paths : List[str], update_progress : Update_Process) -> None:
	for temp_frame_path in temp_frame_paths:
		temp_frame = read_image(temp_frame_path)
		result_frame = process_frame(None, None, temp_frame)
		write_image(temp_frame_path, result_frame)
		update_progress()


def process_image(source_path : str, target_path : str, output_path : str) -> None:
	target_frame = read_static_image(target_path)
	result = process_frame(None, None, target_frame)
	write_image(output_path, result)


def process_video(source_path : str, temp_frame_paths : List[str]) -> None:
	frame_processors.multi_process_frames(None, temp_frame_paths, process_frames)