# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

# -*- coding:utf-8 -*-
import os
import logging
from . import bleu
from . import weighted_ngram_match
from . import syntax_match
from . import dataflow_match


def calc_codebleu(predictions, references, lang, tokenizer=None, params='0.25,0.25,0.25,0.25'):
    """_summary_

    Args:
        predictions (list[str]): list of predictions
        references (list[str]): list of lists with references
        lang (str): ['java','js','c_sharp','php','go','python','ruby']
        tokenizer (callable): tokenizer function, Defaults to lambda s: s.split()
        params (str, optional): Defaults to '0.25,0.25,0.25,0.25'.
    """

    alpha, beta, gamma, theta = [float(x) for x in params.split(',')]

    # preprocess inputs
    references = [[x.strip() for x in ref] for ref in references]
    hypothesis = [x.strip() for x in predictions]

    if not len(references) == len(hypothesis):
        raise ValueError

    # calculate ngram match (BLEU)
    if tokenizer is None:
        tokenizer = lambda s: s.split()

    tokenized_hyps = [tokenizer(x) for x in hypothesis]
    tokenized_refs = [[tokenizer(x) for x in reference]
                      for reference in references]

    ngram_match_score = bleu.corpus_bleu(tokenized_refs, tokenized_hyps)

    # calculate weighted ngram match
    keywords = [x.strip() for x in open(os.path.abspath(os.path.dirname(__file__)) + '/keywords/' + lang +
                                        '.txt', 'r', encoding='utf-8').readlines()]

    def make_weights(reference_tokens, key_word_list):
        return {token: 1 if token in key_word_list else 0.2
                for token in reference_tokens}
    tokenized_refs_with_weights = [[[reference_tokens, make_weights(reference_tokens, keywords)]
                                    for reference_tokens in reference] for reference in tokenized_refs]

    weighted_ngram_match_score = weighted_ngram_match.corpus_bleu(
        tokenized_refs_with_weights, tokenized_hyps)

    # calculate syntax match
    syntax_match_score = syntax_match.corpus_syntax_match(
        references, hypothesis, lang)

    # calculate dataflow match
    dataflow_match_score = dataflow_match.corpus_dataflow_match(
        references, hypothesis, lang)

    # print('ngram match: {0}, weighted ngram match: {1}, syntax_match: {2}, dataflow_match: {3}'.
        #   format(ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score))

    code_bleu_score = alpha*ngram_match_score\
        + beta*weighted_ngram_match_score\
        + gamma*syntax_match_score\
        + theta*dataflow_match_score

    # print('CodeBLEU score: ', code_bleu_score)

    return {
        'CodeBLEU': code_bleu_score,
        'ngram_match_score': ngram_match_score,
        'weighted_ngram_match_score': weighted_ngram_match_score,
        'syntax_match_score': syntax_match_score,
        'dataflow_match_score': dataflow_match_score
    }