kollera commited on
Commit
6382463
·
verified ·
1 Parent(s): 61967b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -8
app.py CHANGED
@@ -1,17 +1,49 @@
1
- from datasets import load_dataset
2
- from transformers import pipeline
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- # Carica il dataset di esempio da Hugging Face (utilizziamo il dataset TREC) con `trust_remote_code=True`
6
- dataset = load_dataset("trec", trust_remote_code=True)
 
 
 
 
 
 
 
 
7
 
8
- # Visualizza i primi esempi per assicurarti che il dataset sia stato caricato correttamente
9
- print(dataset['train'][0])
 
 
 
 
 
10
 
11
- # Inizializza il modello DistilBERT per la classificazione di testo
12
- classifier = pipeline("text-classification", model="distilbert-base-uncased")
13
 
 
14
  def classify_email(text):
 
15
  result = classifier(text)
16
  label = result[0]['label']
17
  score = result[0]['score']
@@ -24,4 +56,5 @@ iface = gr.Interface(fn=classify_email,
24
  title="ZeroSpam Email Classifier",
25
  description="Inserisci l'email da analizzare per determinare se è spam o phishing.")
26
 
 
27
  iface.launch(share=True)
 
 
 
1
  import gradio as gr
2
+ from transformers import pipeline, Trainer, TrainingArguments, DistilBertForSequenceClassification, DistilBertTokenizer
3
+ from datasets import load_dataset
4
+
5
+ # Carica il dataset di spam detection da Hugging Face
6
+ dataset = load_dataset("tanquangduong/spam-detection-dataset-splits")
7
+
8
+ # Carica il tokenizer e il modello pre-addestrato
9
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
10
+ model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
11
+
12
+ # Tokenizza il dataset
13
+ def tokenize_function(examples):
14
+ return tokenizer(examples['message'], truncation=True, padding="max_length", max_length=128)
15
+
16
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
17
+
18
+ # Suddividi in training e test set
19
+ train_dataset = tokenized_datasets["train"]
20
+ test_dataset = tokenized_datasets["test"]
21
 
22
+ # Definisci gli argomenti per l'addestramento
23
+ training_args = TrainingArguments(
24
+ output_dir="./results",
25
+ evaluation_strategy="epoch",
26
+ learning_rate=2e-5,
27
+ per_device_train_batch_size=16,
28
+ per_device_eval_batch_size=16,
29
+ num_train_epochs=3,
30
+ weight_decay=0.01,
31
+ )
32
 
33
+ # Crea l'oggetto Trainer
34
+ trainer = Trainer(
35
+ model=model,
36
+ args=training_args,
37
+ train_dataset=train_dataset,
38
+ eval_dataset=test_dataset,
39
+ )
40
 
41
+ # Avvia il training
42
+ trainer.train()
43
 
44
+ # Definisci la funzione di classificazione usando il modello addestrato
45
  def classify_email(text):
46
+ classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, framework="pt")
47
  result = classifier(text)
48
  label = result[0]['label']
49
  score = result[0]['score']
 
56
  title="ZeroSpam Email Classifier",
57
  description="Inserisci l'email da analizzare per determinare se è spam o phishing.")
58
 
59
+ # Avvia l'interfaccia
60
  iface.launch(share=True)