singhvaibhav924 commited on
Commit
404dd70
·
1 Parent(s): 89effe3

minor bug fix

Browse files
Files changed (1) hide show
  1. helper.py +4 -3
helper.py CHANGED
@@ -13,8 +13,8 @@ ranker_model_name = "sentence-transformers/all-MiniLM-L6-v2"
13
  class KeyphraseExtractionPipeline(TokenClassificationPipeline):
14
  def __init__(self, model, *args, **kwargs):
15
  super().__init__(
16
- model=AutoModelForTokenClassification.from_pretrained(model),
17
- tokenizer=AutoTokenizer.from_pretrained(model),
18
  *args,
19
  **kwargs
20
  )
@@ -32,8 +32,9 @@ def init_pipeline() :
32
  device_map="cuda",
33
  torch_dtype=torch.float16,
34
  trust_remote_code=True,
 
35
  )
36
- summarizer_tokenizer = AutoTokenizer.from_pretrained(summarizer_model_name)
37
 
38
  feature_extractor_model = KeyphraseExtractionPipeline(model=feature_extractor_model_name)
39
 
 
13
  class KeyphraseExtractionPipeline(TokenClassificationPipeline):
14
  def __init__(self, model, *args, **kwargs):
15
  super().__init__(
16
+ model=AutoModelForTokenClassification.from_pretrained(model, cache_dir='/temp/cache/'),
17
+ tokenizer=AutoTokenizer.from_pretrained(model, cache_dir='/temp/cache/'),
18
  *args,
19
  **kwargs
20
  )
 
32
  device_map="cuda",
33
  torch_dtype=torch.float16,
34
  trust_remote_code=True,
35
+ cache_dir='/temp/cache/'
36
  )
37
+ summarizer_tokenizer = AutoTokenizer.from_pretrained(summarizer_model_name, cache_dir='/temp/cache/')
38
 
39
  feature_extractor_model = KeyphraseExtractionPipeline(model=feature_extractor_model_name)
40