from model import get_model_tokenizer_classifier, InferenceArguments
from utils import jaccard, safe_print
from transformers import HfArgumentParser
from preprocess import get_words, clean_text
from shared import GeneralArguments, DatasetArguments
from predict import predict
from segment import extract_segment, word_start, word_end, SegmentationArguments, add_labels_to_words
import pandas as pd
from dataclasses import dataclass, field
from typing import Optional
from tqdm import tqdm
import json
import os
import random
from shared import seconds_to_time
from urllib.parse import quote
import logging

logging.basicConfig()
logger = logging.getLogger(__name__)


@dataclass
class EvaluationArguments(InferenceArguments):
    """Arguments pertaining to how evaluation will occur."""
    output_file: Optional[str] = field(
        default='metrics.csv',
        metadata={
            'help': 'Save metrics to output file'
        }
    )

    skip_missing: bool = field(
        default=False,
        metadata={
            'help': 'Whether to skip checking for missing segments. If False, predictions will be made.'
        }
    )
    skip_incorrect: bool = field(
        default=False,
        metadata={
            'help': 'Whether to skip checking for incorrect segments. If False, classifications will be made on existing segments.'
        }
    )


def attach_predictions_to_sponsor_segments(predictions, sponsor_segments):
    """Attach sponsor segments to closest prediction"""
    for prediction in predictions:
        prediction['best_overlap'] = 0
        prediction['best_sponsorship'] = None

        # Assign predictions to actual (labelled) sponsored segments
        for sponsor_segment in sponsor_segments:
            j = jaccard(prediction['start'], prediction['end'],
                        sponsor_segment['start'], sponsor_segment['end'])
            if prediction['best_overlap'] < j:
                prediction['best_overlap'] = j
                prediction['best_sponsorship'] = sponsor_segment

    return sponsor_segments


def calculate_metrics(labelled_words, predictions):

    metrics = {
        'true_positive': 0,  # Is sponsor, predicted sponsor
        # Is sponsor, predicted not sponsor (i.e., missed it - bad)
        'false_negative': 0,
        # Is not sponsor, predicted sponsor (classified incorectly, not that bad since we do manual checking afterwards)
        'false_positive': 0,
        'true_negative': 0,  # Is not sponsor, predicted not sponsor
    }

    metrics['video_duration'] = word_end(
        labelled_words[-1])-word_start(labelled_words[0])

    for index, word in enumerate(labelled_words):
        if index >= len(labelled_words) - 1:
            continue

        duration = word_end(word) - word_start(word)

        predicted_sponsor = False
        for p in predictions:
            # Is in some prediction
            if p['start'] <= word['start'] <= p['end']:
                predicted_sponsor = True
                break

        if predicted_sponsor:
            # total_positive_time += duration
            if word.get('category') is not None:  # Is actual sponsor
                metrics['true_positive'] += duration
            else:
                metrics['false_positive'] += duration
        else:
            # total_negative_time += duration
            if word.get('category') is not None:  # Is actual sponsor
                metrics['false_negative'] += duration
            else:
                metrics['true_negative'] += duration

    # NOTE In cases where we encounter division by 0, we say that the value is 1
    # https://stats.stackexchange.com/a/1775
    # (Precision) TP+FP=0: means that all instances were predicted as negative
    # (Recall)    TP+FN=0: means that there were no positive cases in the input data

    # The fraction of predictions our model got right
    # Can simplify, but use full formula
    z = metrics['true_positive'] + metrics['true_negative'] + \
        metrics['false_positive'] + metrics['false_negative']
    metrics['accuracy'] = (
        (metrics['true_positive'] + metrics['true_negative']) / z) if z > 0 else 1

    # What proportion of positive identifications was actually correct?
    z = metrics['true_positive'] + metrics['false_positive']
    metrics['precision'] = (metrics['true_positive'] / z) if z > 0 else 1

    # What proportion of actual positives was identified correctly?
    z = metrics['true_positive'] + metrics['false_negative']
    metrics['recall'] = (metrics['true_positive'] / z) if z > 0 else 1

    # https://deepai.org/machine-learning-glossary-and-terms/f-score

    s = metrics['precision'] + metrics['recall']
    metrics['f-score'] = (2 * (metrics['precision'] *
                               metrics['recall']) / s) if s > 0 else 0

    return metrics


def main():
    logger.setLevel(logging.DEBUG)

    hf_parser = HfArgumentParser((
        EvaluationArguments,
        DatasetArguments,
        SegmentationArguments,
        GeneralArguments
    ))

    evaluation_args, dataset_args, segmentation_args, general_args = hf_parser.parse_args_into_dataclasses()

    if evaluation_args.skip_missing and evaluation_args.skip_incorrect:
        logger.error('ERROR: Nothing to do')
        return

    # Load labelled data:
    final_path = os.path.join(
        dataset_args.data_dir, dataset_args.processed_file)

    if not os.path.exists(final_path):
        logger.error('ERROR: Processed database not found.\n'
                     f'Run `python src/preprocess.py --update_database --do_create` to generate "{final_path}".')
        return

    model, tokenizer, classifier = get_model_tokenizer_classifier(
        evaluation_args, general_args)

    with open(final_path) as fp:
        final_data = json.load(fp)

    if evaluation_args.video_ids:  # Use specified
        video_ids = evaluation_args.video_ids

    else:  # Use items found in preprocessed database
        video_ids = list(final_data.keys())
        random.shuffle(video_ids)

        if evaluation_args.start_index is not None:
            video_ids = video_ids[evaluation_args.start_index:]

        if evaluation_args.max_videos is not None:
            video_ids = video_ids[:evaluation_args.max_videos]

    out_metrics = []

    all_metrics = {}
    if not evaluation_args.skip_missing:
        all_metrics['total_prediction_accuracy'] = 0
        all_metrics['total_prediction_precision'] = 0
        all_metrics['total_prediction_recall'] = 0
        all_metrics['total_prediction_fscore'] = 0

    if not evaluation_args.skip_incorrect:
        all_metrics['classifier_segment_correct'] = 0
        all_metrics['classifier_segment_count'] = 0

    metric_count = 0

    postfix_info = {}

    try:
        with tqdm(video_ids) as progress:
            for video_index, video_id in enumerate(progress):
                progress.set_description(f'Processing {video_id}')

                words = get_words(video_id)
                if not words:
                    continue

                # Get labels
                sponsor_segments = final_data.get(video_id)

                # Reset previous
                missed_segments = []
                incorrect_segments = []

                current_metrics = {
                    'video_id': video_id
                }
                metric_count += 1

                if not evaluation_args.skip_missing:  # Make predictions
                    predictions = predict(video_id, model, tokenizer, segmentation_args,
                                          classifier=classifier,
                                          min_probability=evaluation_args.min_probability)

                    if sponsor_segments:
                        labelled_words = add_labels_to_words(
                            words, sponsor_segments)

                        current_metrics.update(
                            calculate_metrics(labelled_words, predictions))

                        all_metrics['total_prediction_accuracy'] += current_metrics['accuracy']
                        all_metrics['total_prediction_precision'] += current_metrics['precision']
                        all_metrics['total_prediction_recall'] += current_metrics['recall']
                        all_metrics['total_prediction_fscore'] += current_metrics['f-score']

                        # Just for display purposes
                        postfix_info.update({
                            'accuracy': all_metrics['total_prediction_accuracy']/metric_count,
                            'precision':  all_metrics['total_prediction_precision']/metric_count,
                            'recall':  all_metrics['total_prediction_recall']/metric_count,
                            'f-score': all_metrics['total_prediction_fscore']/metric_count,
                        })

                        sponsor_segments = attach_predictions_to_sponsor_segments(
                            predictions, sponsor_segments)

                        # Identify possible issues:
                        for prediction in predictions:
                            if prediction['best_sponsorship'] is not None:
                                continue

                            prediction_words = prediction.pop('words', [])

                            # Attach original text to missed segments
                            prediction['text'] = ' '.join(
                                x['text'] for x in prediction_words)
                            missed_segments.append(prediction)

                    else:
                        # Not in database (all segments missed)
                        missed_segments = predictions

                if not evaluation_args.skip_incorrect and sponsor_segments:
                    # Check for incorrect segments using the classifier

                    segments_to_check = []
                    cleaned_texts = []  # Texts to send through tokenizer
                    for sponsor_segment in sponsor_segments:
                        segment_words = extract_segment(
                            words,  sponsor_segment['start'],  sponsor_segment['end'])
                        sponsor_segment['text'] = ' '.join(
                            x['text'] for x in segment_words)

                        duration = sponsor_segment['end'] - \
                            sponsor_segment['start']
                        wps = (len(segment_words) /
                               duration) if duration > 0 else 0
                        if wps < 1.5:
                            continue

                        # Do not worry about those that are locked or have enough votes
                        # or segment['votes'] > 5:
                        if sponsor_segment['locked']:
                            continue

                        cleaned_texts.append(
                            clean_text(sponsor_segment['text']))
                        segments_to_check.append(sponsor_segment)

                    if segments_to_check:  # Some segments to check

                        segments_scores = classifier(cleaned_texts)

                        num_correct = 0
                        for segment, scores in zip(segments_to_check, segments_scores):

                            fixed_scores = {
                                score['label']: score['score']
                                for score in scores
                            }

                            all_metrics['classifier_segment_count'] += 1

                            prediction = max(scores, key=lambda x: x['score'])
                            predicted_category = prediction['label'].lower()

                            if predicted_category == segment['category']:
                                num_correct += 1
                                continue  # Ignore correct segments

                            segment.update({
                                'predicted': predicted_category,
                                'scores': fixed_scores
                            })

                            incorrect_segments.append(segment)

                        current_metrics['num_segments'] = len(
                            segments_to_check)
                        current_metrics['classified_correct'] = num_correct

                        all_metrics['classifier_segment_correct'] += num_correct

                    if all_metrics['classifier_segment_count'] > 0:
                        postfix_info['classifier_accuracy'] = all_metrics['classifier_segment_correct'] / \
                            all_metrics['classifier_segment_count']

                out_metrics.append(current_metrics)
                progress.set_postfix(postfix_info)

                if missed_segments or incorrect_segments:

                    if evaluation_args.output_as_json:
                        to_print = {'video_id': video_id}

                        if missed_segments:
                            to_print['missed'] = missed_segments

                        if incorrect_segments:
                            to_print['incorrect'] = incorrect_segments

                        safe_print(json.dumps(to_print))

                    else:
                        safe_print(
                            f'Issues identified for {video_id} (#{video_index})')
                        # Potentially missed segments (model predicted, but not in database)
                        if missed_segments:
                            safe_print(' - Missed segments:')
                            segments_to_submit = []
                            for i, missed_segment in enumerate(missed_segments, start=1):
                                safe_print(f'\t#{i}:', seconds_to_time(
                                    missed_segment['start']), '-->', seconds_to_time(missed_segment['end']))
                                safe_print('\t\tText: "',
                                           missed_segment['text'], '"', sep='')
                                safe_print('\t\tCategory:',
                                           missed_segment.get('category'))
                                if 'probability' in missed_segment:
                                    safe_print('\t\tProbability:',
                                               missed_segment['probability'])

                                segments_to_submit.append({
                                    'segment': [missed_segment['start'], missed_segment['end']],
                                    'category': missed_segment['category'].lower(),
                                    'actionType': 'skip'
                                })

                            json_data = quote(json.dumps(segments_to_submit))
                            safe_print(
                                f'\tSubmit: https://www.youtube.com/watch?v={video_id}#segments={json_data}')

                        # Incorrect segments (in database, but incorrectly classified)
                        if incorrect_segments:
                            safe_print(' - Incorrect segments:')
                            for i, incorrect_segment in enumerate(incorrect_segments, start=1):
                                safe_print(f'\t#{i}:', seconds_to_time(
                                    incorrect_segment['start']), '-->', seconds_to_time(incorrect_segment['end']))

                                safe_print(
                                    '\t\tText: "', incorrect_segment['text'], '"', sep='')
                                safe_print(
                                    '\t\tUUID:', incorrect_segment['uuid'])
                                safe_print(
                                    '\t\tVotes:', incorrect_segment['votes'])
                                safe_print(
                                    '\t\tViews:', incorrect_segment['views'])
                                safe_print('\t\tLocked:',
                                           incorrect_segment['locked'])

                                safe_print('\t\tCurrent Category:',
                                           incorrect_segment['category'])
                                safe_print('\t\tPredicted Category:',
                                           incorrect_segment['predicted'])
                                safe_print('\t\tProbabilities:')
                                for label, score in incorrect_segment['scores'].items():
                                    safe_print(
                                        f"\t\t\t{label}: {score}")

                        safe_print()

    except KeyboardInterrupt:
        pass

    df = pd.DataFrame(out_metrics)

    df.to_csv(evaluation_args.output_file)
    logger.info(df.mean())


if __name__ == '__main__':
    main()