import gradio as gr
import spaces
import torch
import difflib
from threading import Thread
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextIteratorStreamer
model_id = "textcleanlm/textcleanlm-1-4b"
model = None
tokenizer = None
def load_model():
    global model, tokenizer
    if model is None:
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        
        # Add padding token if needed
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Try different model classes
        for model_class in [AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel]:
            try:
                model = model_class.from_pretrained(
                    model_id,
                    torch_dtype=torch.bfloat16,
                    device_map="auto"
                )
                break
            except:
                continue
                
        if model is None:
            raise ValueError(f"Could not load model {model_id}")
            
    return model, tokenizer
def create_diff_html(original, cleaned):
    """Create HTML diff visualization"""
    original_lines = original.splitlines(keepends=True)
    cleaned_lines = cleaned.splitlines(keepends=True)
    
    differ = difflib.unified_diff(original_lines, cleaned_lines, fromfile='Original', tofile='Cleaned', lineterm='')
    
    html_diff = '
'
    
    for line in differ:
        if line.startswith('+++') or line.startswith('---'):
            html_diff += f'
{line}
'
        elif line.startswith('@@'):
            html_diff += f'
{line}
'
        elif line.startswith('+'):
            html_diff += f'
{line}
'
        elif line.startswith('-'):
            html_diff += f'
{line}
'
        else:
            html_diff += f'
{line}
'
    
    html_diff += '