File size: 2,695 Bytes
61967b1
6382463
 
56863df
295eb56
6382463
295eb56
6382463
 
295eb56
6382463
 
 
295eb56
6382463
56863df
6382463
 
 
295eb56
 
 
2b1a240
295eb56
6382463
 
 
295eb56
6382463
 
 
295eb56
6382463
295eb56
 
6382463
2b1a240
295eb56
6382463
 
 
 
 
 
2b1a240
295eb56
 
 
 
 
 
 
2b1a240
295eb56
5f7a34d
6382463
5f7a34d
1a235fd
 
2b1a240
5f7a34d
 
 
 
 
 
 
 
295eb56
1a235fd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
from transformers import pipeline, Trainer, TrainingArguments, DistilBertForSequenceClassification, DistilBertTokenizer
from datasets import load_dataset
import torch
import os

# Carica il dataset spam detection da Hugging Face
dataset = load_dataset("tanquangduong/spam-detection-dataset-splits")

# Carica il tokenizer e il modello
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")

# Tokenizzazione del dataset
def tokenize_function(examples):
    return tokenizer(examples['text'], truncation=True, padding="max_length", max_length=128)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Suddivisione in training e test set
train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(2000))  # Ridotto per velocizzare l'addestramento
test_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(500))

# Definizione degli argomenti per l'addestramento, inclusa la frequenza di salvataggio dei checkpoint
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",  # Salva un checkpoint alla fine di ogni epoch
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,  # Ridotto a 1 epoch per evitare timeout
    weight_decay=0.01,
    save_total_limit=2,  # Limita il numero di checkpoint salvati per risparmiare spazio
    load_best_model_at_end=True,
)

# Creazione dell'oggetto Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

# Avvio dell'addestramento
if os.path.exists("./results/checkpoint-1"):  # Verifica se esiste un checkpoint salvato
    print("Riprendi l'addestramento dal checkpoint...")
    trainer.train(resume_from_checkpoint="./results/checkpoint-1")
else:
    print("Inizia l'addestramento da zero...")
    trainer.train()

# Definizione della funzione di classificazione usando Gradio
def classify_email(text):
    classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, framework="pt")
    result = classifier(text)
    label = result[0]['label']
    score = result[0]['score']
    return {label: score}

# Interfaccia con Gradio
iface = gr.Interface(fn=classify_email,
                     inputs="text",
                     outputs="label",
                     title="ZeroSpam Email Classifier",
                     description="Inserisci l'email da analizzare per determinare se è spam o phishing.")

# Avvio dell'interfaccia
iface.launch(share=True)