File size: 3,410 Bytes
dd07c8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f56d461
dd07c8d
 
 
 
 
 
 
 
 
e8b6f39
dd07c8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8b6f39
dd07c8d
 
 
 
 
 
 
 
 
e8b6f39
dd07c8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
import gradio as gr
from datasets import load_dataset
from transformers import pipeline
from textwrap import dedent
from email import message_from_file
from email.header import decode_header

# select device
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# load model
pipe = pipeline(model="1aurent/distilbert-base-multilingual-cased-finetuned-email-spam", device=device)

# fn to predict from text
def classify_raw(text):
    return pipe(text, top_k=2)

# fn to predict from form inputs
def classify_form(mailfrom, x_mailfrom, to, reply_to, subject):
    text = dedent(f"""
        From: {mailfrom}
        X-MailFrom: {x_mailfrom}
        To: {to}
        Reply-To: {reply_to}
        Subject: {subject}
    """).strip()
    return pipe(text, top_k=2)

# helper to extract header from email
def get_header(message, header_name: str) -> str:
    try:
        for payload, _ in decode_header(message[header_name]):
            if type(payload) == bytes:
                payload = payload.decode(errors="ignore")
            header = payload
        header = header.replace("\n", " ")
        header = header.strip()
        return header
    except:
        return ""

# fn to predict from email file
def classify_file(file):
    message = message_from_file(open(file.name))

    return classify_form(
        mailfrom=get_header(message, "From"),
        x_mailfrom=get_header(message, "X-MailFrom"),
        to=get_header(message, "To"),
        reply_to=get_header(message, "Reply-To"),
        subject=get_header(message, "Subject"),
    )
    

title = "Email Spam Classifier"
description = """
Spam or ham ?
"""

demo = gr.Blocks()

raw_interface = gr.Interface(
    fn=classify_raw,
    inputs=gr.Textbox(
        label="Formatted Email Header",
        lines=5,
        placeholder=dedent("""
          From: Laurent Fainsin <[email protected]>
          X-MailFrom: Laurent Fainsin <[email protected]>
          To: net7 <[email protected]>
          Reply-To: Laurent Fainsin <[email protected]>
          Subject: Re: Demande d'un H24 net7
        """).strip(),
    ),
    outputs="json",
    api_name="predict_raw_text",
)

form_interface = gr.Interface(
    fn=classify_form,
    inputs=[
        gr.Textbox(
            label="From",
            placeholder="Laurent Fainsin <[email protected]>",
        ),
        gr.Textbox(
            label="X-MailFrom",
            placeholder="Laurent Fainsin <[email protected]>",
        ),
        gr.Textbox(
            label="To",
            placeholder="net7 <[email protected]>",
        ),
        gr.Textbox(
            label="Reply-To",
            placeholder="Laurent Fainsin <[email protected]>",
        ),
        gr.Textbox(
            label="Subject",
            placeholder="Re: Demande d'un H24 net7",
        ),
    ],
    outputs="json",
    api_name="predict_form",
)

file_interface = gr.Interface(
    fn=classify_file,
    inputs=gr.File(
        label="Email File",
        file_types=[".eml"],
    ),
    outputs="json",
    api_name="predict_file",
)

with demo:
    gr.TabbedInterface(
        interface_list=[
            raw_interface,
            form_interface,
            file_interface
        ],
        tab_names=[
            "Raw Text",
            "Form",
            "File"
        ]
    )

demo.launch()