# Import necessary libraries
import nltk
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.special import rel_entr
from collections import Counter
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
distortion_val={}
# Download NLTK data if not already present
nltk.download('punkt', quiet=True)

class SentenceDistortionCalculator:
    """
    A class to calculate and analyze distortion metrics between an original sentence and modified sentences.
    """

    def __init__(self, original_sentence, modified_sentences):
        """
        Initialize the calculator with the original sentence and a list of modified sentences.
        """
        self.original_sentence = original_sentence
        self.modified_sentences = modified_sentences

        # Raw metric dictionaries
        self.levenshtein_distances = {}
        self.word_level_changes = {}
        self.kl_divergences = {}
        self.perplexities = {}

        # Normalized metric dictionaries
        self.normalized_levenshtein = {}
        self.normalized_word_changes = {}
        self.normalized_kl_divergences = {}
        self.normalized_perplexities = {}

        # Combined distortion dictionary
        self.combined_distortions = {}

        # Initialize GPT-2 model and tokenizer for perplexity calculation
        self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
        self.model = GPT2LMHeadModel.from_pretrained("gpt2")
        self.model.eval()  # Set model to evaluation mode

    def calculate_all_metrics(self):
        """
        Calculate all distortion metrics for each modified sentence.
        """
        for idx, modified_sentence in enumerate(self.modified_sentences):
            key = f"Sentence_{idx+1}"
            self.levenshtein_distances[key] = self._calculate_levenshtein_distance(modified_sentence)
            self.word_level_changes[key] = self._calculate_word_level_change(modified_sentence)
            self.kl_divergences[key] = self._calculate_kl_divergence(modified_sentence)
            self.perplexities[key] = self._calculate_perplexity(modified_sentence)

    def normalize_metrics(self):
        """
        Normalize all metrics to be between 0 and 1.
        """
        self.normalized_levenshtein = self._normalize_dict(self.levenshtein_distances)
        self.normalized_word_changes = self._normalize_dict(self.word_level_changes)
        self.normalized_kl_divergences = self._normalize_dict(self.kl_divergences)
        self.normalized_perplexities = self._normalize_dict(self.perplexities)

    def calculate_combined_distortion(self):
        """
        Calculate the combined distortion using the root mean square of the normalized metrics.
        """
        for key in self.normalized_levenshtein.keys():
            rms = np.sqrt(
                (
                    self.normalized_levenshtein[key] ** 2 +
                    self.normalized_word_changes[key] ** 2 +
                    self.normalized_kl_divergences[key] ** 2 +
                    self.normalized_perplexities[key] ** 2
                ) / 4
            )
            self.combined_distortions[key] = rms

    def plot_metrics(self):
        """
        Plot each normalized metric and the combined distortion in separate graphs.
        """
        import matplotlib.pyplot as plt

        keys = list(self.normalized_levenshtein.keys())
        indices = np.arange(len(keys))

        # Prepare data for plotting
        metrics = {
            'Levenshtein Distance': [self.normalized_levenshtein[key] for key in keys],
            'Word-Level Changes': [self.normalized_word_changes[key] for key in keys],
            'KL Divergence': [self.normalized_kl_divergences[key] for key in keys],
            'Perplexity': [self.normalized_perplexities[key] for key in keys],
            'Combined Distortion': [self.combined_distortions[key] for key in keys]
        }

        # Plot each metric separately
        for metric_name, values in metrics.items():
            plt.figure(figsize=(12, 6))
            plt.plot(indices, values, marker='o', color=np.random.rand(3,))
            plt.xlabel('Sentence Index')
            plt.ylabel('Normalized Value (0-1)')
            plt.title(f'Normalized {metric_name}')
            plt.grid(True)
            plt.tight_layout()
            plt.show()

    # Private methods for metric calculations
    def _calculate_levenshtein_distance(self, modified_sentence):
        """
        Calculate the Levenshtein Distance between the original and modified sentence.
        """
        return nltk.edit_distance(self.original_sentence, modified_sentence)

    def _calculate_word_level_change(self, modified_sentence):
        """
        Calculate the proportion of word-level changes between the original and modified sentence.
        """
        original_words = self.original_sentence.split()
        modified_words = modified_sentence.split()
        total_words = max(len(original_words), len(modified_words))
        changed_words = sum(1 for o, m in zip(original_words, modified_words) if o != m)
        # Account for extra words in the modified sentence
        changed_words += abs(len(original_words) - len(modified_words))
        distortion = changed_words / total_words
        return distortion

    def _calculate_kl_divergence(self, modified_sentence):
        """
        Calculate the KL Divergence between the word distributions of the original and modified sentence.
        """
        original_counts = Counter(self.original_sentence.lower().split())
        modified_counts = Counter(modified_sentence.lower().split())
        all_words = set(original_counts.keys()).union(set(modified_counts.keys()))
        original_probs = np.array([original_counts.get(word, 0) for word in all_words], dtype=float)
        modified_probs = np.array([modified_counts.get(word, 0) for word in all_words], dtype=float)

        # Add smoothing to avoid division by zero
        original_probs += 1e-10
        modified_probs += 1e-10

        # Normalize to create probability distributions
        original_probs /= original_probs.sum()
        modified_probs /= modified_probs.sum()

        kl_divergence = np.sum(rel_entr(original_probs, modified_probs))
        return kl_divergence

    def _calculate_perplexity(self, sentence):
        """
        Calculate the perplexity of a sentence using GPT-2.
        """
        encodings = self.tokenizer(sentence, return_tensors='pt')
        max_length = self.model.config.n_positions
        stride = max_length

        lls = []
        for i in range(0, encodings.input_ids.size(1), stride):
            begin_loc = i
            end_loc = min(i + stride, encodings.input_ids.size(1))
            trg_len = end_loc - begin_loc

            input_ids = encodings.input_ids[:, begin_loc:end_loc]
            target_ids = input_ids.clone()

            with torch.no_grad():
                outputs = self.model(input_ids, labels=target_ids)
                log_likelihood = outputs.loss * trg_len

            lls.append(log_likelihood)

        ppl = torch.exp(torch.stack(lls).sum() / end_loc)
        return ppl.item()

    def _normalize_dict(self, metric_dict):
        """
        Normalize the values in a dictionary to be between 0 and 1.
        """
        values = np.array(list(metric_dict.values()))
        min_val = values.min()
        max_val = values.max()
        # Avoid division by zero if all values are the same
        if max_val - min_val == 0:
            normalized_values = np.zeros_like(values)
        else:
            normalized_values = (values - min_val) / (max_val - min_val)
        return dict(zip(metric_dict.keys(), normalized_values))

    # Getter methods
    def get_normalized_metrics(self):
        """
        Get all normalized metrics as a dictionary.
        """
        return {
            'Levenshtein Distance': self.normalized_levenshtein,
            'Word-Level Changes': self.normalized_word_changes,
            'KL Divergence': self.normalized_kl_divergences,
            'Perplexity': self.normalized_perplexities
        }

    def get_combined_distortions(self):
        """
        Get the dictionary of combined distortion values.
        """
        return self.combined_distortions

# # Example usage
# if __name__ == "__main__":
#     # Original sentence
#     original_sentence = "The quick brown fox jumps over the lazy dog"


#     paraphrased_sentences = [
#     # Original 1: "A swift auburn fox leaps across a sleepy canine."
#     "The swift auburn fox leaps across a sleepy canine.",
#     "A quick auburn fox leaps across a sleepy canine.",
#     "A swift ginger fox leaps across a sleepy canine.",
#     "A swift auburn fox bounds across a sleepy canine.",
#     "A swift auburn fox leaps across a tired canine.",
#     "Three swift auburn foxes leap across a sleepy canine.",
#     "The vulpine specimen rapidly traverses over a dormant dog.",
#     "Like lightning, the russet hunter soars over the drowsy guardian.",
#     "Tha quick ginger fox jumps o'er the lazy hound, ye ken.",
#     "One rapid Vulpes vulpes traverses the path of a quiescent canine.",
#     "A swift auburn predator navigates across a lethargic pet.",
#     "Subject A (fox) demonstrates velocity over Subject B (dog).",

#     # Original 2: "The agile russet fox bounds over an idle hound."
#     "Some agile russet foxes bound over an idle hound.",
#     "The nimble russet fox bounds over an idle hound.",
#     "The agile brown fox bounds over an idle hound.",
#     "The agile russet fox jumps over an idle hound.",
#     "The agile russet fox bounds over a lazy hound.",
#     "Two agile russet foxes bound over an idle hound.",
#     "A dexterous vulpine surpasses a stationary canine.",
#     "Quick as thought, the copper warrior sails over the guardian.",
#     "Tha nimble reddish fox jumps o'er the doggo, don't ya know.",
#     "A dexterous V. vulpes exceeds the plane of an inactive canine.",
#     "An agile russet hunter maneuvers above a resting hound.",
#     "Test subject F-1 achieves displacement superior to subject D-1.",

#     # Original 3: "A nimble mahogany vulpine vaults above a drowsy dog."
#     "The nimble mahogany vulpine vaults above a drowsy dog.",
#     "A swift mahogany vulpine vaults above a drowsy dog.",
#     "A nimble reddish vulpine vaults above a drowsy dog.",
#     "A nimble mahogany fox vaults above a drowsy dog.",
#     "A nimble mahogany vulpine leaps above a drowsy dog.",
#     "Four nimble mahogany vulpines vault above a drowsy dog.",
#     "An agile specimen of reddish fur surpasses a somnolent canine.",
#     "Fleet as wind, the earth-toned hunter soars over the sleepy guard.",
#     "Tha quick brown beastie jumps o'er the tired pup, aye.",
#     "Single V. vulpes demonstrates vertical traverse over C. familiaris.",
#     "A nimble rust-colored predator crosses above a drowsy pet.",
#     "Observed: Subject Red executes vertical motion over Subject Gray.",

#     # Original 4: "The speedy copper-colored fox hops over the lethargic pup."
#     "A speedy copper-colored fox hops over the lethargic pup.",
#     "The quick copper-colored fox hops over the lethargic pup.",
#     "The speedy bronze fox hops over the lethargic pup.",
#     "The speedy copper-colored fox jumps over the lethargic pup.",
#     "The speedy copper-colored fox hops over the tired pup.",
#     "Multiple speedy copper-colored foxes hop over the lethargic pup.",
#     "A rapid vulpine of bronze hue traverses an inactive young canine.",
#     "Swift as a dart, the metallic hunter bounds over the lazy puppy.",
#     "Tha fast copper beastie leaps o'er the sleepy wee dog.",
#     "1 rapid V. vulpes crosses above 1 juvenile C. familiaris.",
#     "A fleet copper-toned predator moves past a sluggish young dog.",
#     "Field note: Adult fox subject exceeds puppy subject vertically.",

#     # Original 5: "A rapid tawny fox springs over a sluggish dog."
#     "The rapid tawny fox springs over a sluggish dog.",
#     "A quick tawny fox springs over a sluggish dog.",
#     "A rapid golden fox springs over a sluggish dog.",
#     "A rapid tawny fox jumps over a sluggish dog.",
#     "A rapid tawny fox springs over a lazy dog.",
#     "Six rapid tawny foxes spring over a sluggish dog.",
#     "An expeditious yellowish vulpine surpasses a torpid canine.",
#     "Fast as a bullet, the golden hunter vaults over the idle guard.",
#     "Tha swift yellowy fox jumps o'er the lazy mutt, aye.",
#     "One V. vulpes displays rapid transit over one inactive C. familiaris.",
#     "A speedy yellow-brown predator bypasses a motionless dog.",
#     "Log entry: Vulpine subject achieves swift vertical displacement.",

#     # Original 6: "The fleet-footed chestnut fox soars above an indolent canine."
#     "A fleet-footed chestnut fox soars above an indolent canine.",
#     "The swift chestnut fox soars above an indolent canine.",
#     "The fleet-footed brown fox soars above an indolent canine.",
#     "The fleet-footed chestnut fox leaps above an indolent canine.",
#     "The fleet-footed chestnut fox soars above a lazy canine.",
#     "Several fleet-footed chestnut foxes soar above an indolent canine.",
#     "A rapid brown vulpine specimen traverses a lethargic domestic dog.",
#     "Graceful as a bird, the nutbrown hunter flies over the lazy guard.",
#     "Tha quick brown beastie sails o'er the sleepy hound, ken.",
#     "Single agile V. vulpes achieves elevation above stationary canine.",
#     "A nimble brown predator glides over an unmoving domestic animal.",
#     "Research note: Brown subject displays superior vertical mobility.",

#     # Original 7: "A fast ginger fox hurdles past a slothful dog."
#     "The fast ginger fox hurdles past a slothful dog.",
#     "A quick ginger fox hurdles past a slothful dog.",
#     "A fast red fox hurdles past a slothful dog.",
#     "A fast ginger fox jumps past a slothful dog.",
#     "A fast ginger fox hurdles past a lazy dog.",
#     "Five fast ginger foxes hurdle past a slothful dog.",
#     "A rapid orange vulpine bypasses a lethargic canine.",
#     "Quick as lightning, the flame-colored hunter races past the lazy guard.",
#     "Tha swift ginger beastie leaps past the tired doggy, ye see.",
#     "1 rapid orange V. vulpes surpasses 1 inactive C. familiaris.",
#     "A speedy red-orange predator overtakes a motionless dog.",
#     "Data point: Orange subject demonstrates rapid transit past Gray subject.",

#     # Original 8: "The spry rusty-colored fox jumps across a dozing hound."
#     "A spry rusty-colored fox jumps across a dozing hound.",
#     "The agile rusty-colored fox jumps across a dozing hound.",
#     "The spry reddish fox jumps across a dozing hound.",
#     "The spry rusty-colored fox leaps across a dozing hound.",
#     "The spry rusty-colored fox jumps across a sleeping hound.",
#     "Multiple spry rusty-colored foxes jump across a dozing hound.",
#     "An agile rust-toned vulpine traverses a somnolent canine.",
#     "Nimble as thought, the copper hunter bounds over the resting guard.",
#     "Tha lively rust-colored beastie hops o'er the snoozin' hound.",
#     "Single dexterous V. vulpes crosses path of dormant C. familiaris.",
#     "A lithe rust-tinted predator moves past a slumbering dog.",
#     "Observation: Russet subject exhibits agility over dormant subject.",

#     # Original 9: "A quick tan fox leaps over an inactive dog."
#     "The quick tan fox leaps over an inactive dog.",
#     "A swift tan fox leaps over an inactive dog.",
#     "A quick beige fox leaps over an inactive dog.",
#     "A quick tan fox jumps over an inactive dog.",
#     "A quick tan fox leaps over a motionless dog.",
#     "Seven quick tan foxes leap over an inactive dog.",
#     "A rapid light-brown vulpine surpasses a stationary canine.",
#     "Fast as wind, the sand-colored hunter soars over the still guard.",
#     "Tha nimble tan beastie jumps o'er the quiet doggy, aye.",
#     "One agile fawn V. vulpes traverses one immobile C. familiaris.",
#     "A fleet tan-colored predator bypasses an unmoving dog.",
#     "Field report: Tan subject demonstrates movement over static subject.",

#     # Original 10: "The brisk auburn vulpine bounces over a listless canine."
#     "Some brisk auburn vulpines bounce over a listless canine.",
#     "The quick auburn vulpine bounces over a listless canine.",
#     "The brisk russet vulpine bounces over a listless canine.",
#     "The brisk auburn fox bounces over a listless canine.",
#     "The brisk auburn vulpine jumps over a listless canine.",
#     "Five brisk auburn vulpines bounce over a listless canine.",
#     "The expeditious specimen supersedes a quiescent Canis lupus.",
#     "Swift as wind, the russet hunter vaults over the idle guardian.",
#     "Tha quick ginger beastie hops o'er the lazy mutt, aye.",
#     "One V. vulpes achieves displacement over inactive C. familiaris.",
#     "A high-velocity auburn predator traverses an immobile animal.",
#     "Final observation: Red subject shows mobility over Gray subject."
#     ]
    

#     # Initialize the calculator
#     calculator = SentenceDistortionCalculator(original_sentence, paraphrased_sentences)

#     # Calculate all metrics
#     calculator.calculate_all_metrics()

#     # Normalize the metrics
#     calculator.normalize_metrics()

#     # Calculate combined distortion
#     calculator.calculate_combined_distortion()

#     # Retrieve the normalized metrics and combined distortions
#     normalized_metrics = calculator.get_normalized_metrics()
#     combined_distortions = calculator.get_combined_distortions()
#     distortion_val=combined_distortions
#     # Display the results
#     print("Normalized Metrics:")
#     for metric_name, metric_dict in normalized_metrics.items():
#         print(f"\n{metric_name}:")
#         for key, value in metric_dict.items():
#             print(f"{key}: {value:.4f}")

#     print("\nCombined Distortions:")
#     for key, value in combined_distortions.items():
#         print(f"{key}: {value:.4f}")

#     # Plot the metrics
#     calculator.plot_metrics()