import streamlit as st
import torch
import torch.nn.functional as F
from torch.nn.functional import softmax
from transformers import XLMRobertaTokenizerFast, AutoModelForTokenClassification
import pandas as pd
import trafilatura

# Load model and tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-large")
model = AutoModelForTokenClassification.from_pretrained("dejanseo/LinkBERT-XL").to(device)
model.eval()

# Functions

def tokenize_with_indices(text: str):
    encoded = tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=True)
    return encoded['input_ids'], encoded['offset_mapping']

def fetch_and_extract_content(url: str):
    downloaded = trafilatura.fetch_url(url)
    if downloaded:
        content = trafilatura.extract(downloaded, include_comments=False, include_tables=False)
        return content
    return None

def process_text(inputs: str, confidence_threshold: float):
    max_chunk_length = 512 - 2
    words = inputs.split()
    chunk_texts = []
    current_chunk = []
    current_length = 0
    for word in words:
        if len(tokenizer.tokenize(word)) + current_length > max_chunk_length:
            chunk_texts.append(" ".join(current_chunk))
            current_chunk = [word]
            current_length = len(tokenizer.tokenize(word))
        else:
            current_chunk.append(word)
            current_length += len(tokenizer.tokenize(word))
    chunk_texts.append(" ".join(current_chunk))

    df_data = {
        'Word': [],
        'Prediction': [],
        'Confidence': [],
        'Start': [],
        'End': []
    }
    reconstructed_text = ""
    original_position_offset = 0

    for chunk in chunk_texts:
        input_ids, token_offsets = tokenize_with_indices(chunk)
        predictions = []

        input_ids_tensor = torch.tensor(input_ids).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = model(input_ids_tensor)
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=-1).squeeze().tolist()
            softmax_scores = F.softmax(logits, dim=-1).squeeze().tolist()

        word_info = {}

        for idx, (start, end) in enumerate(token_offsets):
            if idx == 0 or idx == len(token_offsets) - 1:
                continue

            word_start = start
            while word_start > 0 and chunk[word_start-1] != ' ':
                word_start -= 1

            if word_start not in word_info:
                word_info[word_start] = {'prediction': 0, 'confidence': 0.0, 'subtokens': []}

            confidence_percentage = softmax_scores[idx][predictions[idx]] * 100

            if predictions[idx] == 1 and confidence_percentage >= confidence_threshold:
                word_info[word_start]['prediction'] = 1

            word_info[word_start]['confidence'] = max(word_info[word_start]['confidence'], confidence_percentage)
            word_info[word_start]['subtokens'].append((start, end, chunk[start:end]))

        last_end = 0
        for word_start in sorted(word_info.keys()):
            word_data = word_info[word_start]
            for subtoken_start, subtoken_end, subtoken_text in word_data['subtokens']:
                escaped_subtoken_text = subtoken_text.replace('$', '\\$')  # Perform replacement outside f-string
                if last_end < subtoken_start:
                    reconstructed_text += chunk[last_end:subtoken_start]
                if word_data['prediction'] == 1:
                    reconstructed_text += f"<span style='background-color: rgba(0, 255, 0); display: inline;'>{escaped_subtoken_text}</span>"
                else:
                    reconstructed_text += escaped_subtoken_text
                last_end = subtoken_end

                df_data['Word'].append(escaped_subtoken_text)
                df_data['Prediction'].append(word_data['prediction'])
                df_data['Confidence'].append(word_info[word_start]['confidence'])
                df_data['Start'].append(subtoken_start + original_position_offset)
                df_data['End'].append(subtoken_end + original_position_offset)


            original_position_offset += len(chunk) + 1

        reconstructed_text += chunk[last_end:].replace('$', '\\$')

    df_tokens = pd.DataFrame(df_data)
    return reconstructed_text, df_tokens

# Streamlit Interface

st.set_page_config(layout="wide")
st.title('SEO by DEJAN: LinkBERT')

confidence_threshold = st.slider('Confidence Threshold', 50, 100, 50)

tab1, tab2 = st.tabs(["Text Input", "URL Input"])

with tab1:
    user_input = st.text_area("Enter text to process:")
    if st.button('Process Text'):
        highlighted_text, df_tokens = process_text(user_input, confidence_threshold)
        st.markdown(highlighted_text, unsafe_allow_html=True)
        st.dataframe(df_tokens)

with tab2:
    url_input = st.text_input("Enter URL to process:")
    if st.button('Fetch and Process'):
        content = fetch_and_extract_content(url_input)
        if content:
            highlighted_text, df_tokens = process_text(content, confidence_threshold)
            st.markdown(highlighted_text, unsafe_allow_html=True)
            st.dataframe(df_tokens)
        else:
            st.error("Could not fetch content from the URL. Please check the URL and try again.")