import spaces
import gradio as gr
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import DataCollatorForSeq2Seq
from datasets import load_dataset, concatenate_datasets, load_from_disk
import traceback
from sklearn.metrics import accuracy_score
import numpy as np

import os
from huggingface_hub import login
from peft import get_peft_model, LoraConfig

os.environ['HF_HOME'] = '/data/.huggingface'

lora_config = LoraConfig(
    r=16,  # Rank of the low-rank adaptation
    lora_alpha=32,  # Scaling factor
    lora_dropout=0.1,  # Dropout for LoRA layers
    bias="none"  # Bias handling
)
model = AutoModelForSeq2SeqLM.from_pretrained('google/t5-efficient-tiny', num_labels=2, force_download=True)
#model = get_peft_model(model, lora_config)
#model.gradient_checkpointing_enable()   

@spaces.GPU(duration=120)
def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
    try:
        def compute_metrics(eval_pred):
            logits, labels = eval_pred
            predictions = np.argmax(logits, axis=1)
            accuracy = accuracy_score(labels, predictions)
            return {
                'eval_accuracy': accuracy,
                'eval_loss': eval_pred.loss,  # If you want to include loss as well
            }        
        login(api_key.strip())
   
    
        # Load the model and tokenizer
             
        
    
        # Set training arguments
        training_args = TrainingArguments(
            output_dir='/data/results',
            eval_strategy="steps",  # Change this to steps
            save_strategy='steps',
            learning_rate=lr*0.00001,
            per_device_train_batch_size=int(batch_size),
            per_device_eval_batch_size=int(batch_size), 
            num_train_epochs=int(num_epochs),
            weight_decay=0.01,
            #gradient_accumulation_steps=int(grad),
            #max_grad_norm = 1.0, 
            load_best_model_at_end=True,
            metric_for_best_model="accuracy",
            greater_is_better=True,
            logging_dir='/data/logs',
            logging_steps=10,
            #push_to_hub=True,
            hub_model_id=hub_id.strip(),
            fp16=True,
            #lr_scheduler_type='cosine',
            save_steps=100,  # Save checkpoint every 500 steps
            save_total_limit=3, 
        )
        # Check if a checkpoint exists and load it
        if os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir):
            print("Loading model from checkpoint...")
            model = AutoModelForSeq2SeqLM.from_pretrained(training_args.output_dir)        
    
        max_length = 128
        try:
            tokenized_train_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
            tokenized_test_dataset = load_from_disk(f'/data/{hub_id.strip()}_test_dataset')
            
            # Create Trainer
            trainer = Trainer(
                model=model,
                args=training_args,
                train_dataset=tokenized_train_dataset,
                eval_dataset=tokenized_test_dataset,
                compute_metrics=compute_metrics,
            )            
        except:
            # Load the dataset
            dataset = load_dataset(dataset_name.strip())
            tokenizer = AutoTokenizer.from_pretrained('google/t5-efficient-tiny-nh8')
            # Tokenize the dataset
            def tokenize_function(examples):
                
                # Assuming 'text' is the input and 'target' is the expected output
                model_inputs = tokenizer(
                    examples['text'], 
                    max_length=max_length,  # Set to None for dynamic padding
                    padding=True,     # Disable padding here, we will handle it later
                    truncation=True,
                )
            
                # Setup the decoder input IDs (shifted right)
                labels = tokenizer(
                    examples['target'], 
                    max_length=max_length,  # Set to None for dynamic padding
                    padding=True,     # Disable padding here, we will handle it later
                    truncation=True,
                    text_target=examples['target']  # Use text_target for target text
                )
            
                # Add labels to the model inputs
                model_inputs["labels"] = labels["input_ids"]
                return model_inputs
        
            tokenized_datasets = dataset.map(tokenize_function, batched=True)
            
            tokenized_datasets['train'].save_to_disk(f'/data/{hub_id.strip()}_train_dataset')
            tokenized_datasets['test'].save_to_disk(f'/data/{hub_id.strip()}_test_dataset')
        
            # Create Trainer
            trainer = Trainer(
                model=model,
                args=training_args,
                train_dataset=tokenized_datasets['train'],
                eval_dataset=tokenized_datasets['test'],
                compute_metrics=compute_metrics,
                #callbacks=[LoggingCallback()], 
            )            

        # Fine-tune the model
        trainer.train()
        trainer.push_to_hub(commit_message="Training complete!")
    except Exception as e:
        return f"An error occurred: {str(e)}, TB: {traceback.format_exc()}"
    return 'DONE!'#model
'''
# Define Gradio interface
def predict(text):
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name.strip(), num_labels=2)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    outputs = model(inputs)
    predictions = outputs.logits.argmax(dim=-1)
    return predictions.item()
'''

def run_train(dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
    result = fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad)
    return result
# Create Gradio interface
try:
    model = AutoModelForSeq2SeqLM.from_pretrained('google/t5-efficient-tiny-nh8'.strip(), num_labels=2, force_download=True)
    iface = gr.Interface(
        fn=run_train,
        inputs=[
            gr.Textbox(label="Dataset Name (e.g., 'imdb')"),
            gr.Textbox(label="HF hub to push to after training"),
            gr.Textbox(label="HF API token"),
            gr.Slider(minimum=1, maximum=10, value=3, label="Number of Epochs", step=1),
            gr.Slider(minimum=1, maximum=2000, value=1, label="Batch Size", step=1),
            gr.Slider(minimum=1, maximum=1000, value=1, label="Learning Rate (e-5)", step=1),
            gr.Slider(minimum=1, maximum=100, value=1, label="Gradient accumulation", step=1), 
        ],
        outputs="text",
        title="Fine-Tune Hugging Face Model",
        description="This interface allows you to fine-tune a Hugging Face model on a specified dataset."
    )
    '''
    iface = gr.Interface(
        fn=predict,
        inputs=[
            gr.Textbox(label="Query"),
        ],
        outputs="text",
        title="Fine-Tune Hugging Face Model",
        description="This interface allows you to test a fine-tune Hugging Face model."
    )
    '''
    # Launch the interface
    iface.launch()    
except Exception as e:
    print(f"An error occurred: {str(e)}, TB: {traceback.format_exc()}")