Update modeling_provence.py
Browse files- modeling_provence.py +4 -11
modeling_provence.py
CHANGED
@@ -100,8 +100,8 @@ class Provence(DebertaV2PreTrainedModel):
|
|
100 |
def process(
|
101 |
self,
|
102 |
question: Union[List[str], str],
|
103 |
-
context: Union[List[List[str]],
|
104 |
-
title: Optional[Union[List[List[str]],
|
105 |
batch_size=32,
|
106 |
threshold=0.1,
|
107 |
always_select_title=False,
|
@@ -117,14 +117,10 @@ class Provence(DebertaV2PreTrainedModel):
|
|
117 |
queries = question
|
118 |
if type(context) == str:
|
119 |
contexts = [[context]]
|
120 |
-
elif type(context) == list and type(context[0]) == str:
|
121 |
-
contexts = [context]
|
122 |
else:
|
123 |
contexts = context
|
124 |
if type(title) == str:
|
125 |
titles = [[title]]
|
126 |
-
elif type(title) == list and type(title[0]) == str:
|
127 |
-
titles = [title]
|
128 |
else:
|
129 |
titles = title
|
130 |
assert (
|
@@ -235,13 +231,10 @@ class Provence(DebertaV2PreTrainedModel):
|
|
235 |
if type(context) == str:
|
236 |
selected_contexts = selected_contexts[0][0]
|
237 |
reranking_scores = reranking_scores[0][0]
|
238 |
-
elif type(context) == list and type(context[0]) == str:
|
239 |
-
selected_contexts = selected_contexts[0]
|
240 |
-
reranking_scores = reranking_scores[0]
|
241 |
|
242 |
return {
|
243 |
-
"
|
244 |
-
"
|
245 |
}
|
246 |
|
247 |
|
|
|
100 |
def process(
|
101 |
self,
|
102 |
question: Union[List[str], str],
|
103 |
+
context: Union[List[List[str]], str],
|
104 |
+
title: Optional[Union[List[List[str]], str]] = "first_sentence",
|
105 |
batch_size=32,
|
106 |
threshold=0.1,
|
107 |
always_select_title=False,
|
|
|
117 |
queries = question
|
118 |
if type(context) == str:
|
119 |
contexts = [[context]]
|
|
|
|
|
120 |
else:
|
121 |
contexts = context
|
122 |
if type(title) == str:
|
123 |
titles = [[title]]
|
|
|
|
|
124 |
else:
|
125 |
titles = title
|
126 |
assert (
|
|
|
231 |
if type(context) == str:
|
232 |
selected_contexts = selected_contexts[0][0]
|
233 |
reranking_scores = reranking_scores[0][0]
|
|
|
|
|
|
|
234 |
|
235 |
return {
|
236 |
+
"pruned_context": selected_contexts,
|
237 |
+
"reranking_score": reranking_scores
|
238 |
}
|
239 |
|
240 |
|