mtc commited on
Commit
d1056da
·
verified ·
1 Parent(s): f19a32c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +11 -17
README.md CHANGED
@@ -16,50 +16,44 @@ This model is a finetuned version of the [google-bert/bert-base-multilingual-cas
16
 
17
  Below is a minimal code snippet to run classification for a given article and summary sentences. The model outputs either Faithful, Intrinsic Hallucination or Extrinsic Hallucination:
18
 
19
- ```
20
  from typing import List, Dict
21
  import torch
22
  from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
23
 
24
- def preprocess_prompts_for_classification(article: str, summary_sentences: List[str]) -> List[Dict]:
25
  prompts = []
26
  for sentence in summary_sentences:
27
  prompt = {"text": article, "text_pair": sentence}
28
  prompts.append(prompt)
29
  return prompts
30
 
31
- def predict_with_hf_classification(prompts: List[Dict], model_name: str, max_context_length: int = 512,
32
- batch_size: int = 2):
33
- # Create the text generation pipeline
34
- if torch.cuda.is_available():
35
- device = "cuda"
36
- else:
37
- device = "cpu"
38
- tokenizer = AutoTokenizer.from_pretrained(model_name)
39
- model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device=device)
40
- text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer,
41
  batch_size=batch_size)
42
 
43
- # Generate text for batched inputs
44
- batch_output = text_classification_pipeline(prompts, truncation=True, max_length=max_context_length, )
45
  predictions = [result['label'] for result in batch_output]
46
  return predictions
47
 
48
  def main():
49
 
50
  model_name = "mtc/mbert-absinth-3-epochs"
 
51
  max_context_length = 512
52
  # Adjust batch_size according to your local gpu memory
53
  batch_size = 2
54
 
55
- article = "Ein neuer Zirkus ist gestern in Zürich angekommen. Viele Familien besuchten das große Zelt, um die Vorstellung zu sehen. Es gab Akrobaten, Clowns und Tiere, die das Publikum begeisterten. Der Zirkus bleibt noch eine Woche in der Stadt und bietet täglich Vorstellungen an."
56
 
57
  summary_sentences = [
58
  "Ein Zirkus ist in Basel angekommen.",
59
  "Der Zirkus, der in 1950 gegründet wurde, wird von vielen Familien besucht."]
60
 
61
- prompts = preprocess_prompts_for_classification(article=article, summary_sentences=summary_sentences)
62
- predictions = predict_with_hf_classification(prompts=prompts_for_bert, model_name=model_name,
63
  max_context_length=max_context_length, batch_size=batch_size)
64
  print(predictions)
65
 
 
16
 
17
  Below is a minimal code snippet to run classification for a given article and summary sentences. The model outputs either Faithful, Intrinsic Hallucination or Extrinsic Hallucination:
18
 
19
+ ```python
20
  from typing import List, Dict
21
  import torch
22
  from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
23
 
24
+ def generate_prompts_for_classification(article: str, summary_sentences: List[str]) -> List[Dict]:
25
  prompts = []
26
  for sentence in summary_sentences:
27
  prompt = {"text": article, "text_pair": sentence}
28
  prompts.append(prompt)
29
  return prompts
30
 
31
+ def predict_with_hf_classification_pipeline(prompts: List[Dict], model_name: str, max_context_length: int = 512,
32
+ batch_size: int = 2):
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ text_classification_pipeline = pipeline("text-classification", model=model_name, device=device,
 
 
 
 
 
 
35
  batch_size=batch_size)
36
 
37
+ batch_output = text_classification_pipeline(prompts, truncation=True, max_length=max_context_length)
 
38
  predictions = [result['label'] for result in batch_output]
39
  return predictions
40
 
41
  def main():
42
 
43
  model_name = "mtc/mbert-absinth-3-epochs"
44
+ # Articles longer than 512 tokens will be truncated
45
  max_context_length = 512
46
  # Adjust batch_size according to your local gpu memory
47
  batch_size = 2
48
 
49
+ article = "Ein neuer Zirkus ist gestern in Zürich angekommen. Viele Familien besuchten das grosse Zelt, um die Vorstellung zu sehen. Es gab Akrobaten, Clowns und Tiere, die das Publikum begeisterten. Der Zirkus bleibt noch eine Woche in der Stadt und bietet täglich Vorstellungen an."
50
 
51
  summary_sentences = [
52
  "Ein Zirkus ist in Basel angekommen.",
53
  "Der Zirkus, der in 1950 gegründet wurde, wird von vielen Familien besucht."]
54
 
55
+ prompts = generate_prompts_for_classification(article=article, summary_sentences=summary_sentences)
56
+ predictions = predict_with_hf_classification_pipeline(prompts=prompts, model_name=model_name,
57
  max_context_length=max_context_length, batch_size=batch_size)
58
  print(predictions)
59