Nikhil Singh commited on
Commit
9fe2871
·
1 Parent(s): 767cd38

previous working version

Browse files
Files changed (1) hide show
  1. app.py +26 -23
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
 
3
- from transformers import T5Tokenizer, T5ForConditionalGeneration
4
  from mailparser import parse_from_file
5
  from bs4 import BeautifulSoup
6
  from gliner import GLiNER
@@ -12,9 +11,6 @@ import os
12
  import en_core_web_sm
13
  nlp = en_core_web_sm.load()
14
 
15
- t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
16
- t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
17
-
18
  _MODEL = {}
19
  _CACHE_DIR = os.environ.get("CACHE_DIR", None)
20
 
@@ -42,49 +38,57 @@ def get_sentences(further_cleaned_text):
42
  def get_model(model_name: str = None, multilingual: bool = False):
43
  if model_name is None:
44
  model_name = "urchade/gliner_base" if not multilingual else "urchade/gliner_multilingual"
45
- if model_name not in _MODEL:
 
 
 
46
  _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)
 
47
  return _MODEL[model_name]
48
 
49
- def parse_query(sentences, labels, threshold=0.3, nested_ner=False, model_name=None, multilingual=False):
50
- model = get_model(model_name, multilingual)
 
51
  results = []
 
52
  for sentence in sentences:
53
  _entities = model.predict_entities(sentence, labels, threshold=threshold)
54
- results.extend([{"text": entity["text"], "label": entity["label"]} for entity in _entities])
55
- return results
56
 
57
- def refine_entities_with_t5(entities):
58
- inputs = "refine entities: " + " ; ".join([f"{entity['text']} as {entity['label']}" for entity in entities])
59
- input_ids = t5_tokenizer.encode(inputs, return_tensors="pt", add_special_tokens=True)
60
- outputs = t5_model.generate(input_ids)
61
- result = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
62
- return result
63
 
64
  def present(email_file, labels, multilingual=False):
65
  email = accept_mail(email_file)
66
  cleaned_text = clean_email(email)
67
  further_cleaned_text = remove_special_characters(cleaned_text)
68
  sentence_list = get_sentences(further_cleaned_text)
 
69
  entities = parse_query(sentence_list, labels, threshold=0.3, nested_ner=False, model_name="urchade/gliner_base", multilingual=multilingual)
70
- refined_entities = refine_entities_with_t5(entities)
 
 
 
71
  email_info = {
72
  "Subject": email.subject,
73
  "From": email.from_,
74
  "To": email.to,
75
  "Date": email.date,
76
- "Extracted Entities": entities, # Prepare entities for DataFrame if needed
77
- "Refined Entities": refined_entities
78
  }
79
- return [email_info[key] for key in ["Subject", "From", "To", "Date", "Extracted Entities", "Refined Entities"]]
80
 
81
  labels = ["PERSON", "PRODUCT", "DEAL", "ORDER", "ORDER PAYMENT METHOD", "STORE", "LEGAL ENTITY", "MERCHANT", "FINANCIAL TRANSACTION", "UNCATEGORIZED", "DATE"]
82
 
83
  demo = gr.Interface(
84
- fn=present,
85
  inputs=[
86
  gr.components.File(label="Upload Email (.eml file)"),
87
- gr.components.CheckboxGroup(choices=labels, label="Labels to Detect", value=labels),
 
 
 
 
88
  gr.components.Checkbox(label="Use Multilingual Model")
89
  ],
90
  outputs=[
@@ -92,8 +96,7 @@ demo = gr.Interface(
92
  gr.components.Textbox(label="From"),
93
  gr.components.Textbox(label="To"),
94
  gr.components.Textbox(label="Date"),
95
- gr.components.Dataframe(headers=["Text", "Label"], label="Extracted Entities"),
96
- gr.components.Textbox(label="Refined Entities")
97
  ],
98
  title="Email Info Extractor",
99
  description="Upload an email file (.eml) to extract its details and detected entities."
 
1
  import gradio as gr
2
 
 
3
  from mailparser import parse_from_file
4
  from bs4 import BeautifulSoup
5
  from gliner import GLiNER
 
11
  import en_core_web_sm
12
  nlp = en_core_web_sm.load()
13
 
 
 
 
14
  _MODEL = {}
15
  _CACHE_DIR = os.environ.get("CACHE_DIR", None)
16
 
 
38
  def get_model(model_name: str = None, multilingual: bool = False):
39
  if model_name is None:
40
  model_name = "urchade/gliner_base" if not multilingual else "urchade/gliner_multilingual"
41
+
42
+ global _MODEL
43
+
44
+ if _MODEL.get(model_name) is None:
45
  _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)
46
+
47
  return _MODEL[model_name]
48
 
49
+ def parse_query(sentences: List[str], labels: List[str], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None, multilingual: bool = False) -> List[Dict[str, Union[str, list]]]:
50
+ model = get_model(model_name, multilingual=multilingual)
51
+
52
  results = []
53
+
54
  for sentence in sentences:
55
  _entities = model.predict_entities(sentence, labels, threshold=threshold)
56
+ entities = [{"text": entity["text"], "label": entity["label"]} for entity in _entities]
57
+ results.extend(entities)
58
 
59
+ return results
 
 
 
 
 
60
 
61
  def present(email_file, labels, multilingual=False):
62
  email = accept_mail(email_file)
63
  cleaned_text = clean_email(email)
64
  further_cleaned_text = remove_special_characters(cleaned_text)
65
  sentence_list = get_sentences(further_cleaned_text)
66
+
67
  entities = parse_query(sentence_list, labels, threshold=0.3, nested_ner=False, model_name="urchade/gliner_base", multilingual=multilingual)
68
+
69
+ # Format entities for DataFrame: Convert list of dicts to list of lists
70
+ entities_data = [[entity['text'], entity['label']] for entity in entities]
71
+
72
  email_info = {
73
  "Subject": email.subject,
74
  "From": email.from_,
75
  "To": email.to,
76
  "Date": email.date,
77
+ "Extracted Entities": entities_data # Adjusted for DataFrame
 
78
  }
79
+ return [email_info[key] for key in ["Subject", "From", "To", "Date"]] + [entities_data]
80
 
81
  labels = ["PERSON", "PRODUCT", "DEAL", "ORDER", "ORDER PAYMENT METHOD", "STORE", "LEGAL ENTITY", "MERCHANT", "FINANCIAL TRANSACTION", "UNCATEGORIZED", "DATE"]
82
 
83
  demo = gr.Interface(
84
+ fn=present,
85
  inputs=[
86
  gr.components.File(label="Upload Email (.eml file)"),
87
+ gr.components.CheckboxGroup(
88
+ choices=labels,
89
+ label="Labels to Detect",
90
+ value=labels, # Default all selected
91
+ ),
92
  gr.components.Checkbox(label="Use Multilingual Model")
93
  ],
94
  outputs=[
 
96
  gr.components.Textbox(label="From"),
97
  gr.components.Textbox(label="To"),
98
  gr.components.Textbox(label="Date"),
99
+ gr.components.Dataframe(headers=["Text", "Label"], label="Extracted Entities")
 
100
  ],
101
  title="Email Info Extractor",
102
  description="Upload an email file (.eml) to extract its details and detected entities."