import pandas as pd
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch
import torch.nn as nn

# Preprocess reviews
reviews_path = "data_reviews.txt"
with open(reviews_path, "r") as reviews_raw:
    reviews = reviews_raw.readlines()
reviews = [review.replace("TL;DR", " TL;DR ") for review in reviews]

max_length = 200

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
extra_length = len(tokenizer.encode(" TL;DR "))



from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=512,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)



model = model.to('cuda')

class ReviewDataset(Dataset):
    def __init__(self, tokenizer, reviews, max_len):
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.eos = self.tokenizer.eos_token
        self.eos_id = self.tokenizer.eos_token_id
        self.reviews = reviews
        self.result = []

        for review in self.reviews:
            # Encode the text using tokenizer.encode(). Add EOS at the end
            tokenized = self.tokenizer.encode(review + self.eos)

            # Padding/truncating the encoded sequence to max_len
            padded = self.pad_truncate(tokenized)

            # Creating a tensor and adding to the result
            self.result.append(torch.tensor(padded))

    def __len__(self):
        return len(self.result)

    def __getitem__(self, item):
        return self.result[item]

    def pad_truncate(self, name):
        name_length = len(name) - extra_length
        if name_length < self.max_len:
            difference = self.max_len - name_length
            result = name + [self.eos_id] * difference
        elif name_length > self.max_len:
            result = name[:self.max_len + 3]+[self.eos_id]
        else:
            result = name
        return result

dataset = ReviewDataset(tokenizer, reviews, max_length)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, drop_last=True)
epochs = 2
for epoch in range(epochs):
      for batch in dataloader:
          with torch.set_grad_enabled(True):
            optimizer.zero_grad()
            batch = batch.to('cuda')
            output = model(batch, labels=batch)
            loss = output.loss
            loss.backward()
            optimizer.step()
      print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item()}') 

import numpy as np
import random
def topk(probs, n=9):
    # The scores are initially softmaxed to convert to probabilities
    probs = torch.softmax(probs, dim= -1)

    # PyTorch has its own topk method, which we use here
    tokensProb, topIx = torch.topk(probs, k=n)

    # The new selection pool (9 choices) is normalized
    tokensProb = tokensProb / torch.sum(tokensProb)

    # Send to CPU for numpy handling
    tokensProb = tokensProb.cpu().detach().numpy()

    # Make a random choice from the pool based on the new prob distribution
    choice = np.random.choice(n, 1, p = tokensProb)
    tokenId = topIx[choice][0]

    return int(tokenId)

def model_infer(model, tokenizer, review, max_length=15):
    # Preprocess the init token (task designator)
    review_encoded = tokenizer.encode(review)
    result = review_encoded
    initial_input = torch.tensor(review_encoded).unsqueeze(0).to('cuda')

    with torch.set_grad_enabled(False):
        # Feed the init token to the model
        output = model(initial_input)

        # Flatten the logits at the final time step
        logits = output.logits[0,-1]

        # Make a top-k choice and append to the result
        result.append(topk(logits))

        # For max_length times:
        for _ in range(max_length):
            # Feed the current sequence to the model and make a choice
            input = torch.tensor(result).unsqueeze(0).to('cuda')
            output = model(input)
            logits = output.logits[0,-1]
            res_id = topk(logits)

            # If the chosen token is EOS, return the result
            if res_id == tokenizer.eos_token_id:
                return tokenizer.decode(result)
            else: # Append to the sequence
                result.append(res_id)
    # IF no EOS is generated, return after the max_len
    return tokenizer.decode(result)


import gradio as gr
def summarize(review):
    summary = model_infer(model, tokenizer, review + " TL;DR ")
    return summary.split(" TL;DR ")[1].strip()

iface = gr.Interface(fn=summarize, inputs="text", outputs="text")
iface.launch()