davebulaval commited on
Commit
7217d6a
Β·
1 Parent(s): 1982c24

uniformization of interface and add .to for tokenizer output

Browse files
Files changed (1) hide show
  1. meaningbert.py +14 -16
meaningbert.py CHANGED
@@ -64,8 +64,8 @@ _KWARGS_DESCRIPTION = """
64
  MeaningBERT metric for assessing meaning preservation between sentences.
65
 
66
  Args:
67
- documents (list of str): Document sentences.
68
- simplifications (list of str): Simplification sentences (same number of element as documents).
69
  device (str): Device to use for model inference. By default, set to "cuda".
70
 
71
  Returns:
@@ -75,10 +75,10 @@ Returns:
75
 
76
  Examples:
77
 
78
- >>> documents = ["hello there", "general kenobi"]
79
- >>> simplifications = ["hello there", "general kenobi"]
80
  >>> meaning_bert = evaluate.load("davebulaval/meaningbert", device="cuda:0")
81
- >>> results = meaning_bert.compute(documents=documents, simplifications=simplifications)
82
  """
83
 
84
  _HASH = "21845c0cc85a2e8e16c89bb0053f489095cf64c5b19e9c3865d3e10047aba51b"
@@ -110,19 +110,17 @@ class MeaningBERT(evaluate.Metric):
110
 
111
  def _compute(
112
  self,
113
- documents: List,
114
- simplifications: List,
115
  device: str = "cuda",
116
  ) -> Dict:
117
- assert len(documents) == len(
118
- simplifications
119
- ), "The number of document is different of the number of simplifications."
120
  hashcode = _HASH
121
 
122
  # Index of sentence with perfect match between two sentences
123
- matching_index = [
124
- i for i, item in enumerate(documents) if item in simplifications
125
- ]
126
 
127
  # We load the MeaningBERT pretrained model
128
  scorer = AutoModelForSequenceClassification.from_pretrained(
@@ -135,12 +133,12 @@ class MeaningBERT(evaluate.Metric):
135
 
136
  # We tokenize the text as a pair and return Pytorch Tensors
137
  tokenize_text = tokenizer(
138
- documents,
139
- simplifications,
140
  truncation=True,
141
  padding=True,
142
  return_tensors="pt",
143
- )
144
 
145
  with filter_logging_context():
146
  # We process the text
 
64
  MeaningBERT metric for assessing meaning preservation between sentences.
65
 
66
  Args:
67
+ references (list of str): References sentences.
68
+ predictions (list of str): Predictions sentences (same number of element as documents).
69
  device (str): Device to use for model inference. By default, set to "cuda".
70
 
71
  Returns:
 
75
 
76
  Examples:
77
 
78
+ >>> references = ["hello there", "general kenobi"]
79
+ >>> predictions = ["hello there", "general kenobi"]
80
  >>> meaning_bert = evaluate.load("davebulaval/meaningbert", device="cuda:0")
81
+ >>> results = meaning_bert.compute(references=references, predictions=predictions)
82
  """
83
 
84
  _HASH = "21845c0cc85a2e8e16c89bb0053f489095cf64c5b19e9c3865d3e10047aba51b"
 
110
 
111
  def _compute(
112
  self,
113
+ references: List,
114
+ predictions: List,
115
  device: str = "cuda",
116
  ) -> Dict:
117
+ assert len(references) == len(
118
+ predictions
119
+ ), "The number of references is different of the number of predictions."
120
  hashcode = _HASH
121
 
122
  # Index of sentence with perfect match between two sentences
123
+ matching_index = [i for i, item in enumerate(references) if item in predictions]
 
 
124
 
125
  # We load the MeaningBERT pretrained model
126
  scorer = AutoModelForSequenceClassification.from_pretrained(
 
133
 
134
  # We tokenize the text as a pair and return Pytorch Tensors
135
  tokenize_text = tokenizer(
136
+ references,
137
+ predictions,
138
  truncation=True,
139
  padding=True,
140
  return_tensors="pt",
141
+ ).to(device)
142
 
143
  with filter_logging_context():
144
  # We process the text