File size: 5,820 Bytes
51439a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import argparse
import os
from typing import List, Tuple

import torch
import numpy as np
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification
from seqeval.metrics import precision_score, recall_score, f1_score, classification_report


def read_conll_2col(path: str) -> Tuple[List[List[str]], List[List[str]]]:
    """Reads 2-column CoNLL (TOKEN TAG) with blank lines between sentences."""
    toks, labs = [], []
    all_toks, all_labs = [], []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.rstrip("\n")
            if not line:
                if toks:
                    all_toks.append(toks)
                    all_labs.append(labs)
                    toks, labs = [], []
                continue
            parts = line.split()
            if len(parts) < 2:
                # tolerate malformed lines
                continue
            tok, tag = parts[0], parts[-1]
            toks.append(tok)
            labs.append(tag)
    if toks:
        all_toks.append(toks)
        all_labs.append(labs)
    return all_toks, all_labs


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", type=str, default="outputs/bert-base-cased-timeNER",
                        help="Path to the fine-tuned model directory (with config.json, tokenizer files, weights).")
    parser.add_argument("--test_path", type=str, default="data/test.conll",
                        help="Path to 2-column CoNLL test file.")
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--max_length", type=int, default=256)
    args = parser.parse_args()

    assert os.path.exists(args.model_dir), f"Model dir not found: {args.model_dir}"
    assert os.path.exists(args.test_path), f"Test file not found: {args.test_path}"

    # Load tokenizer & model
    print(f"Loading model from: {args.model_dir}")
    tokenizer = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True)
    model = AutoModelForTokenClassification.from_pretrained(args.model_dir)
    model.eval()

    # Labels
    id2label = model.config.id2label
    label2id = model.config.label2id
    labels_sorted = [id2label[i] for i in range(len(id2label))]
    print(f"Model labels: {labels_sorted}")

    # Read test set
    print(f"Reading test set: {args.test_path}")
    tokens_list, tags_list = read_conll_2col(args.test_path)
    num_sents = len(tokens_list)
    num_tokens = sum(len(s) for s in tokens_list)
    print(f"Loaded {num_sents} sentences / {num_tokens} tokens")

    # Sanity check: test labels should be subset of model labels
    uniq_test_labels = sorted({t for seq in tags_list for t in seq})
    missing = [t for t in uniq_test_labels if t not in label2id]
    if missing:
        print(f"⚠️  Warning: test labels not in model: {missing}")

    # Build a simple dataset for batching convenience
    ds = Dataset.from_dict({"tokens": tokens_list, "ner_tags": tags_list})

    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    

    # Evaluation loop (word-piece alignment; label only first subword)
    all_preds: List[List[str]] = []
    all_refs: List[List[str]] = []

    # iterate in batches
    for start in range(0, len(ds), args.batch_size):
        batch = ds[start : start + args.batch_size]
        batch_tokens = batch["tokens"]  # List[List[str]]
        batch_refs = batch["ner_tags"]  # List[List[str]]

        # Tokenize as split words so we can map back using fast tokenizer encodings
        encodings = tokenizer(
            batch_tokens,
            is_split_into_words=True,
            truncation=True,
            max_length=args.max_length,
            return_tensors="pt",
            padding=True,
        )

        with torch.no_grad():
            logits = model(
                input_ids=encodings["input_ids"].to(device),
                attention_mask=encodings["attention_mask"].to(device),
                token_type_ids=encodings.get("token_type_ids", None).to(device) if "token_type_ids" in encodings else None,
            ).logits  # (bsz, seq_len, num_labels)

        pred_ids = logits.argmax(dim=-1).cpu().numpy()  # (bsz, seq_len)

        # Recover word-level predictions using encodings.word_ids()
        for i, word_labels in enumerate(batch_refs):
            encoding = encodings.encodings[i]  # fast tokenizer encoding
            word_ids = encoding.word_ids  # list[Optional[int]] aligned to tokens
            seq_pred_ids = pred_ids[i]

            word_level_preds: List[str] = []
            seen_word = None
            for tok_idx, wid in enumerate(word_ids):
                if wid is None:
                    continue
                if wid != seen_word:
                    # first subword of this word → take prediction
                    label_id = int(seq_pred_ids[tok_idx])
                    word_level_preds.append(id2label[label_id])
                    seen_word = wid
                else:
                    # subsequent subwords → skip (standard NER eval)
                    continue

            # Trim to same length as references (in case of truncation)
            L = min(len(word_labels), len(word_level_preds))
            all_refs.append(word_labels[:L])
            all_preds.append(word_level_preds[:L])

    # Metrics
    p = precision_score(all_refs, all_preds)
    r = recall_score(all_refs, all_preds)
    f1 = f1_score(all_refs, all_preds)

    print("\n Results on test set")
    print(f"Precision: {p:.4f}")
    print(f"Recall   : {r:.4f}")
    print(f"F1       : {f1:.4f}")

    print("\nSeqeval classification report")
    print(classification_report(all_refs, all_preds, digits=4))


if __name__ == "__main__":
    main()