import os
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, Trainer, TrainingArguments, EncoderDecoderCache
from datasets import load_dataset
from huggingface_hub import login


hf_token = os.getenv("HF_TOKEN")
if hf_token is None:
    raise ValueError("Il token HF_TOKEN non è impostato nelle variabili d'ambiente")
login(hf_token)



# Carica il dataset dal file JSONL
dataset = load_dataset("json", data_files="data.jsonl")

# Definisci il modello e il tokenizer
#model_name = "meta-llama/Llama-2-7b-hf"
model_name = "slarkprime/Llama-2-7b-QLoRA"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Sposta il modello su GPU se disponibile
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Tokenizza il dataset
def preprocess_data(example):
    inputs = tokenizer(example["text"], truncation=True, max_length=256)
    inputs["labels"] = inputs["input_ids"].copy()
    return inputs


tokenized_dataset = dataset.map(preprocess_data, batched=True)

# Configura i parametri di addestramento
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="no",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=torch.cuda.is_available(), # Usa mixed precision se GPU è disponibile
)

# Inizializza il Trainer
def start_training():
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"]
    )
    trainer.train()
    trainer.push_to_hub("to_validate_model")
    return "Training completato e caricato"

def answer_question(question):
    inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=128).to(device)
    outputs = model.generate(inputs["input_ids"], max_length=50, pad_token_id=tokenizer.eos_token_id)
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return answer

iface = gr.Interface(fn=answer_question, inputs="text", outputs="text")
train_interface = gr.Interface(fn=start_training, inputs=[], outputs="text")
app = gr.TabbedInterface([iface, train_interface], ["Q&A", "Avvia Training"])


iface.launch()