File size: 2,695 Bytes
61967b1 6382463 56863df 295eb56 6382463 295eb56 6382463 295eb56 6382463 295eb56 6382463 56863df 6382463 295eb56 2b1a240 295eb56 6382463 295eb56 6382463 295eb56 6382463 295eb56 6382463 2b1a240 295eb56 6382463 2b1a240 295eb56 2b1a240 295eb56 5f7a34d 6382463 5f7a34d 1a235fd 2b1a240 5f7a34d 295eb56 1a235fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import gradio as gr
from transformers import pipeline, Trainer, TrainingArguments, DistilBertForSequenceClassification, DistilBertTokenizer
from datasets import load_dataset
import torch
import os
# Carica il dataset spam detection da Hugging Face
dataset = load_dataset("tanquangduong/spam-detection-dataset-splits")
# Carica il tokenizer e il modello
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
# Tokenizzazione del dataset
def tokenize_function(examples):
return tokenizer(examples['text'], truncation=True, padding="max_length", max_length=128)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
# Suddivisione in training e test set
train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(2000)) # Ridotto per velocizzare l'addestramento
test_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(500))
# Definizione degli argomenti per l'addestramento, inclusa la frequenza di salvataggio dei checkpoint
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
save_strategy="epoch", # Salva un checkpoint alla fine di ogni epoch
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=1, # Ridotto a 1 epoch per evitare timeout
weight_decay=0.01,
save_total_limit=2, # Limita il numero di checkpoint salvati per risparmiare spazio
load_best_model_at_end=True,
)
# Creazione dell'oggetto Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
)
# Avvio dell'addestramento
if os.path.exists("./results/checkpoint-1"): # Verifica se esiste un checkpoint salvato
print("Riprendi l'addestramento dal checkpoint...")
trainer.train(resume_from_checkpoint="./results/checkpoint-1")
else:
print("Inizia l'addestramento da zero...")
trainer.train()
# Definizione della funzione di classificazione usando Gradio
def classify_email(text):
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, framework="pt")
result = classifier(text)
label = result[0]['label']
score = result[0]['score']
return {label: score}
# Interfaccia con Gradio
iface = gr.Interface(fn=classify_email,
inputs="text",
outputs="label",
title="ZeroSpam Email Classifier",
description="Inserisci l'email da analizzare per determinare se è spam o phishing.")
# Avvio dell'interfaccia
iface.launch(share=True)
|