|
import logging |
|
import os |
|
from typing import List, TextIO, Union |
|
|
|
from conllu import parse_incr |
|
from utils_ner import InputExample, Split, TokenClassificationTask |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class NER(TokenClassificationTask): |
|
def __init__(self, label_idx=-1): |
|
|
|
self.label_idx = label_idx |
|
|
|
def read_examples_from_file(self, data_dir, mode: Union[Split, str]) -> List[InputExample]: |
|
if isinstance(mode, Split): |
|
mode = mode.value |
|
file_path = os.path.join(data_dir, f"{mode}.txt") |
|
guid_index = 1 |
|
examples = [] |
|
with open(file_path, encoding="utf-8") as f: |
|
words = [] |
|
labels = [] |
|
for line in f: |
|
if line.startswith("-DOCSTART-") or line == "" or line == "\n": |
|
if words: |
|
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels)) |
|
guid_index += 1 |
|
words = [] |
|
labels = [] |
|
else: |
|
splits = line.split(" ") |
|
words.append(splits[0]) |
|
if len(splits) > 1: |
|
labels.append(splits[self.label_idx].replace("\n", "")) |
|
else: |
|
|
|
labels.append("O") |
|
if words: |
|
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels)) |
|
return examples |
|
|
|
def write_predictions_to_file(self, writer: TextIO, test_input_reader: TextIO, preds_list: List): |
|
example_id = 0 |
|
for line in test_input_reader: |
|
if line.startswith("-DOCSTART-") or line == "" or line == "\n": |
|
writer.write(line) |
|
if not preds_list[example_id]: |
|
example_id += 1 |
|
elif preds_list[example_id]: |
|
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n" |
|
writer.write(output_line) |
|
else: |
|
logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]) |
|
|
|
def get_labels(self, path: str) -> List[str]: |
|
if path: |
|
with open(path, "r") as f: |
|
labels = f.read().splitlines() |
|
if "O" not in labels: |
|
labels = ["O"] + labels |
|
return labels |
|
else: |
|
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"] |
|
|
|
|
|
class Chunk(NER): |
|
def __init__(self): |
|
|
|
super().__init__(label_idx=-2) |
|
|
|
def get_labels(self, path: str) -> List[str]: |
|
if path: |
|
with open(path, "r") as f: |
|
labels = f.read().splitlines() |
|
if "O" not in labels: |
|
labels = ["O"] + labels |
|
return labels |
|
else: |
|
return [ |
|
"O", |
|
"B-ADVP", |
|
"B-INTJ", |
|
"B-LST", |
|
"B-PRT", |
|
"B-NP", |
|
"B-SBAR", |
|
"B-VP", |
|
"B-ADJP", |
|
"B-CONJP", |
|
"B-PP", |
|
"I-ADVP", |
|
"I-INTJ", |
|
"I-LST", |
|
"I-PRT", |
|
"I-NP", |
|
"I-SBAR", |
|
"I-VP", |
|
"I-ADJP", |
|
"I-CONJP", |
|
"I-PP", |
|
] |
|
|
|
|
|
class POS(TokenClassificationTask): |
|
def read_examples_from_file(self, data_dir, mode: Union[Split, str]) -> List[InputExample]: |
|
if isinstance(mode, Split): |
|
mode = mode.value |
|
file_path = os.path.join(data_dir, f"{mode}.txt") |
|
guid_index = 1 |
|
examples = [] |
|
|
|
with open(file_path, encoding="utf-8") as f: |
|
for sentence in parse_incr(f): |
|
words = [] |
|
labels = [] |
|
for token in sentence: |
|
words.append(token["form"]) |
|
labels.append(token["upos"]) |
|
assert len(words) == len(labels) |
|
if words: |
|
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels)) |
|
guid_index += 1 |
|
return examples |
|
|
|
def write_predictions_to_file(self, writer: TextIO, test_input_reader: TextIO, preds_list: List): |
|
example_id = 0 |
|
for sentence in parse_incr(test_input_reader): |
|
s_p = preds_list[example_id] |
|
out = "" |
|
for token in sentence: |
|
out += f'{token["form"]} ({token["upos"]}|{s_p.pop(0)}) ' |
|
out += "\n" |
|
writer.write(out) |
|
example_id += 1 |
|
|
|
def get_labels(self, path: str) -> List[str]: |
|
if path: |
|
with open(path, "r") as f: |
|
return f.read().splitlines() |
|
else: |
|
return [ |
|
"ADJ", |
|
"ADP", |
|
"ADV", |
|
"AUX", |
|
"CCONJ", |
|
"DET", |
|
"INTJ", |
|
"NOUN", |
|
"NUM", |
|
"PART", |
|
"PRON", |
|
"PROPN", |
|
"PUNCT", |
|
"SCONJ", |
|
"SYM", |
|
"VERB", |
|
"X", |
|
] |
|
|