# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TODO: Add a description here."""

import evaluate
import datasets
from collections import Counter
from math import log, exp
from random import seed, randint


# TODO: Add BibTeX citation
_CITATION = """\
@InProceedings{napoles-EtAl:2015:ACL-IJCNLP,
  author    = {Napoles, Courtney  and  Sakaguchi, Keisuke  and  Post, Matt  and  Tetreault, Joel},
  title     = {Ground Truth for Grammatical Error Correction Metrics},
  booktitle = {Proceedings of the 53rd Annual Meeting of the Association for Computational Linguistics and the 7th International Joint Conference on Natural Language Processing (Volume 2: Short Papers)},
  month     = {July},
  year      = {2015},
  address   = {Beijing, China},
  publisher = {Association for Computational Linguistics},
  pages     = {588--593},
  url       = {http://www.aclweb.org/anthology/P15-2097}
}
"""

# TODO: Add description of the module here
_DESCRIPTION = """\
 GLEU metric can be used for any monolingual "translation" task, that is it can be used for Grammatical Error Correction and other text re-writing tasks. BLEU  computes n-gram precisions over the reference but assigns more weight to n-grams that have been correctly changed from the source. GLEU rewards corrections while also correctly crediting unchanged source text. 
"""


# TODO: Add description of the arguments of the module here
_KWARGS_DESCRIPTION = """
Calculates how good are predictions given some references, using certain scores
Args:
    sources: Source language reference sentences. This is assumed to be same as references if not provided.
    references: Reference for each prediction. Each reference should be a string with tokens separated by spaces.
    predictions: list of predictions to score. Each prediction should be a string with tokens separated by spaces.
Returns:
    gleu_score: Average gleu_score over all predictions.

Examples:

    >>> my_new_module = evaluate.load("my_new_module")
    >>> references=["We may in actual fact be communicating with a hoax Facebook acccount of a cyberfriend , which we assume to be real but in reality , it is a fake account ."]
    >>> results = my_new_module.compute(references=references, predictions=["We may of actual fact communicating with a hoax Facebook acccount of a cyber friend , which we assumed to be real but in reality , it is a fake account ."])
    >>> print(results)
    {'gleu_score': 0.6}
    
    >>> results = my_new_module.compute(references=references, predictions=["We may be in actual fact communicating with a hoax Facebook acccount of a cyber friend , we assume to be real but in reality , it is a fake account ."])
    >>> print(results)
    {'gleu_score': 0.62}
    
    >>> results = my_new_module.compute(references=references, predictions=["We may in actual fact communicating with a hoax Facebook account of a cyber friend , which we assume to be real but in reality , it is a fake accounts ."])
    >>> print(results)
    {'gleu_score': 0.64}
    
"""

# TODO: Define external resources urls if needed
BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"

class GLEU():
    def __init__(self, order=4) :
        self.order = order
        
    def load_hypothesis_sentence(self,hypothesis) :
        self.hlen = len(hypothesis)
        self.this_h_ngrams = [self.get_ngram_counts(hypothesis,n)
                                for n in range(1,self.order+1) ]
    
    def load_sources(self,source_sents) :
        self.all_s_ngrams = [[self.get_ngram_counts(source_sent.split(),n)
                                for n in range(1,self.order+1) ]
                                    for source_sent in source_sents ]
    
    def load_references(self,ref_sents) :
        self.refs = [ [] for i in range(len(self.all_s_ngrams)) ]
        self.rlens = [ [] for i in range(len(self.all_s_ngrams)) ]
        for i, ref_sent in enumerate(ref_sents) :
            self.refs[i].append(ref_sent.split())
            self.rlens[i].append(len(ref_sent.split()))
    
        # count number of references each n-gram appear sin
        self.all_rngrams_freq = [ Counter() for i in range(self.order) ]
    
        self.all_r_ngrams = [ ]
        for refset in self.refs :
            all_ngrams = []
            self.all_r_ngrams.append(all_ngrams)
    
            for n in range(1,self.order+1) :
                ngrams = self.get_ngram_counts(refset[0],n)
                all_ngrams.append(ngrams)
    
                for k in ngrams.keys() :
                    self.all_rngrams_freq[n-1][k]+=1
    
                for ref in refset[1:] :
                    new_ngrams = self.get_ngram_counts(ref,n)
                    for nn in new_ngrams.elements() :
                        if new_ngrams[nn] > ngrams.get(nn,0) :
                            ngrams[nn] = new_ngrams[nn]
    
    def get_ngram_counts(self,sentence,n) :
        return Counter([tuple(sentence[i:i+n]) for i in range(len(sentence)+1-n)])
    
    # returns ngrams in a but not in b
    def get_ngram_diff(self,a,b) :
        diff = Counter(a)
        for k in (set(a) & set(b)) :
            del diff[k]
        return diff
    
    def normalization(self,ngram,n) :
        return 1.0*self.all_rngrams_freq[n-1][ngram]/len(self.rlens[0])
    
    # Collect BLEU-relevant statistics for a single hypothesis/reference pair.
    # Return value is a generator yielding:
    # (c, r, numerator1, denominator1, ... numerator4, denominator4)
    # Summing the columns across calls to this function on an entire corpus
    # will produce a vector of statistics that can be used to compute GLEU
    def gleu_stats(self,i,r_ind=None):
    
      hlen = self.hlen
      rlen = self.rlens[i][r_ind]
      
      yield hlen
      yield rlen
    
      for n in range(1,self.order+1):
        h_ngrams = self.this_h_ngrams[n-1]
        s_ngrams = self.all_s_ngrams[i][n-1]
        r_ngrams = self.get_ngram_counts(self.refs[i][r_ind],n)
    
        s_ngram_diff = self.get_ngram_diff(s_ngrams,r_ngrams)
    
        yield max([ sum( (h_ngrams & r_ngrams).values() ) - \
                    sum( (h_ngrams & s_ngram_diff).values() ), 0 ])
    
        yield max([hlen+1-n, 0])
    
    # Compute GLEU from collected statistics obtained by call(s) to gleu_stats
    def compute_gleu(self,stats,smooth=False):
        # smooth 0 counts for sentence-level scores
        if smooth :
            stats = [ s if s != 0 else 1 for s in stats ]
        if len(filter(lambda x: x==0, stats)) > 0:
            return 0
        (c, r) = stats[:2]
        log_gleu_prec = sum([math.log(float(x)/y)
                             for x,y in zip(stats[2::2],stats[3::2])]) / 4
        return math.exp(min([0, 1-float(r)/c]) + log_gleu_prec)
    


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class gleu(evaluate.Metric):
    """TODO: Short description of my evaluation module."""

    def _info(self):
        # TODO: Specifies the evaluate.EvaluationModuleInfo object
        return evaluate.MetricInfo(
            # This is the description that will appear on the modules page.
            module_type="metric",
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            # This defines the format of each prediction and reference
            features=datasets.Features(
                {
                    "predictions": datasets.Value("string", id="sequence"),
                    "references": datasets.Value("string", id="sequence"),
                    "order": datasets.Value("Int32"),
                }
            ),
            codebase_urls=["https://github.com/cnap/gec-ranking/"],
        )

    def _download_and_prepare(self, dl_manager):
        """Optional: download external resources useful to compute the scores"""
        # TODO: Download external resources if needed
        pass

    def _compute(self, references, predictions, order=4):
        """Returns the scores"""
        
        num_iterations = 500
        
        if len(references)==1:
            num_iterations = 1
            
        gleu_calculator = GLEU(order=order)
        
        # if sources:
        #     gleu_calculator.load_sources(sources)
        # else:
        #     
        gleu_calculator.load_sources(references)
        gleu_calculator.load_references(references)
        
        # first generate a random list of indices, using a different seed
        # for each iteration
        indices = []
        for j in range(num_iterations) :
            seed(j*101)
            indices.append([randint(0,len(references)-1) for i in range(len(predictions))])
        
        iter_stats = [[0 for i in range(2*order+2)] for j in range(num_iterations) ]
        
        for i,h in enumerate(predictions) :
            gleu_calculator.load_hypothesis_sentence(h)
            
            # we are going to store the score of this sentence for each ref
            # so we don't have to recalculate them 500 times
            
            stats_by_ref = [ None for r in range(len(references)) ]
            
            for j in range(num_iterations) :
                ref = indices[j][i]
                this_stats = stats_by_ref[ref]
            
                if this_stats is None :
                    this_stats = [ s for s in gleu_calculator.gleu_stats(i,r_ind=ref) ]
                    stats_by_ref[ref] = this_stats
                    
                iter_stats[j] = [sum(scores) for scores in zip(iter_stats[j], this_stats)]
        
        final_gleu_score = get_gleu_stats([gleu_calculator.compute_gleu(stats)
          for stats in iter_stats ])[0]
        return {"gleu_score": final_gleu_score}