Spaces:
Sleeping
Sleeping
File size: 3,410 Bytes
dd07c8d f56d461 dd07c8d e8b6f39 dd07c8d e8b6f39 dd07c8d e8b6f39 dd07c8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import torch
import gradio as gr
from datasets import load_dataset
from transformers import pipeline
from textwrap import dedent
from email import message_from_file
from email.header import decode_header
# select device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# load model
pipe = pipeline(model="1aurent/distilbert-base-multilingual-cased-finetuned-email-spam", device=device)
# fn to predict from text
def classify_raw(text):
return pipe(text, top_k=2)
# fn to predict from form inputs
def classify_form(mailfrom, x_mailfrom, to, reply_to, subject):
text = dedent(f"""
From: {mailfrom}
X-MailFrom: {x_mailfrom}
To: {to}
Reply-To: {reply_to}
Subject: {subject}
""").strip()
return pipe(text, top_k=2)
# helper to extract header from email
def get_header(message, header_name: str) -> str:
try:
for payload, _ in decode_header(message[header_name]):
if type(payload) == bytes:
payload = payload.decode(errors="ignore")
header = payload
header = header.replace("\n", " ")
header = header.strip()
return header
except:
return ""
# fn to predict from email file
def classify_file(file):
message = message_from_file(open(file.name))
return classify_form(
mailfrom=get_header(message, "From"),
x_mailfrom=get_header(message, "X-MailFrom"),
to=get_header(message, "To"),
reply_to=get_header(message, "Reply-To"),
subject=get_header(message, "Subject"),
)
title = "Email Spam Classifier"
description = """
Spam or ham ?
"""
demo = gr.Blocks()
raw_interface = gr.Interface(
fn=classify_raw,
inputs=gr.Textbox(
label="Formatted Email Header",
lines=5,
placeholder=dedent("""
From: Laurent Fainsin <[email protected]>
X-MailFrom: Laurent Fainsin <[email protected]>
To: net7 <[email protected]>
Reply-To: Laurent Fainsin <[email protected]>
Subject: Re: Demande d'un H24 net7
""").strip(),
),
outputs="json",
api_name="predict_raw_text",
)
form_interface = gr.Interface(
fn=classify_form,
inputs=[
gr.Textbox(
label="From",
placeholder="Laurent Fainsin <[email protected]>",
),
gr.Textbox(
label="X-MailFrom",
placeholder="Laurent Fainsin <[email protected]>",
),
gr.Textbox(
label="To",
placeholder="net7 <[email protected]>",
),
gr.Textbox(
label="Reply-To",
placeholder="Laurent Fainsin <[email protected]>",
),
gr.Textbox(
label="Subject",
placeholder="Re: Demande d'un H24 net7",
),
],
outputs="json",
api_name="predict_form",
)
file_interface = gr.Interface(
fn=classify_file,
inputs=gr.File(
label="Email File",
file_types=[".eml"],
),
outputs="json",
api_name="predict_file",
)
with demo:
gr.TabbedInterface(
interface_list=[
raw_interface,
form_interface,
file_interface
],
tab_names=[
"Raw Text",
"Form",
"File"
]
)
demo.launch()
|