kollera commited on
Commit
2b1a240
·
verified ·
1 Parent(s): 82ee6db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -12
app.py CHANGED
@@ -1,22 +1,54 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
 
3
 
4
- # Inizializza il modello DistilBERT per la classificazione di testo
5
- classifier = pipeline("text-classification", model="distilbert-base-uncased")
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def classify_email(text):
 
8
  result = classifier(text)
9
- # 'result' è una lista con un solo dizionario, prendiamo il primo elemento e restituiamo solo la label
10
  label = result[0]['label']
11
  score = result[0]['score']
12
-
13
- # Modifica le etichette per essere più comprensibili
14
- if label == "LABEL_1":
15
- label_text = "Phishing"
16
- else:
17
- label_text = "Non Phishing"
18
-
19
- return {label_text: score}
20
 
21
  # Interfaccia con Gradio
22
  iface = gr.Interface(fn=classify_email,
@@ -25,4 +57,5 @@ iface = gr.Interface(fn=classify_email,
25
  title="ZeroSpam Email Classifier",
26
  description="Inserisci l'email da analizzare per determinare se è spam o phishing.")
27
 
 
28
  iface.launch(share=True)
 
1
  import gradio as gr
2
+ from transformers import pipeline, Trainer, TrainingArguments, DistilBertForSequenceClassification, DistilBertTokenizer
3
+ from datasets import load_dataset
4
+ import torch
5
 
6
+ # Carica il dataset Enron Spam da Hugging Face
7
+ dataset = load_dataset("enron_spam")
8
 
9
+ # Carica il tokenizer e il modello
10
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
11
+ model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
12
+
13
+ # Tokenizzazione del dataset
14
+ def tokenize_function(examples):
15
+ return tokenizer(examples['text'], truncation=True, padding="max_length", max_length=128)
16
+
17
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
18
+
19
+ # Suddivisione in training e test set
20
+ train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(8000))
21
+ test_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(2000))
22
+
23
+ # Definizione degli argomenti per l'addestramento
24
+ training_args = TrainingArguments(
25
+ output_dir="./results",
26
+ evaluation_strategy="epoch",
27
+ learning_rate=2e-5,
28
+ per_device_train_batch_size=16,
29
+ per_device_eval_batch_size=16,
30
+ num_train_epochs=3,
31
+ weight_decay=0.01,
32
+ )
33
+
34
+ # Creazione dell'oggetto Trainer
35
+ trainer = Trainer(
36
+ model=model,
37
+ args=training_args,
38
+ train_dataset=train_dataset,
39
+ eval_dataset=test_dataset,
40
+ )
41
+
42
+ # Avvio dell'addestramento
43
+ trainer.train()
44
+
45
+ # Definizione della funzione di classificazione usando Gradio
46
  def classify_email(text):
47
+ classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, framework="pt")
48
  result = classifier(text)
 
49
  label = result[0]['label']
50
  score = result[0]['score']
51
+ return {label: score}
 
 
 
 
 
 
 
52
 
53
  # Interfaccia con Gradio
54
  iface = gr.Interface(fn=classify_email,
 
57
  title="ZeroSpam Email Classifier",
58
  description="Inserisci l'email da analizzare per determinare se è spam o phishing.")
59
 
60
+ # Avvio dell'interfaccia
61
  iface.launch(share=True)