Update README.md
Browse files
README.md
CHANGED
@@ -38,7 +38,7 @@ def generate_prompts_for_classification(article: str, summary_sentences: List[st
|
|
38 |
return prompts
|
39 |
|
40 |
def predict_with_hf_classification_pipeline(prompts: List[Dict], model_name: str, max_context_length: int = 512,
|
41 |
-
batch_size: int = 2):
|
42 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
43 |
text_classification_pipeline = pipeline("text-classification", model=model_name, device=device,
|
44 |
batch_size=batch_size)
|
|
|
38 |
return prompts
|
39 |
|
40 |
def predict_with_hf_classification_pipeline(prompts: List[Dict], model_name: str, max_context_length: int = 512,
|
41 |
+
batch_size: int = 2) -> List[str]:
|
42 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
43 |
text_classification_pipeline = pipeline("text-classification", model=model_name, device=device,
|
44 |
batch_size=batch_size)
|