Spaces:
Sleeping
Sleeping
Commit
·
404dd70
1
Parent(s):
89effe3
minor bug fix
Browse files
helper.py
CHANGED
@@ -13,8 +13,8 @@ ranker_model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
|
13 |
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
|
14 |
def __init__(self, model, *args, **kwargs):
|
15 |
super().__init__(
|
16 |
-
model=AutoModelForTokenClassification.from_pretrained(model),
|
17 |
-
tokenizer=AutoTokenizer.from_pretrained(model),
|
18 |
*args,
|
19 |
**kwargs
|
20 |
)
|
@@ -32,8 +32,9 @@ def init_pipeline() :
|
|
32 |
device_map="cuda",
|
33 |
torch_dtype=torch.float16,
|
34 |
trust_remote_code=True,
|
|
|
35 |
)
|
36 |
-
summarizer_tokenizer = AutoTokenizer.from_pretrained(summarizer_model_name)
|
37 |
|
38 |
feature_extractor_model = KeyphraseExtractionPipeline(model=feature_extractor_model_name)
|
39 |
|
|
|
13 |
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
|
14 |
def __init__(self, model, *args, **kwargs):
|
15 |
super().__init__(
|
16 |
+
model=AutoModelForTokenClassification.from_pretrained(model, cache_dir='/temp/cache/'),
|
17 |
+
tokenizer=AutoTokenizer.from_pretrained(model, cache_dir='/temp/cache/'),
|
18 |
*args,
|
19 |
**kwargs
|
20 |
)
|
|
|
32 |
device_map="cuda",
|
33 |
torch_dtype=torch.float16,
|
34 |
trust_remote_code=True,
|
35 |
+
cache_dir='/temp/cache/'
|
36 |
)
|
37 |
+
summarizer_tokenizer = AutoTokenizer.from_pretrained(summarizer_model_name, cache_dir='/temp/cache/')
|
38 |
|
39 |
feature_extractor_model = KeyphraseExtractionPipeline(model=feature_extractor_model_name)
|
40 |
|