davebulaval commited on
Commit
364be12
Β·
1 Parent(s): 7217d6a

fix inversion of reference and prediction

Browse files
Files changed (1) hide show
  1. meaningbert.py +3 -3
meaningbert.py CHANGED
@@ -64,8 +64,8 @@ _KWARGS_DESCRIPTION = """
64
  MeaningBERT metric for assessing meaning preservation between sentences.
65
 
66
  Args:
67
- references (list of str): References sentences.
68
  predictions (list of str): Predictions sentences (same number of element as documents).
 
69
  device (str): Device to use for model inference. By default, set to "cuda".
70
 
71
  Returns:
@@ -78,7 +78,7 @@ Examples:
78
  >>> references = ["hello there", "general kenobi"]
79
  >>> predictions = ["hello there", "general kenobi"]
80
  >>> meaning_bert = evaluate.load("davebulaval/meaningbert", device="cuda:0")
81
- >>> results = meaning_bert.compute(references=references, predictions=predictions)
82
  """
83
 
84
  _HASH = "21845c0cc85a2e8e16c89bb0053f489095cf64c5b19e9c3865d3e10047aba51b"
@@ -110,8 +110,8 @@ class MeaningBERT(evaluate.Metric):
110
 
111
  def _compute(
112
  self,
113
- references: List,
114
  predictions: List,
 
115
  device: str = "cuda",
116
  ) -> Dict:
117
  assert len(references) == len(
 
64
  MeaningBERT metric for assessing meaning preservation between sentences.
65
 
66
  Args:
 
67
  predictions (list of str): Predictions sentences (same number of element as documents).
68
+ references (list of str): References sentences.
69
  device (str): Device to use for model inference. By default, set to "cuda".
70
 
71
  Returns:
 
78
  >>> references = ["hello there", "general kenobi"]
79
  >>> predictions = ["hello there", "general kenobi"]
80
  >>> meaning_bert = evaluate.load("davebulaval/meaningbert", device="cuda:0")
81
+ >>> results = meaning_bert.compute(predictions=predictions, references=references)
82
  """
83
 
84
  _HASH = "21845c0cc85a2e8e16c89bb0053f489095cf64c5b19e9c3865d3e10047aba51b"
 
110
 
111
  def _compute(
112
  self,
 
113
  predictions: List,
114
+ references: List,
115
  device: str = "cuda",
116
  ) -> Dict:
117
  assert len(references) == len(