kollera commited on
Commit
295eb56
·
verified ·
1 Parent(s): 870b508

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -16
app.py CHANGED
@@ -2,39 +2,40 @@ import gradio as gr
2
  from transformers import pipeline, Trainer, TrainingArguments, DistilBertForSequenceClassification, DistilBertTokenizer
3
  from datasets import load_dataset
4
  import torch
 
5
 
6
- # Carica il dataset di spam detection da Hugging Face
7
  dataset = load_dataset("tanquangduong/spam-detection-dataset-splits")
8
 
9
- # Visualizza i nomi delle colonne per verificare quale contiene il testo delle email
10
- print(dataset['train'].column_names)
11
-
12
- # Carica il tokenizer e il modello pre-addestrato
13
  tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
14
  model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
15
 
16
- # Aggiorna il nome della colonna con il nome corretto
17
  def tokenize_function(examples):
18
  return tokenizer(examples['text'], truncation=True, padding="max_length", max_length=128)
19
 
20
  tokenized_datasets = dataset.map(tokenize_function, batched=True)
21
 
22
- # Suddividi in training e test set
23
- train_dataset = tokenized_datasets["train"]
24
- test_dataset = tokenized_datasets["test"]
25
 
26
- # Definisci gli argomenti per l'addestramento
27
  training_args = TrainingArguments(
28
  output_dir="./results",
29
  evaluation_strategy="epoch",
 
30
  learning_rate=2e-5,
31
  per_device_train_batch_size=16,
32
  per_device_eval_batch_size=16,
33
- num_train_epochs=3,
34
  weight_decay=0.01,
 
 
35
  )
36
 
37
- # Crea l'oggetto Trainer
38
  trainer = Trainer(
39
  model=model,
40
  args=training_args,
@@ -42,10 +43,15 @@ trainer = Trainer(
42
  eval_dataset=test_dataset,
43
  )
44
 
45
- # Avvia il training
46
- trainer.train()
 
 
 
 
 
47
 
48
- # Definisci la funzione di classificazione usando il modello addestrato
49
  def classify_email(text):
50
  classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, framework="pt")
51
  result = classifier(text)
@@ -60,5 +66,5 @@ iface = gr.Interface(fn=classify_email,
60
  title="ZeroSpam Email Classifier",
61
  description="Inserisci l'email da analizzare per determinare se è spam o phishing.")
62
 
63
- # Avvia l'interfaccia
64
  iface.launch(share=True)
 
2
  from transformers import pipeline, Trainer, TrainingArguments, DistilBertForSequenceClassification, DistilBertTokenizer
3
  from datasets import load_dataset
4
  import torch
5
+ import os
6
 
7
+ # Carica il dataset spam detection da Hugging Face
8
  dataset = load_dataset("tanquangduong/spam-detection-dataset-splits")
9
 
10
+ # Carica il tokenizer e il modello
 
 
 
11
  tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
12
  model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
13
 
14
+ # Tokenizzazione del dataset
15
  def tokenize_function(examples):
16
  return tokenizer(examples['text'], truncation=True, padding="max_length", max_length=128)
17
 
18
  tokenized_datasets = dataset.map(tokenize_function, batched=True)
19
 
20
+ # Suddivisione in training e test set
21
+ train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(2000)) # Ridotto per velocizzare l'addestramento
22
+ test_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(500))
23
 
24
+ # Definizione degli argomenti per l'addestramento, inclusa la frequenza di salvataggio dei checkpoint
25
  training_args = TrainingArguments(
26
  output_dir="./results",
27
  evaluation_strategy="epoch",
28
+ save_strategy="epoch", # Salva un checkpoint alla fine di ogni epoch
29
  learning_rate=2e-5,
30
  per_device_train_batch_size=16,
31
  per_device_eval_batch_size=16,
32
+ num_train_epochs=1, # Ridotto a 1 epoch per evitare timeout
33
  weight_decay=0.01,
34
+ save_total_limit=2, # Limita il numero di checkpoint salvati per risparmiare spazio
35
+ load_best_model_at_end=True,
36
  )
37
 
38
+ # Creazione dell'oggetto Trainer
39
  trainer = Trainer(
40
  model=model,
41
  args=training_args,
 
43
  eval_dataset=test_dataset,
44
  )
45
 
46
+ # Avvio dell'addestramento
47
+ if os.path.exists("./results/checkpoint-1"): # Verifica se esiste un checkpoint salvato
48
+ print("Riprendi l'addestramento dal checkpoint...")
49
+ trainer.train(resume_from_checkpoint="./results/checkpoint-1")
50
+ else:
51
+ print("Inizia l'addestramento da zero...")
52
+ trainer.train()
53
 
54
+ # Definizione della funzione di classificazione usando Gradio
55
  def classify_email(text):
56
  classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, framework="pt")
57
  result = classifier(text)
 
66
  title="ZeroSpam Email Classifier",
67
  description="Inserisci l'email da analizzare per determinare se è spam o phishing.")
68
 
69
+ # Avvio dell'interfaccia
70
  iface.launch(share=True)