|
import argparse |
|
import os |
|
from typing import List, Tuple |
|
|
|
import torch |
|
import numpy as np |
|
from datasets import Dataset |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
from seqeval.metrics import precision_score, recall_score, f1_score, classification_report |
|
|
|
|
|
def read_conll_2col(path: str) -> Tuple[List[List[str]], List[List[str]]]: |
|
"""Reads 2-column CoNLL (TOKEN TAG) with blank lines between sentences.""" |
|
toks, labs = [], [] |
|
all_toks, all_labs = [], [] |
|
with open(path, "r", encoding="utf-8") as f: |
|
for line in f: |
|
line = line.rstrip("\n") |
|
if not line: |
|
if toks: |
|
all_toks.append(toks) |
|
all_labs.append(labs) |
|
toks, labs = [], [] |
|
continue |
|
parts = line.split() |
|
if len(parts) < 2: |
|
|
|
continue |
|
tok, tag = parts[0], parts[-1] |
|
toks.append(tok) |
|
labs.append(tag) |
|
if toks: |
|
all_toks.append(toks) |
|
all_labs.append(labs) |
|
return all_toks, all_labs |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model_dir", type=str, default="outputs/bert-base-cased-timeNER", |
|
help="Path to the fine-tuned model directory (with config.json, tokenizer files, weights).") |
|
parser.add_argument("--test_path", type=str, default="data/test.conll", |
|
help="Path to 2-column CoNLL test file.") |
|
parser.add_argument("--batch_size", type=int, default=16) |
|
parser.add_argument("--max_length", type=int, default=256) |
|
args = parser.parse_args() |
|
|
|
assert os.path.exists(args.model_dir), f"Model dir not found: {args.model_dir}" |
|
assert os.path.exists(args.test_path), f"Test file not found: {args.test_path}" |
|
|
|
|
|
print(f"Loading model from: {args.model_dir}") |
|
tokenizer = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True) |
|
model = AutoModelForTokenClassification.from_pretrained(args.model_dir) |
|
model.eval() |
|
|
|
|
|
id2label = model.config.id2label |
|
label2id = model.config.label2id |
|
labels_sorted = [id2label[i] for i in range(len(id2label))] |
|
print(f"Model labels: {labels_sorted}") |
|
|
|
|
|
print(f"Reading test set: {args.test_path}") |
|
tokens_list, tags_list = read_conll_2col(args.test_path) |
|
num_sents = len(tokens_list) |
|
num_tokens = sum(len(s) for s in tokens_list) |
|
print(f"Loaded {num_sents} sentences / {num_tokens} tokens") |
|
|
|
|
|
uniq_test_labels = sorted({t for seq in tags_list for t in seq}) |
|
missing = [t for t in uniq_test_labels if t not in label2id] |
|
if missing: |
|
print(f"⚠️ Warning: test labels not in model: {missing}") |
|
|
|
|
|
ds = Dataset.from_dict({"tokens": tokens_list, "ner_tags": tags_list}) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
|
|
all_preds: List[List[str]] = [] |
|
all_refs: List[List[str]] = [] |
|
|
|
|
|
for start in range(0, len(ds), args.batch_size): |
|
batch = ds[start : start + args.batch_size] |
|
batch_tokens = batch["tokens"] |
|
batch_refs = batch["ner_tags"] |
|
|
|
|
|
encodings = tokenizer( |
|
batch_tokens, |
|
is_split_into_words=True, |
|
truncation=True, |
|
max_length=args.max_length, |
|
return_tensors="pt", |
|
padding=True, |
|
) |
|
|
|
with torch.no_grad(): |
|
logits = model( |
|
input_ids=encodings["input_ids"].to(device), |
|
attention_mask=encodings["attention_mask"].to(device), |
|
token_type_ids=encodings.get("token_type_ids", None).to(device) if "token_type_ids" in encodings else None, |
|
).logits |
|
|
|
pred_ids = logits.argmax(dim=-1).cpu().numpy() |
|
|
|
|
|
for i, word_labels in enumerate(batch_refs): |
|
encoding = encodings.encodings[i] |
|
word_ids = encoding.word_ids |
|
seq_pred_ids = pred_ids[i] |
|
|
|
word_level_preds: List[str] = [] |
|
seen_word = None |
|
for tok_idx, wid in enumerate(word_ids): |
|
if wid is None: |
|
continue |
|
if wid != seen_word: |
|
|
|
label_id = int(seq_pred_ids[tok_idx]) |
|
word_level_preds.append(id2label[label_id]) |
|
seen_word = wid |
|
else: |
|
|
|
continue |
|
|
|
|
|
L = min(len(word_labels), len(word_level_preds)) |
|
all_refs.append(word_labels[:L]) |
|
all_preds.append(word_level_preds[:L]) |
|
|
|
|
|
p = precision_score(all_refs, all_preds) |
|
r = recall_score(all_refs, all_preds) |
|
f1 = f1_score(all_refs, all_preds) |
|
|
|
print("\n Results on test set") |
|
print(f"Precision: {p:.4f}") |
|
print(f"Recall : {r:.4f}") |
|
print(f"F1 : {f1:.4f}") |
|
|
|
print("\nSeqeval classification report") |
|
print(classification_report(all_refs, all_preds, digits=4)) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |