singhvaibhav924 commited on
Commit
4f964a6
·
1 Parent(s): 12a130d

minor bug fix

Browse files
Files changed (2) hide show
  1. app.py +0 -2
  2. helper.py +4 -5
app.py CHANGED
@@ -7,8 +7,6 @@ import os
7
  from dotenv import load_dotenv
8
 
9
  load_dotenv()
10
- os.environ['TRANSFORMERS_CACHE'] = '/temp/cache/'
11
- os.environ['HF_HOME'] = '/temp/cache/'
12
 
13
  app = FastAPI()
14
  app.add_middleware(
 
7
  from dotenv import load_dotenv
8
 
9
  load_dotenv()
 
 
10
 
11
  app = FastAPI()
12
  app.add_middleware(
helper.py CHANGED
@@ -13,8 +13,8 @@ ranker_model_name = "sentence-transformers/all-MiniLM-L6-v2"
13
  class KeyphraseExtractionPipeline(TokenClassificationPipeline):
14
  def __init__(self, model, *args, **kwargs):
15
  super().__init__(
16
- model=AutoModelForTokenClassification.from_pretrained(model, cache_dir='/temp/cache/'),
17
- tokenizer=AutoTokenizer.from_pretrained(model, cache_dir='/temp/cache/'),
18
  *args,
19
  **kwargs
20
  )
@@ -31,10 +31,9 @@ def init_pipeline() :
31
  summarizer_model_name,
32
  device_map="cuda",
33
  torch_dtype=torch.float16,
34
- trust_remote_code=True,
35
- cache_dir='/temp/cache/'
36
  )
37
- summarizer_tokenizer = AutoTokenizer.from_pretrained(summarizer_model_name, cache_dir='/temp/cache/')
38
 
39
  feature_extractor_model = KeyphraseExtractionPipeline(model=feature_extractor_model_name)
40
 
 
13
  class KeyphraseExtractionPipeline(TokenClassificationPipeline):
14
  def __init__(self, model, *args, **kwargs):
15
  super().__init__(
16
+ model=AutoModelForTokenClassification.from_pretrained(model),
17
+ tokenizer=AutoTokenizer.from_pretrained(model),
18
  *args,
19
  **kwargs
20
  )
 
31
  summarizer_model_name,
32
  device_map="cuda",
33
  torch_dtype=torch.float16,
34
+ trust_remote_code=True
 
35
  )
36
+ summarizer_tokenizer = AutoTokenizer.from_pretrained(summarizer_model_name)
37
 
38
  feature_extractor_model = KeyphraseExtractionPipeline(model=feature_extractor_model_name)
39