edithram23 commited on
Commit
f2216e3
·
verified ·
1 Parent(s): 508f4e9
Files changed (1) hide show
  1. main.py +38 -38
main.py CHANGED
@@ -1,39 +1,39 @@
1
- import os
2
- os.environ["TRANSFORMERS_CACHE"] = "/.cache"
3
-
4
- import re
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
-
7
- model_dir = 'edithram23/Redaction'
8
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
9
- model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
10
-
11
- def mask_generation(text):
12
- import re
13
- inputs = ["Mask Generation: " + text]
14
- inputs = tokenizer(inputs, max_length=500, truncation=True, return_tensors="pt")
15
- output = model.generate(**inputs, num_beams=8, do_sample=True, max_length=len(text)+10)
16
- decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
17
- predicted_title = decoded_output.strip()
18
- pattern = r'\[.*?\]'
19
- # Replace all occurrences of the pattern with [redacted]
20
- redacted_text = re.sub(pattern, '[redacted]', predicted_title)
21
- return redacted_text
22
-
23
- from fastapi import FastAPI
24
- import uvicorn
25
-
26
- app = FastAPI()
27
-
28
- @app.get("/")
29
- async def hello():
30
- return {"msg" : "Live"}
31
-
32
- @app.post("/mask")
33
- async def mask_input(query):
34
- output = mask_generation(query)
35
- return {"data" : output}
36
-
37
- if __name__ == '__main__':
38
- os.environ["TRANSFORMERS_CACHE"] = "/.cache"
39
  uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True, workers=1)
 
1
+ import os
2
+ os.environ["HF_HOME"] = "/.cache"
3
+
4
+ import re
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+
7
+ model_dir = 'edithram23/Redaction'
8
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
10
+
11
+ def mask_generation(text):
12
+ import re
13
+ inputs = ["Mask Generation: " + text]
14
+ inputs = tokenizer(inputs, max_length=500, truncation=True, return_tensors="pt")
15
+ output = model.generate(**inputs, num_beams=8, do_sample=True, max_length=len(text)+10)
16
+ decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
17
+ predicted_title = decoded_output.strip()
18
+ pattern = r'\[.*?\]'
19
+ # Replace all occurrences of the pattern with [redacted]
20
+ redacted_text = re.sub(pattern, '[redacted]', predicted_title)
21
+ return redacted_text
22
+
23
+ from fastapi import FastAPI
24
+ import uvicorn
25
+
26
+ app = FastAPI()
27
+
28
+ @app.get("/")
29
+ async def hello():
30
+ return {"msg" : "Live"}
31
+
32
+ @app.post("/mask")
33
+ async def mask_input(query):
34
+ output = mask_generation(query)
35
+ return {"data" : output}
36
+
37
+ if __name__ == '__main__':
38
+ os.environ["HF_HOME"] = "/.cache"
39
  uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True, workers=1)