|
import numpy as np |
|
|
|
def classifyTokens(model, input_ids, attention_mask, bbox, offset_mapping): |
|
outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask) |
|
|
|
predictions = outputs.logits.argmax(-1).squeeze().tolist() |
|
return predictions |
|
|
|
def createEntities(model, predictions, input_ids, offset_mapping): |
|
|
|
|
|
offset_mapping = np.array(offset_mapping) |
|
is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0 |
|
|
|
id2label = {"HEADER":0, "QUESTION":1, "ANSWER":2} |
|
|
|
|
|
entities = [] |
|
current_entity = None |
|
start = None |
|
end = None |
|
|
|
for idx, (id, pred) in enumerate(zip(input_ids[0].tolist(), predictions)): |
|
if not is_subword[idx]: |
|
predicted_label = model.config.id2label[pred] |
|
if predicted_label.startswith("B") and current_entity is None: |
|
|
|
current_entity = predicted_label.replace("B-", "") |
|
start = idx |
|
if current_entity is not None and current_entity not in predicted_label: |
|
|
|
end = idx |
|
entities.append((start, end, current_entity, id2label[current_entity])) |
|
current_entity = None |
|
|
|
return entities |