Update README.md
Browse files
README.md
CHANGED
@@ -16,50 +16,44 @@ This model is a finetuned version of the [google-bert/bert-base-multilingual-cas
|
|
16 |
|
17 |
Below is a minimal code snippet to run classification for a given article and summary sentences. The model outputs either Faithful, Intrinsic Hallucination or Extrinsic Hallucination:
|
18 |
|
19 |
-
```
|
20 |
from typing import List, Dict
|
21 |
import torch
|
22 |
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
|
23 |
|
24 |
-
def
|
25 |
prompts = []
|
26 |
for sentence in summary_sentences:
|
27 |
prompt = {"text": article, "text_pair": sentence}
|
28 |
prompts.append(prompt)
|
29 |
return prompts
|
30 |
|
31 |
-
def
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
device = "cuda"
|
36 |
-
else:
|
37 |
-
device = "cpu"
|
38 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
39 |
-
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device=device)
|
40 |
-
text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer,
|
41 |
batch_size=batch_size)
|
42 |
|
43 |
-
|
44 |
-
batch_output = text_classification_pipeline(prompts, truncation=True, max_length=max_context_length, )
|
45 |
predictions = [result['label'] for result in batch_output]
|
46 |
return predictions
|
47 |
|
48 |
def main():
|
49 |
|
50 |
model_name = "mtc/mbert-absinth-3-epochs"
|
|
|
51 |
max_context_length = 512
|
52 |
# Adjust batch_size according to your local gpu memory
|
53 |
batch_size = 2
|
54 |
|
55 |
-
article = "Ein neuer Zirkus ist gestern in Zürich angekommen. Viele Familien besuchten das
|
56 |
|
57 |
summary_sentences = [
|
58 |
"Ein Zirkus ist in Basel angekommen.",
|
59 |
"Der Zirkus, der in 1950 gegründet wurde, wird von vielen Familien besucht."]
|
60 |
|
61 |
-
prompts =
|
62 |
-
predictions =
|
63 |
max_context_length=max_context_length, batch_size=batch_size)
|
64 |
print(predictions)
|
65 |
|
|
|
16 |
|
17 |
Below is a minimal code snippet to run classification for a given article and summary sentences. The model outputs either Faithful, Intrinsic Hallucination or Extrinsic Hallucination:
|
18 |
|
19 |
+
```python
|
20 |
from typing import List, Dict
|
21 |
import torch
|
22 |
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
|
23 |
|
24 |
+
def generate_prompts_for_classification(article: str, summary_sentences: List[str]) -> List[Dict]:
|
25 |
prompts = []
|
26 |
for sentence in summary_sentences:
|
27 |
prompt = {"text": article, "text_pair": sentence}
|
28 |
prompts.append(prompt)
|
29 |
return prompts
|
30 |
|
31 |
+
def predict_with_hf_classification_pipeline(prompts: List[Dict], model_name: str, max_context_length: int = 512,
|
32 |
+
batch_size: int = 2):
|
33 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
34 |
+
text_classification_pipeline = pipeline("text-classification", model=model_name, device=device,
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
batch_size=batch_size)
|
36 |
|
37 |
+
batch_output = text_classification_pipeline(prompts, truncation=True, max_length=max_context_length)
|
|
|
38 |
predictions = [result['label'] for result in batch_output]
|
39 |
return predictions
|
40 |
|
41 |
def main():
|
42 |
|
43 |
model_name = "mtc/mbert-absinth-3-epochs"
|
44 |
+
# Articles longer than 512 tokens will be truncated
|
45 |
max_context_length = 512
|
46 |
# Adjust batch_size according to your local gpu memory
|
47 |
batch_size = 2
|
48 |
|
49 |
+
article = "Ein neuer Zirkus ist gestern in Zürich angekommen. Viele Familien besuchten das grosse Zelt, um die Vorstellung zu sehen. Es gab Akrobaten, Clowns und Tiere, die das Publikum begeisterten. Der Zirkus bleibt noch eine Woche in der Stadt und bietet täglich Vorstellungen an."
|
50 |
|
51 |
summary_sentences = [
|
52 |
"Ein Zirkus ist in Basel angekommen.",
|
53 |
"Der Zirkus, der in 1950 gegründet wurde, wird von vielen Familien besucht."]
|
54 |
|
55 |
+
prompts = generate_prompts_for_classification(article=article, summary_sentences=summary_sentences)
|
56 |
+
predictions = predict_with_hf_classification_pipeline(prompts=prompts, model_name=model_name,
|
57 |
max_context_length=max_context_length, batch_size=batch_size)
|
58 |
print(predictions)
|
59 |
|