edithram23 commited on
Commit
a914a34
·
verified ·
1 Parent(s): 2e30d38

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +13 -5
main.py CHANGED
@@ -4,11 +4,16 @@ os.environ["HF_HOME"] = "/.cache"
4
  import re
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
 
7
- model_dir = 'edithram23/Redaction_Personal_info_v1'
8
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
9
- model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
10
 
11
- def mask_generation(text):
 
 
 
 
 
12
  import re
13
  inputs = ["Mask Generation: " + text]
14
  inputs = tokenizer(inputs, max_length=512, truncation=True, return_tensors="pt")
@@ -31,7 +36,10 @@ async def hello():
31
 
32
  @app.post("/mask")
33
  async def mask_input(query):
34
- output = mask_generation(query)
 
 
 
35
  return {"data" : output}
36
 
37
  if __name__ == '__main__':
 
4
  import re
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
 
7
+ model_dir_small = 'edithram23/Redaction'
8
+ tokenizer_small = AutoTokenizer.from_pretrained(model_dir)
9
+ model_small = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
10
 
11
+
12
+ model_dir_large = 'edithram23/Redaction_Personal_info_v1'
13
+ tokenizer_large = AutoTokenizer.from_pretrained(model_dir)
14
+ model_large = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
15
+
16
+ def mask_generation(text,model=model_small,tokenizer=tokenizer_small):
17
  import re
18
  inputs = ["Mask Generation: " + text]
19
  inputs = tokenizer(inputs, max_length=512, truncation=True, return_tensors="pt")
 
36
 
37
  @app.post("/mask")
38
  async def mask_input(query):
39
+ if(len(query)<90):
40
+ output = mask_generation(query)
41
+ else:
42
+ output = mask_generation(query,model_large,tokenizer_large)
43
  return {"data" : output}
44
 
45
  if __name__ == '__main__':