nadiinchi commited on
Commit
0aef913
·
verified ·
1 Parent(s): 30d70a0

Update modeling_provence.py

Browse files
Files changed (1) hide show
  1. modeling_provence.py +4 -11
modeling_provence.py CHANGED
@@ -100,8 +100,8 @@ class Provence(DebertaV2PreTrainedModel):
100
  def process(
101
  self,
102
  question: Union[List[str], str],
103
- context: Union[List[List[str]], List[str], str],
104
- title: Optional[Union[List[List[str]], List[str], str]] = "first_sentence",
105
  batch_size=32,
106
  threshold=0.1,
107
  always_select_title=False,
@@ -117,14 +117,10 @@ class Provence(DebertaV2PreTrainedModel):
117
  queries = question
118
  if type(context) == str:
119
  contexts = [[context]]
120
- elif type(context) == list and type(context[0]) == str:
121
- contexts = [context]
122
  else:
123
  contexts = context
124
  if type(title) == str:
125
  titles = [[title]]
126
- elif type(title) == list and type(title[0]) == str:
127
- titles = [title]
128
  else:
129
  titles = title
130
  assert (
@@ -235,13 +231,10 @@ class Provence(DebertaV2PreTrainedModel):
235
  if type(context) == str:
236
  selected_contexts = selected_contexts[0][0]
237
  reranking_scores = reranking_scores[0][0]
238
- elif type(context) == list and type(context[0]) == str:
239
- selected_contexts = selected_contexts[0]
240
- reranking_scores = reranking_scores[0]
241
 
242
  return {
243
- "pruned_contexts": selected_contexts,
244
- "reranking_scores": reranking_scores
245
  }
246
 
247
 
 
100
  def process(
101
  self,
102
  question: Union[List[str], str],
103
+ context: Union[List[List[str]], str],
104
+ title: Optional[Union[List[List[str]], str]] = "first_sentence",
105
  batch_size=32,
106
  threshold=0.1,
107
  always_select_title=False,
 
117
  queries = question
118
  if type(context) == str:
119
  contexts = [[context]]
 
 
120
  else:
121
  contexts = context
122
  if type(title) == str:
123
  titles = [[title]]
 
 
124
  else:
125
  titles = title
126
  assert (
 
231
  if type(context) == str:
232
  selected_contexts = selected_contexts[0][0]
233
  reranking_scores = reranking_scores[0][0]
 
 
 
234
 
235
  return {
236
+ "pruned_context": selected_contexts,
237
+ "reranking_score": reranking_scores
238
  }
239
 
240