upload https://github.com/ArneBinder/pie-document-level/pull/452
Browse files- src/analysis/combine_job_returns.py +15 -11
- src/analysis/show_score_distribution.py +99 -0
- src/data/calc_iaa_for_brat.py +1 -0
- src/datamodules/datamodule.py +12 -2
- src/demo/annotation_utils.py +1 -0
- src/demo/retrieve_and_dump_all_relevant.py +196 -0
- src/demo/retriever_utils.py +51 -11
- src/document/processing.py +247 -91
- src/document/types.py +46 -0
- src/metrics/__init__.py +1 -0
- src/metrics/score_distribution.py +345 -0
- src/models/__init__.py +1 -0
- src/models/sequence_classification.py +94 -0
- src/pipeline/ner_re_pipeline.py +99 -21
- src/predict.py +8 -34
- src/serializer/interface.py +2 -2
- src/serializer/json.py +13 -7
- src/start_demo.py +35 -14
- src/taskmodules/cross_text_binary_coref_nli.py +31 -4
- src/train.py +71 -10
- src/utils/__init__.py +2 -1
- src/utils/inference_utils.py +74 -0
- src/utils/span_utils.py +14 -0
src/analysis/combine_job_returns.py
CHANGED
|
@@ -47,6 +47,7 @@ def main(
|
|
| 47 |
transpose: bool = False,
|
| 48 |
unpack_multirun_results: bool = False,
|
| 49 |
in_percent: bool = False,
|
|
|
|
| 50 |
):
|
| 51 |
file_paths = get_file_paths(
|
| 52 |
paths_file=paths_file, file_name=file_name, use_aggregated=use_aggregated
|
|
@@ -97,9 +98,6 @@ def main(
|
|
| 97 |
data = data.unstack(index_name)
|
| 98 |
data = data.T
|
| 99 |
|
| 100 |
-
if transpose:
|
| 101 |
-
data = data.T
|
| 102 |
-
|
| 103 |
# needs to happen before rounding, otherwise the rounding will be off
|
| 104 |
if in_percent:
|
| 105 |
data = data * 100
|
|
@@ -107,20 +105,23 @@ def main(
|
|
| 107 |
if round_precision is not None:
|
| 108 |
data = data.round(round_precision)
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
elif format == "markdown_mean_and_std":
|
| 113 |
-
if transpose:
|
| 114 |
-
data = data.T
|
| 115 |
if "mean" not in data.columns or "std" not in data.columns:
|
| 116 |
raise ValueError("Columns 'mean' and 'std' are required for this format.")
|
| 117 |
# create a single column with mean and std in the format: mean ± std
|
| 118 |
data = pd.DataFrame(
|
| 119 |
data["mean"].astype(str) + " ± " + data["std"].astype(str), columns=["mean ± std"]
|
| 120 |
)
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
elif format == "json":
|
| 125 |
print(data.to_json())
|
| 126 |
else:
|
|
@@ -156,6 +157,9 @@ if __name__ == "__main__":
|
|
| 156 |
parser.add_argument(
|
| 157 |
"--in-percent", action="store_true", help="Show the values in percent (multiply by 100)"
|
| 158 |
)
|
|
|
|
|
|
|
|
|
|
| 159 |
parser.add_argument(
|
| 160 |
"--format",
|
| 161 |
type=str,
|
|
|
|
| 47 |
transpose: bool = False,
|
| 48 |
unpack_multirun_results: bool = False,
|
| 49 |
in_percent: bool = False,
|
| 50 |
+
reset_index: bool = False,
|
| 51 |
):
|
| 52 |
file_paths = get_file_paths(
|
| 53 |
paths_file=paths_file, file_name=file_name, use_aggregated=use_aggregated
|
|
|
|
| 98 |
data = data.unstack(index_name)
|
| 99 |
data = data.T
|
| 100 |
|
|
|
|
|
|
|
|
|
|
| 101 |
# needs to happen before rounding, otherwise the rounding will be off
|
| 102 |
if in_percent:
|
| 103 |
data = data * 100
|
|
|
|
| 105 |
if round_precision is not None:
|
| 106 |
data = data.round(round_precision)
|
| 107 |
|
| 108 |
+
# needs to happen before transposing
|
| 109 |
+
if format == "markdown_mean_and_std":
|
|
|
|
|
|
|
|
|
|
| 110 |
if "mean" not in data.columns or "std" not in data.columns:
|
| 111 |
raise ValueError("Columns 'mean' and 'std' are required for this format.")
|
| 112 |
# create a single column with mean and std in the format: mean ± std
|
| 113 |
data = pd.DataFrame(
|
| 114 |
data["mean"].astype(str) + " ± " + data["std"].astype(str), columns=["mean ± std"]
|
| 115 |
)
|
| 116 |
+
|
| 117 |
+
if transpose:
|
| 118 |
+
data = data.T
|
| 119 |
+
|
| 120 |
+
if reset_index:
|
| 121 |
+
data = data.reset_index()
|
| 122 |
+
|
| 123 |
+
if format in ["markdown", "markdown_mean_and_std"]:
|
| 124 |
+
print(data.to_markdown(index=not reset_index))
|
| 125 |
elif format == "json":
|
| 126 |
print(data.to_json())
|
| 127 |
else:
|
|
|
|
| 157 |
parser.add_argument(
|
| 158 |
"--in-percent", action="store_true", help="Show the values in percent (multiply by 100)"
|
| 159 |
)
|
| 160 |
+
parser.add_argument(
|
| 161 |
+
"--reset-index", action="store_true", help="Reset the index of the combined job returns"
|
| 162 |
+
)
|
| 163 |
parser.add_argument(
|
| 164 |
"--format",
|
| 165 |
type=str,
|
src/analysis/show_score_distribution.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pyrootutils
|
| 2 |
+
|
| 3 |
+
root = pyrootutils.setup_root(
|
| 4 |
+
search_from=__file__,
|
| 5 |
+
indicator=[".project-root"],
|
| 6 |
+
pythonpath=True,
|
| 7 |
+
dotenv=False,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
from typing import List, Optional
|
| 12 |
+
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import plotly.figure_factory as ff
|
| 15 |
+
from pie_datasets import DatasetDict
|
| 16 |
+
|
| 17 |
+
pd.options.plotting.backend = "plotly"
|
| 18 |
+
|
| 19 |
+
if __name__ == "__main__":
|
| 20 |
+
|
| 21 |
+
parser = argparse.ArgumentParser(
|
| 22 |
+
description="Show score distribution of annotations per layer"
|
| 23 |
+
)
|
| 24 |
+
# --data-dir predictions/default/2025-02-26_14-28-17
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--data-dir", type=str, required=True, help="Path to the dataset directory"
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument("--split", type=str, default="test", help="Dataset split to use")
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--layers",
|
| 31 |
+
nargs="+",
|
| 32 |
+
default=["labeled_spans", "binary_relations"],
|
| 33 |
+
help="Annotation layers to use",
|
| 34 |
+
)
|
| 35 |
+
# --layer-captions ADUs "Argumentative Relations"
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--layer-captions", nargs="+", help="Captions for the figure traces per layer"
|
| 38 |
+
)
|
| 39 |
+
# --layer-colors "rgb(31,119,180)" "rgb(255,127,14)"
|
| 40 |
+
parser.add_argument("--layer-colors", nargs="+", help="Colors for the figure traces per layer")
|
| 41 |
+
|
| 42 |
+
args = parser.parse_args()
|
| 43 |
+
|
| 44 |
+
# Load the dataset
|
| 45 |
+
ds = DatasetDict.from_json(data_dir=args.data_dir)[args.split]
|
| 46 |
+
|
| 47 |
+
# get scores per annotation layer and label
|
| 48 |
+
layers = args.layers
|
| 49 |
+
all_scores = []
|
| 50 |
+
all_scores_idx = []
|
| 51 |
+
for doc in ds:
|
| 52 |
+
for layer in layers:
|
| 53 |
+
for ann in doc[layer].predictions:
|
| 54 |
+
all_scores.append(ann.score)
|
| 55 |
+
all_scores_idx.append((doc.id, layer, getattr(ann, "label", None)))
|
| 56 |
+
scores = pd.Series(
|
| 57 |
+
all_scores,
|
| 58 |
+
index=pd.MultiIndex.from_tuples(all_scores_idx, names=["doc_id", "layer", "label"]),
|
| 59 |
+
name="score",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
if args.layer_captions is not None:
|
| 63 |
+
if len(args.layer_captions) < len(layers):
|
| 64 |
+
raise ValueError("Not enough captions provided for all layers")
|
| 65 |
+
name_mapping = dict(zip(layers, args.layer_captions))
|
| 66 |
+
else:
|
| 67 |
+
name_mapping = dict(zip(layers, layers))
|
| 68 |
+
|
| 69 |
+
colors: Optional[List[str]] = None
|
| 70 |
+
if args.layer_colors is not None:
|
| 71 |
+
if len(args.layer_colors) < len(layers):
|
| 72 |
+
raise ValueError("Not enough colors provided for all layers")
|
| 73 |
+
color_mapping = dict(zip(layers, args.layer_colors))
|
| 74 |
+
colors = [color_mapping[layer] for layer in layers]
|
| 75 |
+
else:
|
| 76 |
+
colors = None
|
| 77 |
+
|
| 78 |
+
score_groups = {layer: scores.xs(layer, level="layer").to_numpy() for layer in layers}
|
| 79 |
+
group_labels, hist_data = zip(*score_groups.items())
|
| 80 |
+
group_labels_renamed = [name_mapping[label] for label in group_labels]
|
| 81 |
+
fig = ff.create_distplot(
|
| 82 |
+
hist_data,
|
| 83 |
+
group_labels=group_labels_renamed,
|
| 84 |
+
show_hist=True,
|
| 85 |
+
colors=colors,
|
| 86 |
+
bin_size=0.025,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
fig.update_layout(
|
| 90 |
+
height=600,
|
| 91 |
+
width=800,
|
| 92 |
+
title_text="Score Distribution per Annotation Layer",
|
| 93 |
+
title_x=0.5,
|
| 94 |
+
barmode="group",
|
| 95 |
+
)
|
| 96 |
+
fig.update_layout(legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01))
|
| 97 |
+
|
| 98 |
+
fig.show()
|
| 99 |
+
print("done")
|
src/data/calc_iaa_for_brat.py
CHANGED
|
@@ -92,6 +92,7 @@ def calc_brat_iaas(
|
|
| 92 |
create_multi_spans=True,
|
| 93 |
result_document_type=BratDocument,
|
| 94 |
result_field_mapping={"spans": "spans", "relations": "relations"},
|
|
|
|
| 95 |
)
|
| 96 |
else:
|
| 97 |
merger = None
|
|
|
|
| 92 |
create_multi_spans=True,
|
| 93 |
result_document_type=BratDocument,
|
| 94 |
result_field_mapping={"spans": "spans", "relations": "relations"},
|
| 95 |
+
combine_scores_method="product",
|
| 96 |
)
|
| 97 |
else:
|
| 98 |
merger = None
|
src/datamodules/datamodule.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Union
|
| 2 |
|
| 3 |
from pytorch_ie.core import Document
|
|
@@ -21,6 +22,8 @@ DatasetType: TypeAlias = Union[
|
|
| 21 |
IterableTaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
|
| 22 |
]
|
| 23 |
|
|
|
|
|
|
|
| 24 |
|
| 25 |
class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, TargetEncoding]):
|
| 26 |
"""A simple LightningDataModule for PIE document datasets.
|
|
@@ -49,6 +52,7 @@ class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, Ta
|
|
| 49 |
test_split: Optional[str] = "test",
|
| 50 |
show_progress_for_encode: bool = False,
|
| 51 |
train_sampler: Optional[str] = None,
|
|
|
|
| 52 |
**dataloader_kwargs,
|
| 53 |
):
|
| 54 |
super().__init__()
|
|
@@ -62,6 +66,7 @@ class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, Ta
|
|
| 62 |
self.show_progress_for_encode = show_progress_for_encode
|
| 63 |
self.train_sampler_name = train_sampler
|
| 64 |
self.dataloader_kwargs = dataloader_kwargs
|
|
|
|
| 65 |
|
| 66 |
self._data: Dict[str, DatasetType] = {}
|
| 67 |
|
|
@@ -128,12 +133,17 @@ class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, Ta
|
|
| 128 |
sampler = self.get_train_sampler(sampler_name=self.train_sampler_name, dataset=ds)
|
| 129 |
else:
|
| 130 |
sampler = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
return DataLoader(
|
| 132 |
dataset=ds,
|
| 133 |
sampler=sampler,
|
| 134 |
collate_fn=self.taskmodule.collate,
|
| 135 |
-
|
| 136 |
-
shuffle=not (isinstance(ds, IterableTaskEncodingDataset) or sampler is not None),
|
| 137 |
**self.dataloader_kwargs,
|
| 138 |
)
|
| 139 |
|
|
|
|
| 1 |
+
import logging
|
| 2 |
from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Union
|
| 3 |
|
| 4 |
from pytorch_ie.core import Document
|
|
|
|
| 22 |
IterableTaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
|
| 23 |
]
|
| 24 |
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
|
| 28 |
class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, TargetEncoding]):
|
| 29 |
"""A simple LightningDataModule for PIE document datasets.
|
|
|
|
| 52 |
test_split: Optional[str] = "test",
|
| 53 |
show_progress_for_encode: bool = False,
|
| 54 |
train_sampler: Optional[str] = None,
|
| 55 |
+
dont_shuffle_train: bool = False,
|
| 56 |
**dataloader_kwargs,
|
| 57 |
):
|
| 58 |
super().__init__()
|
|
|
|
| 66 |
self.show_progress_for_encode = show_progress_for_encode
|
| 67 |
self.train_sampler_name = train_sampler
|
| 68 |
self.dataloader_kwargs = dataloader_kwargs
|
| 69 |
+
self.dont_shuffle_train = dont_shuffle_train
|
| 70 |
|
| 71 |
self._data: Dict[str, DatasetType] = {}
|
| 72 |
|
|
|
|
| 133 |
sampler = self.get_train_sampler(sampler_name=self.train_sampler_name, dataset=ds)
|
| 134 |
else:
|
| 135 |
sampler = None
|
| 136 |
+
# don't shuffle streamed datasets or if we use a sampler or if we explicitly set dont_shuffle_train
|
| 137 |
+
shuffle = not self.dont_shuffle_train and not (
|
| 138 |
+
isinstance(ds, IterableTaskEncodingDataset) or sampler is not None
|
| 139 |
+
)
|
| 140 |
+
if not shuffle:
|
| 141 |
+
logger.warning("not shuffling train dataloader")
|
| 142 |
return DataLoader(
|
| 143 |
dataset=ds,
|
| 144 |
sampler=sampler,
|
| 145 |
collate_fn=self.taskmodule.collate,
|
| 146 |
+
shuffle=shuffle,
|
|
|
|
| 147 |
**self.dataloader_kwargs,
|
| 148 |
)
|
| 149 |
|
src/demo/annotation_utils.py
CHANGED
|
@@ -37,6 +37,7 @@ def get_merger() -> SpansViaRelationMerger:
|
|
| 37 |
"binary_relations": "binary_relations",
|
| 38 |
"labeled_partitions": "labeled_partitions",
|
| 39 |
},
|
|
|
|
| 40 |
)
|
| 41 |
|
| 42 |
|
|
|
|
| 37 |
"binary_relations": "binary_relations",
|
| 38 |
"labeled_partitions": "labeled_partitions",
|
| 39 |
},
|
| 40 |
+
combine_scores_method="product",
|
| 41 |
)
|
| 42 |
|
| 43 |
|
src/demo/retrieve_and_dump_all_relevant.py
CHANGED
|
@@ -10,9 +10,17 @@ root = pyrootutils.setup_root(
|
|
| 10 |
import argparse
|
| 11 |
import logging
|
| 12 |
import os
|
|
|
|
| 13 |
|
| 14 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
from src.demo.retriever_utils import (
|
| 17 |
retrieve_all_relevant_spans,
|
| 18 |
retrieve_all_relevant_spans_for_all_documents,
|
|
@@ -23,6 +31,168 @@ from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations
|
|
| 23 |
logger = logging.getLogger(__name__)
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
if __name__ == "__main__":
|
| 27 |
|
| 28 |
parser = argparse.ArgumentParser()
|
|
@@ -81,6 +251,19 @@ if __name__ == "__main__":
|
|
| 81 |
'(each separated by ":") to retrieve spans for. If provided, '
|
| 82 |
"--query_doc_id and --query_span_id are ignored.",
|
| 83 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
args = parser.parse_args()
|
| 85 |
|
| 86 |
logging.basicConfig(
|
|
@@ -157,4 +340,17 @@ if __name__ == "__main__":
|
|
| 157 |
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
|
| 158 |
all_spans_for_all_documents.to_json(args.output_path)
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
logger.info("done")
|
|
|
|
| 10 |
import argparse
|
| 11 |
import logging
|
| 12 |
import os
|
| 13 |
+
from typing import Dict, List, Optional, Tuple
|
| 14 |
|
| 15 |
import pandas as pd
|
| 16 |
+
from pie_datasets import Dataset, DatasetDict
|
| 17 |
+
from pytorch_ie import Annotation
|
| 18 |
+
from pytorch_ie.annotations import BinaryRelation, MultiSpan, Span
|
| 19 |
|
| 20 |
+
from document.types import (
|
| 21 |
+
RelatedRelation,
|
| 22 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
|
| 23 |
+
)
|
| 24 |
from src.demo.retriever_utils import (
|
| 25 |
retrieve_all_relevant_spans,
|
| 26 |
retrieve_all_relevant_spans_for_all_documents,
|
|
|
|
| 31 |
logger = logging.getLogger(__name__)
|
| 32 |
|
| 33 |
|
| 34 |
+
def get_original_doc_id_and_offsets(doc_id: str) -> Tuple[str, int, Optional[int]]:
|
| 35 |
+
original_doc_id, middle, start_end, ext = doc_id.split(".")
|
| 36 |
+
if middle == "remaining":
|
| 37 |
+
return original_doc_id, int(start_end), None
|
| 38 |
+
elif middle == "abstract":
|
| 39 |
+
start, end = start_end.split("_")
|
| 40 |
+
return original_doc_id, int(start), int(end)
|
| 41 |
+
else:
|
| 42 |
+
raise ValueError(f"unexpected doc_id format: {doc_id}")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def add_base_annotations(
|
| 46 |
+
documents: Dict[
|
| 47 |
+
str, TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations
|
| 48 |
+
],
|
| 49 |
+
retrieved_doc_ids: List[str],
|
| 50 |
+
retriever: DocumentAwareSpanRetrieverWithRelations,
|
| 51 |
+
) -> Dict[Tuple[str, Annotation], Tuple[str, Annotation]]:
|
| 52 |
+
# (retrieved_doc_id, retrieved_annotation) -> (original_doc_id, original_annotation)
|
| 53 |
+
annotation_mapping = {}
|
| 54 |
+
for retrieved_doc_id in retrieved_doc_ids:
|
| 55 |
+
pie_doc = retriever.get_document(retrieved_doc_id).metadata["pie_document"].copy()
|
| 56 |
+
original_doc_id, offset, _ = get_original_doc_id_and_offsets(retrieved_doc_id)
|
| 57 |
+
document = documents[original_doc_id]
|
| 58 |
+
span_mapping = {}
|
| 59 |
+
for span in pie_doc.labeled_multi_spans.predictions:
|
| 60 |
+
if isinstance(span, MultiSpan):
|
| 61 |
+
new_span = span.copy(
|
| 62 |
+
slices=[(start + offset, end + offset) for start, end in span.slices]
|
| 63 |
+
)
|
| 64 |
+
elif isinstance(span, Span):
|
| 65 |
+
new_span = span.copy(start=span.start + offset, end=span.end + offset)
|
| 66 |
+
else:
|
| 67 |
+
raise ValueError(f"unexpected span type: {span}")
|
| 68 |
+
span_mapping[span] = new_span
|
| 69 |
+
document.labeled_multi_spans.predictions.extend(span_mapping.values())
|
| 70 |
+
for relation in pie_doc.binary_relations.predictions:
|
| 71 |
+
new_relation = relation.copy(
|
| 72 |
+
head=span_mapping[relation.head], tail=span_mapping[relation.tail]
|
| 73 |
+
)
|
| 74 |
+
document.binary_relations.predictions.append(new_relation)
|
| 75 |
+
for old_ann, new_ann in span_mapping.items():
|
| 76 |
+
annotation_mapping[(retrieved_doc_id, old_ann)] = (original_doc_id, new_ann)
|
| 77 |
+
|
| 78 |
+
return annotation_mapping
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_doc_and_span_id2annotation_mapping(
|
| 82 |
+
span_ids: pd.Series,
|
| 83 |
+
doc_ids: pd.Series,
|
| 84 |
+
retriever: DocumentAwareSpanRetrieverWithRelations,
|
| 85 |
+
base_annotation_mapping: Dict[Tuple[str, Annotation], Tuple[str, Annotation]],
|
| 86 |
+
) -> Dict[Tuple[str, str], Tuple[str, Annotation]]:
|
| 87 |
+
if len(doc_ids) != len(span_ids):
|
| 88 |
+
raise ValueError("doc_ids and span_ids must have the same length")
|
| 89 |
+
doc_and_span_ids = zip(doc_ids.tolist(), span_ids.tolist())
|
| 90 |
+
return {
|
| 91 |
+
(doc_id, span_id): base_annotation_mapping[(doc_id, retriever.get_span_by_id(span_id))]
|
| 92 |
+
for doc_id, span_id in set(doc_and_span_ids)
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def add_result_to_gold_data(
|
| 97 |
+
result: pd.DataFrame,
|
| 98 |
+
gold_dataset_dir: str,
|
| 99 |
+
dataset_out_dir: str,
|
| 100 |
+
retriever: DocumentAwareSpanRetrieverWithRelations,
|
| 101 |
+
split: Optional[str] = None,
|
| 102 |
+
link_relation_label: str = "semantically_same",
|
| 103 |
+
reversed_relation_suffix: str = "_reversed",
|
| 104 |
+
):
|
| 105 |
+
|
| 106 |
+
if not os.path.exists(gold_dataset_dir):
|
| 107 |
+
raise ValueError(f"gold dataset directory does not exist: {gold_dataset_dir}")
|
| 108 |
+
|
| 109 |
+
dataset_dict = DatasetDict.from_json(data_dir=gold_dataset_dir)
|
| 110 |
+
if split is None and len(dataset_dict) == 1:
|
| 111 |
+
split = list(dataset_dict.keys())[0]
|
| 112 |
+
if split is None:
|
| 113 |
+
raise ValueError("need to provide split name to add results to gold dataset")
|
| 114 |
+
|
| 115 |
+
dataset = dataset_dict[split]
|
| 116 |
+
|
| 117 |
+
doc_id2doc = {doc.id: doc for doc in dataset}
|
| 118 |
+
retriever_doc_ids = (
|
| 119 |
+
result["doc_id"].unique().tolist() + result["query_doc_id"].unique().tolist()
|
| 120 |
+
)
|
| 121 |
+
base_annotation_mapping = add_base_annotations(
|
| 122 |
+
documents=doc_id2doc, retrieved_doc_ids=retriever_doc_ids, retriever=retriever
|
| 123 |
+
)
|
| 124 |
+
# (retriever_doc_id, retriever_span_id) -> (original_doc_id, original_span)
|
| 125 |
+
doc_and_span_id2annotation = {}
|
| 126 |
+
doc_and_span_id2annotation.update(
|
| 127 |
+
get_doc_and_span_id2annotation_mapping(
|
| 128 |
+
span_ids=result["span_id"],
|
| 129 |
+
doc_ids=result["doc_id"],
|
| 130 |
+
retriever=retriever,
|
| 131 |
+
base_annotation_mapping=base_annotation_mapping,
|
| 132 |
+
)
|
| 133 |
+
)
|
| 134 |
+
doc_and_span_id2annotation.update(
|
| 135 |
+
get_doc_and_span_id2annotation_mapping(
|
| 136 |
+
span_ids=result["ref_span_id"],
|
| 137 |
+
doc_ids=result["doc_id"],
|
| 138 |
+
retriever=retriever,
|
| 139 |
+
base_annotation_mapping=base_annotation_mapping,
|
| 140 |
+
)
|
| 141 |
+
)
|
| 142 |
+
doc_and_span_id2annotation.update(
|
| 143 |
+
get_doc_and_span_id2annotation_mapping(
|
| 144 |
+
span_ids=result["query_span_id"],
|
| 145 |
+
doc_ids=result["query_doc_id"],
|
| 146 |
+
retriever=retriever,
|
| 147 |
+
base_annotation_mapping=base_annotation_mapping,
|
| 148 |
+
)
|
| 149 |
+
)
|
| 150 |
+
doc_id2head_tail2relation = {}
|
| 151 |
+
for doc_id, doc in doc_id2doc.items():
|
| 152 |
+
head_and_tail2relation = {}
|
| 153 |
+
for relation in doc.binary_relations.predictions:
|
| 154 |
+
head_and_tail2relation[(relation.head, relation.tail)] = relation
|
| 155 |
+
doc_id2head_tail2relation[doc_id] = head_and_tail2relation
|
| 156 |
+
|
| 157 |
+
for row in result.itertuples():
|
| 158 |
+
query_doc_id, query_span = doc_and_span_id2annotation[
|
| 159 |
+
(row.query_doc_id, row.query_span_id)
|
| 160 |
+
]
|
| 161 |
+
doc_id, span = doc_and_span_id2annotation[(row.doc_id, row.span_id)]
|
| 162 |
+
doc_id2, ref_span = doc_and_span_id2annotation[(row.doc_id, row.ref_span_id)]
|
| 163 |
+
if doc_id != query_doc_id:
|
| 164 |
+
raise ValueError("doc_id and query_doc_id must be the same")
|
| 165 |
+
if doc_id != doc_id2:
|
| 166 |
+
raise ValueError("doc_id and ref_doc_id must be the same")
|
| 167 |
+
doc = doc_id2doc[doc_id]
|
| 168 |
+
link_rel = BinaryRelation(
|
| 169 |
+
head=query_span, tail=ref_span, label=link_relation_label, score=row.sim_score
|
| 170 |
+
)
|
| 171 |
+
doc.binary_relations.predictions.append(link_rel)
|
| 172 |
+
head_and_tail2relation = doc_id2head_tail2relation[doc_id]
|
| 173 |
+
related_rel_label = row.type
|
| 174 |
+
if related_rel_label.endswith(reversed_relation_suffix):
|
| 175 |
+
base_rel = head_and_tail2relation[(span, ref_span)]
|
| 176 |
+
else:
|
| 177 |
+
base_rel = head_and_tail2relation[(ref_span, span)]
|
| 178 |
+
related_rel = RelatedRelation(
|
| 179 |
+
head=query_span,
|
| 180 |
+
tail=span,
|
| 181 |
+
link_relation=link_rel,
|
| 182 |
+
relation=base_rel,
|
| 183 |
+
label=related_rel_label,
|
| 184 |
+
score=link_rel.score * base_rel.score,
|
| 185 |
+
)
|
| 186 |
+
doc.related_relations.predictions.append(related_rel)
|
| 187 |
+
|
| 188 |
+
dataset = Dataset.from_documents(list(doc_id2doc.values()))
|
| 189 |
+
dataset_dict = DatasetDict({split: dataset})
|
| 190 |
+
if not os.path.exists(dataset_out_dir):
|
| 191 |
+
os.makedirs(dataset_out_dir, exist_ok=True)
|
| 192 |
+
|
| 193 |
+
dataset_dict.to_json(dataset_out_dir)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
if __name__ == "__main__":
|
| 197 |
|
| 198 |
parser = argparse.ArgumentParser()
|
|
|
|
| 251 |
'(each separated by ":") to retrieve spans for. If provided, '
|
| 252 |
"--query_doc_id and --query_span_id are ignored.",
|
| 253 |
)
|
| 254 |
+
parser.add_argument(
|
| 255 |
+
"--gold_dataset_dir",
|
| 256 |
+
type=str,
|
| 257 |
+
default=None,
|
| 258 |
+
help="If provided, add the spans and base relations from the retriever data as well "
|
| 259 |
+
"as the related relations to the gold dataset.",
|
| 260 |
+
)
|
| 261 |
+
parser.add_argument(
|
| 262 |
+
"--dataset_out_dir",
|
| 263 |
+
type=str,
|
| 264 |
+
default=None,
|
| 265 |
+
help="If provided, save the enriched gold dataset to this directory.",
|
| 266 |
+
)
|
| 267 |
args = parser.parse_args()
|
| 268 |
|
| 269 |
logging.basicConfig(
|
|
|
|
| 340 |
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
|
| 341 |
all_spans_for_all_documents.to_json(args.output_path)
|
| 342 |
|
| 343 |
+
if args.gold_dataset_dir is not None:
|
| 344 |
+
logger.info(
|
| 345 |
+
f"reading gold data from {args.gold_dataset_dir} and adding results as predictions ..."
|
| 346 |
+
)
|
| 347 |
+
if args.dataset_out_dir is None:
|
| 348 |
+
raise ValueError("need to provide --dataset_out_dir to save the enriched dataset")
|
| 349 |
+
add_result_to_gold_data(
|
| 350 |
+
all_spans_for_all_documents,
|
| 351 |
+
gold_dataset_dir=args.gold_dataset_dir,
|
| 352 |
+
dataset_out_dir=args.dataset_out_dir,
|
| 353 |
+
retriever=retriever,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
logger.info("done")
|
src/demo/retriever_utils.py
CHANGED
|
@@ -51,6 +51,7 @@ def load_retriever(
|
|
| 51 |
def retrieve_similar_spans(
|
| 52 |
retriever: DocumentAwareSpanRetriever,
|
| 53 |
query_span_id: str,
|
|
|
|
| 54 |
**kwargs,
|
| 55 |
) -> pd.DataFrame:
|
| 56 |
if not query_span_id.strip():
|
|
@@ -60,21 +61,42 @@ def retrieve_similar_spans(
|
|
| 60 |
records = []
|
| 61 |
for similar_span_doc in retrieval_result:
|
| 62 |
pie_doc, metadata = retriever.docstore.unwrap_with_metadata(similar_span_doc)
|
|
|
|
|
|
|
| 63 |
span_ann = metadata["attached_span"]
|
|
|
|
|
|
|
|
|
|
| 64 |
records.append(
|
| 65 |
{
|
|
|
|
| 66 |
"doc_id": pie_doc.id,
|
| 67 |
"span_id": similar_span_doc.id,
|
| 68 |
-
"
|
|
|
|
|
|
|
| 69 |
"label": span_ann.label,
|
| 70 |
"text": str(span_ann),
|
| 71 |
}
|
| 72 |
)
|
| 73 |
-
|
| 74 |
-
pd.DataFrame(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
.sort_values(by="score", ascending=False)
|
| 76 |
.round(3)
|
| 77 |
)
|
|
|
|
| 78 |
except Exception as e:
|
| 79 |
raise gr.Error(f"Failed to retrieve similar ADUs: {e}")
|
| 80 |
|
|
@@ -83,6 +105,7 @@ def retrieve_relevant_spans(
|
|
| 83 |
retriever: DocumentAwareSpanRetriever,
|
| 84 |
query_span_id: str,
|
| 85 |
relation_label_mapping: Optional[dict[str, str]] = None,
|
|
|
|
| 86 |
**kwargs,
|
| 87 |
) -> pd.DataFrame:
|
| 88 |
if not query_span_id.strip():
|
|
@@ -98,40 +121,57 @@ def retrieve_relevant_spans(
|
|
| 98 |
mapped_relation_label = relation_label_mapping.get(
|
| 99 |
metadata["relation_label"], metadata["relation_label"]
|
| 100 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
records.append(
|
| 102 |
{
|
| 103 |
"doc_id": pie_doc.id,
|
| 104 |
"type": mapped_relation_label,
|
| 105 |
-
"
|
|
|
|
| 106 |
"text": str(tail_span_ann),
|
| 107 |
"span_id": relevant_span_doc.id,
|
|
|
|
| 108 |
"label": tail_span_ann.label,
|
| 109 |
-
"
|
| 110 |
"ref_label": span_ann.label,
|
| 111 |
"ref_text": str(span_ann),
|
| 112 |
"ref_span_id": metadata["head_id"],
|
|
|
|
|
|
|
| 113 |
}
|
| 114 |
)
|
| 115 |
-
|
| 116 |
pd.DataFrame(
|
| 117 |
records,
|
| 118 |
columns=[
|
|
|
|
| 119 |
"type",
|
| 120 |
-
# omitted for now, we get no valid relation scores for the generative model
|
| 121 |
-
# "rel_score",
|
| 122 |
-
"ref_score",
|
| 123 |
-
"label",
|
| 124 |
"text",
|
|
|
|
|
|
|
|
|
|
| 125 |
"ref_label",
|
|
|
|
| 126 |
"ref_text",
|
| 127 |
"doc_id",
|
| 128 |
"span_id",
|
|
|
|
| 129 |
"ref_span_id",
|
|
|
|
| 130 |
],
|
| 131 |
)
|
| 132 |
-
.sort_values(by=["
|
| 133 |
.round(3)
|
| 134 |
)
|
|
|
|
|
|
|
| 135 |
except Exception as e:
|
| 136 |
raise gr.Error(f"Failed to retrieve relevant ADUs: {e}")
|
| 137 |
|
|
|
|
| 51 |
def retrieve_similar_spans(
|
| 52 |
retriever: DocumentAwareSpanRetriever,
|
| 53 |
query_span_id: str,
|
| 54 |
+
min_score: float = 0.0,
|
| 55 |
**kwargs,
|
| 56 |
) -> pd.DataFrame:
|
| 57 |
if not query_span_id.strip():
|
|
|
|
| 61 |
records = []
|
| 62 |
for similar_span_doc in retrieval_result:
|
| 63 |
pie_doc, metadata = retriever.docstore.unwrap_with_metadata(similar_span_doc)
|
| 64 |
+
query_span = retriever.get_span_by_id(span_id=query_span_id)
|
| 65 |
+
query_span_score = query_span.score
|
| 66 |
span_ann = metadata["attached_span"]
|
| 67 |
+
sim_score = metadata["relevance_score"]
|
| 68 |
+
span_score = span_ann.score
|
| 69 |
+
score = query_span_score * sim_score * span_score
|
| 70 |
records.append(
|
| 71 |
{
|
| 72 |
+
"score": score,
|
| 73 |
"doc_id": pie_doc.id,
|
| 74 |
"span_id": similar_span_doc.id,
|
| 75 |
+
"sim_score": sim_score,
|
| 76 |
+
"query_span_score": query_span_score,
|
| 77 |
+
"span_score": span_score,
|
| 78 |
"label": span_ann.label,
|
| 79 |
"text": str(span_ann),
|
| 80 |
}
|
| 81 |
)
|
| 82 |
+
result = (
|
| 83 |
+
pd.DataFrame(
|
| 84 |
+
records,
|
| 85 |
+
columns=[
|
| 86 |
+
"score",
|
| 87 |
+
"text",
|
| 88 |
+
"label",
|
| 89 |
+
"sim_score",
|
| 90 |
+
"span_score",
|
| 91 |
+
"query_span_score",
|
| 92 |
+
"doc_id",
|
| 93 |
+
"span_id",
|
| 94 |
+
],
|
| 95 |
+
)
|
| 96 |
.sort_values(by="score", ascending=False)
|
| 97 |
.round(3)
|
| 98 |
)
|
| 99 |
+
return result[result["score"] >= min_score]
|
| 100 |
except Exception as e:
|
| 101 |
raise gr.Error(f"Failed to retrieve similar ADUs: {e}")
|
| 102 |
|
|
|
|
| 105 |
retriever: DocumentAwareSpanRetriever,
|
| 106 |
query_span_id: str,
|
| 107 |
relation_label_mapping: Optional[dict[str, str]] = None,
|
| 108 |
+
min_score: float = 0.0,
|
| 109 |
**kwargs,
|
| 110 |
) -> pd.DataFrame:
|
| 111 |
if not query_span_id.strip():
|
|
|
|
| 121 |
mapped_relation_label = relation_label_mapping.get(
|
| 122 |
metadata["relation_label"], metadata["relation_label"]
|
| 123 |
)
|
| 124 |
+
|
| 125 |
+
query_span = retriever.get_span_by_id(span_id=query_span_id)
|
| 126 |
+
query_span_score = query_span.score
|
| 127 |
+
sim_score = metadata["relevance_score"]
|
| 128 |
+
ref_span_score = span_ann.score
|
| 129 |
+
rel_score = metadata["relation_score"]
|
| 130 |
+
span_score = tail_span_ann.score
|
| 131 |
+
score = query_span_score * sim_score * ref_span_score * rel_score * span_score
|
| 132 |
records.append(
|
| 133 |
{
|
| 134 |
"doc_id": pie_doc.id,
|
| 135 |
"type": mapped_relation_label,
|
| 136 |
+
"score": score,
|
| 137 |
+
"rel_score": rel_score,
|
| 138 |
"text": str(tail_span_ann),
|
| 139 |
"span_id": relevant_span_doc.id,
|
| 140 |
+
"span_score": span_score,
|
| 141 |
"label": tail_span_ann.label,
|
| 142 |
+
"sim_score": sim_score,
|
| 143 |
"ref_label": span_ann.label,
|
| 144 |
"ref_text": str(span_ann),
|
| 145 |
"ref_span_id": metadata["head_id"],
|
| 146 |
+
"ref_span_score": ref_span_score,
|
| 147 |
+
"query_span_score": query_span_score,
|
| 148 |
}
|
| 149 |
)
|
| 150 |
+
result = (
|
| 151 |
pd.DataFrame(
|
| 152 |
records,
|
| 153 |
columns=[
|
| 154 |
+
"score",
|
| 155 |
"type",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
"text",
|
| 157 |
+
"label",
|
| 158 |
+
"rel_score",
|
| 159 |
+
"sim_score",
|
| 160 |
"ref_label",
|
| 161 |
+
"ref_span_score",
|
| 162 |
"ref_text",
|
| 163 |
"doc_id",
|
| 164 |
"span_id",
|
| 165 |
+
"span_score",
|
| 166 |
"ref_span_id",
|
| 167 |
+
"query_span_score",
|
| 168 |
],
|
| 169 |
)
|
| 170 |
+
.sort_values(by=["score"], ascending=False)
|
| 171 |
.round(3)
|
| 172 |
)
|
| 173 |
+
return result[result["score"] >= min_score]
|
| 174 |
+
|
| 175 |
except Exception as e:
|
| 176 |
raise gr.Error(f"Failed to retrieve relevant ADUs: {e}")
|
| 177 |
|
src/document/processing.py
CHANGED
|
@@ -1,16 +1,20 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import logging
|
| 4 |
-
from
|
|
|
|
| 5 |
|
| 6 |
-
from pie_modules.document.processing.merge_spans_via_relation import _merge_spans_via_relation
|
| 7 |
-
from pie_modules.documents import TextDocumentWithLabeledMultiSpansAndBinaryRelations
|
| 8 |
from pie_modules.utils.span import have_overlap
|
| 9 |
from pytorch_ie import AnnotationLayer
|
|
|
|
| 10 |
from pytorch_ie.core import Document
|
| 11 |
from pytorch_ie.core.document import Annotation, _enumerate_dependencies
|
| 12 |
|
| 13 |
-
from src.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from src.utils.span_utils import get_overlap_len
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
|
@@ -68,58 +72,6 @@ def remove_overlapping_entities(
|
|
| 68 |
return new_doc
|
| 69 |
|
| 70 |
|
| 71 |
-
# TODO: remove and use pie_modules.document.processing.SpansViaRelationMerger instead
|
| 72 |
-
def merge_spans_via_relation(
|
| 73 |
-
document: D,
|
| 74 |
-
relation_layer: str,
|
| 75 |
-
link_relation_label: str,
|
| 76 |
-
use_predicted_spans: bool = False,
|
| 77 |
-
process_predictions: bool = True,
|
| 78 |
-
create_multi_spans: bool = False,
|
| 79 |
-
) -> D:
|
| 80 |
-
|
| 81 |
-
rel_layer = document[relation_layer]
|
| 82 |
-
span_layer = rel_layer.target_layer
|
| 83 |
-
new_gold_spans, new_gold_relations = _merge_spans_via_relation(
|
| 84 |
-
spans=span_layer,
|
| 85 |
-
relations=rel_layer,
|
| 86 |
-
link_relation_label=link_relation_label,
|
| 87 |
-
create_multi_spans=create_multi_spans,
|
| 88 |
-
)
|
| 89 |
-
if process_predictions:
|
| 90 |
-
new_pred_spans, new_pred_relations = _merge_spans_via_relation(
|
| 91 |
-
spans=span_layer.predictions if use_predicted_spans else span_layer,
|
| 92 |
-
relations=rel_layer.predictions,
|
| 93 |
-
link_relation_label=link_relation_label,
|
| 94 |
-
create_multi_spans=create_multi_spans,
|
| 95 |
-
)
|
| 96 |
-
else:
|
| 97 |
-
assert not use_predicted_spans
|
| 98 |
-
new_pred_spans = set(span_layer.predictions.clear())
|
| 99 |
-
new_pred_relations = set(rel_layer.predictions.clear())
|
| 100 |
-
|
| 101 |
-
relation_layer_name = relation_layer
|
| 102 |
-
span_layer_name = document[relation_layer].target_name
|
| 103 |
-
if create_multi_spans:
|
| 104 |
-
doc_dict = document.asdict()
|
| 105 |
-
for f in document.annotation_fields():
|
| 106 |
-
doc_dict.pop(f.name)
|
| 107 |
-
|
| 108 |
-
result = TextDocumentWithLabeledMultiSpansAndBinaryRelations.fromdict(doc_dict)
|
| 109 |
-
result.labeled_multi_spans.extend(new_gold_spans)
|
| 110 |
-
result.labeled_multi_spans.predictions.extend(new_pred_spans)
|
| 111 |
-
result.binary_relations.extend(new_gold_relations)
|
| 112 |
-
result.binary_relations.predictions.extend(new_pred_relations)
|
| 113 |
-
else:
|
| 114 |
-
result = document.copy(with_annotations=False)
|
| 115 |
-
result[span_layer_name].extend(new_gold_spans)
|
| 116 |
-
result[span_layer_name].predictions.extend(new_pred_spans)
|
| 117 |
-
result[relation_layer_name].extend(new_gold_relations)
|
| 118 |
-
result[relation_layer_name].predictions.extend(new_pred_relations)
|
| 119 |
-
|
| 120 |
-
return result
|
| 121 |
-
|
| 122 |
-
|
| 123 |
def remove_partitions_by_labels(
|
| 124 |
document: D, partition_layer: str, label_blacklist: List[str], span_layer: Optional[str] = None
|
| 125 |
) -> D:
|
|
@@ -249,31 +201,19 @@ def relabel_annotations(
|
|
| 249 |
DWithSpans = TypeVar("DWithSpans", bound=Document)
|
| 250 |
|
| 251 |
|
| 252 |
-
def
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
predicted span and the gold span have an overlap of at least half of the maximum length
|
| 261 |
-
of the two spans, the predicted span is aligned with the gold span.
|
| 262 |
-
|
| 263 |
-
Args:
|
| 264 |
-
document: The document to process.
|
| 265 |
-
span_layer: The name of the span layer.
|
| 266 |
-
distance_type: The type of distance to calculate. One of: center, inner, outer
|
| 267 |
-
verbose: Whether to print debug information.
|
| 268 |
|
| 269 |
-
Returns:
|
| 270 |
-
The processed document.
|
| 271 |
-
"""
|
| 272 |
-
gold_spans = document[span_layer]
|
| 273 |
-
if len(gold_spans) == 0:
|
| 274 |
-
return document.copy()
|
| 275 |
|
| 276 |
-
|
|
|
|
|
|
|
| 277 |
old2new_pred_span = {}
|
| 278 |
span_id2gold_span = {}
|
| 279 |
for pred_span in pred_spans:
|
|
@@ -282,29 +222,32 @@ def align_predicted_span_annotations(
|
|
| 282 |
(
|
| 283 |
gold_span,
|
| 284 |
distance(
|
| 285 |
-
start_end=(pred_span
|
| 286 |
-
other_start_end=(gold_span
|
| 287 |
distance_type=distance_type,
|
| 288 |
),
|
| 289 |
)
|
| 290 |
for gold_span in gold_spans
|
| 291 |
]
|
|
|
|
|
|
|
| 292 |
|
| 293 |
closest_gold_span, min_distance = min(gold_spans_with_distance, key=lambda x: x[1])
|
| 294 |
# if the closest gold span is the same as the predicted span, we don't need to align
|
| 295 |
if min_distance == 0.0:
|
| 296 |
continue
|
| 297 |
|
|
|
|
|
|
|
|
|
|
| 298 |
if have_overlap(
|
| 299 |
-
start_end=
|
| 300 |
-
other_start_end=
|
| 301 |
):
|
| 302 |
-
overlap_len = get_overlap_len(
|
| 303 |
-
(pred_span.start, pred_span.end), (closest_gold_span.start, closest_gold_span.end)
|
| 304 |
-
)
|
| 305 |
-
# get the maximum length of the two spans
|
| 306 |
l_max = max(
|
| 307 |
-
|
|
|
|
| 308 |
)
|
| 309 |
# if the overlap is at least half of the maximum length, we consider it a valid match for alignment
|
| 310 |
valid_match = overlap_len >= (l_max / 2)
|
|
@@ -312,12 +255,140 @@ def align_predicted_span_annotations(
|
|
| 312 |
valid_match = False
|
| 313 |
|
| 314 |
if valid_match:
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
old2new_pred_span[pred_span._id] = aligned_pred_span
|
| 319 |
span_id2gold_span[pred_span._id] = closest_gold_span
|
| 320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
result = document.copy(with_annotations=False)
|
| 322 |
|
| 323 |
# multiple predicted spans can be aligned with the same gold span,
|
|
@@ -356,3 +427,88 @@ def align_predicted_span_annotations(
|
|
| 356 |
)
|
| 357 |
|
| 358 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import logging
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union
|
| 6 |
|
|
|
|
|
|
|
| 7 |
from pie_modules.utils.span import have_overlap
|
| 8 |
from pytorch_ie import AnnotationLayer
|
| 9 |
+
from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan, MultiSpan, Span
|
| 10 |
from pytorch_ie.core import Document
|
| 11 |
from pytorch_ie.core.document import Annotation, _enumerate_dependencies
|
| 12 |
|
| 13 |
+
from src.document.types import (
|
| 14 |
+
RelatedRelation,
|
| 15 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
|
| 16 |
+
)
|
| 17 |
+
from src.utils import distance, distance_slices
|
| 18 |
from src.utils.span_utils import get_overlap_len
|
| 19 |
|
| 20 |
logger = logging.getLogger(__name__)
|
|
|
|
| 72 |
return new_doc
|
| 73 |
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
def remove_partitions_by_labels(
|
| 76 |
document: D, partition_layer: str, label_blacklist: List[str], span_layer: Optional[str] = None
|
| 77 |
) -> D:
|
|
|
|
| 201 |
DWithSpans = TypeVar("DWithSpans", bound=Document)
|
| 202 |
|
| 203 |
|
| 204 |
+
def get_start_end(span: Union[Span, MultiSpan]) -> Tuple[int, int]:
|
| 205 |
+
if isinstance(span, Span):
|
| 206 |
+
return span.start, span.end
|
| 207 |
+
elif isinstance(span, MultiSpan):
|
| 208 |
+
starts, ends = zip(*span.slices)
|
| 209 |
+
return min(starts), max(ends)
|
| 210 |
+
else:
|
| 211 |
+
raise ValueError(f"Unsupported span type: {type(span)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
+
def _get_aligned_span_mappings(
|
| 215 |
+
gold_spans: Iterable[Span], pred_spans: Iterable[Span], distance_type: str
|
| 216 |
+
) -> Tuple[Dict[int, Span], Dict[int, Span]]:
|
| 217 |
old2new_pred_span = {}
|
| 218 |
span_id2gold_span = {}
|
| 219 |
for pred_span in pred_spans:
|
|
|
|
| 222 |
(
|
| 223 |
gold_span,
|
| 224 |
distance(
|
| 225 |
+
start_end=get_start_end(pred_span),
|
| 226 |
+
other_start_end=get_start_end(gold_span),
|
| 227 |
distance_type=distance_type,
|
| 228 |
),
|
| 229 |
)
|
| 230 |
for gold_span in gold_spans
|
| 231 |
]
|
| 232 |
+
if len(gold_spans_with_distance) == 0:
|
| 233 |
+
continue
|
| 234 |
|
| 235 |
closest_gold_span, min_distance = min(gold_spans_with_distance, key=lambda x: x[1])
|
| 236 |
# if the closest gold span is the same as the predicted span, we don't need to align
|
| 237 |
if min_distance == 0.0:
|
| 238 |
continue
|
| 239 |
|
| 240 |
+
pred_start_end = get_start_end(pred_span)
|
| 241 |
+
closest_gold_start_end = get_start_end(closest_gold_span)
|
| 242 |
+
|
| 243 |
if have_overlap(
|
| 244 |
+
start_end=pred_start_end,
|
| 245 |
+
other_start_end=closest_gold_start_end,
|
| 246 |
):
|
| 247 |
+
overlap_len = get_overlap_len(pred_start_end, closest_gold_start_end)
|
|
|
|
|
|
|
|
|
|
| 248 |
l_max = max(
|
| 249 |
+
pred_start_end[1] - pred_start_end[0],
|
| 250 |
+
closest_gold_start_end[1] - closest_gold_start_end[0],
|
| 251 |
)
|
| 252 |
# if the overlap is at least half of the maximum length, we consider it a valid match for alignment
|
| 253 |
valid_match = overlap_len >= (l_max / 2)
|
|
|
|
| 255 |
valid_match = False
|
| 256 |
|
| 257 |
if valid_match:
|
| 258 |
+
if isinstance(pred_span, Span):
|
| 259 |
+
aligned_pred_span = pred_span.copy(
|
| 260 |
+
start=closest_gold_span.start, end=closest_gold_span.end
|
| 261 |
+
)
|
| 262 |
+
elif isinstance(pred_span, MultiSpan):
|
| 263 |
+
aligned_pred_span = pred_span.copy(slices=closest_gold_span.slices)
|
| 264 |
+
else:
|
| 265 |
+
raise ValueError(f"Unsupported span type: {type(pred_span)}")
|
| 266 |
old2new_pred_span[pred_span._id] = aligned_pred_span
|
| 267 |
span_id2gold_span[pred_span._id] = closest_gold_span
|
| 268 |
|
| 269 |
+
return old2new_pred_span, span_id2gold_span
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def get_spans2multi_spans_mapping(multi_spans: Iterable[MultiSpan]) -> Dict[Span, MultiSpan]:
|
| 273 |
+
result = {}
|
| 274 |
+
for multi_span in multi_spans:
|
| 275 |
+
for start, end in multi_span.slices:
|
| 276 |
+
span_kwargs = dict(start=start, end=end, score=multi_span.score)
|
| 277 |
+
if isinstance(multi_span, LabeledMultiSpan):
|
| 278 |
+
result[LabeledSpan(label=multi_span.label, **span_kwargs)] = multi_span
|
| 279 |
+
else:
|
| 280 |
+
result[Span(**span_kwargs)] = multi_span
|
| 281 |
+
|
| 282 |
+
return result
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def align_predicted_span_annotations(
|
| 286 |
+
document: DWithSpans,
|
| 287 |
+
span_layer: str,
|
| 288 |
+
distance_type: str = "center",
|
| 289 |
+
simple_multi_span: bool = False,
|
| 290 |
+
verbose: bool = False,
|
| 291 |
+
) -> DWithSpans:
|
| 292 |
+
"""
|
| 293 |
+
Aligns predicted span annotations with the closest gold spans in a document.
|
| 294 |
+
|
| 295 |
+
First, calculates the distance between each predicted span and each gold span. Then,
|
| 296 |
+
for each predicted span, the gold span with the smallest distance is selected. If the
|
| 297 |
+
predicted span and the gold span have an overlap of at least half of the maximum length
|
| 298 |
+
of the two spans, the predicted span is aligned with the gold span.
|
| 299 |
+
|
| 300 |
+
This also works for MultiSpan annotations, where the slices of the MultiSpan are used
|
| 301 |
+
to align the predicted spans. If any of the slices is aligned with a gold slice,
|
| 302 |
+
the MultiSpan is aligned with the respective gold MultiSpan. However, this may result in
|
| 303 |
+
the predicted MultiSpan being aligned with multiple gold MultiSpans, in which case the
|
| 304 |
+
closest gold MultiSpan is selected. A simplified version of this alignment can be achieved
|
| 305 |
+
by setting `simple_multi_span=True`, which treats MultiSpan annotations as simple Spans
|
| 306 |
+
by using their maximum and minimum start and end indices.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
document: The document to process.
|
| 310 |
+
span_layer: The name of the span layer.
|
| 311 |
+
distance_type: The type of distance to calculate. One of: center, inner, outer
|
| 312 |
+
simple_multi_span: Whether to treat MultiSpan annotations as simple Spans by using their
|
| 313 |
+
maximum and minimum start and end indices.
|
| 314 |
+
verbose: Whether to print debug information.
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
The processed document.
|
| 318 |
+
"""
|
| 319 |
+
gold_spans = document[span_layer]
|
| 320 |
+
if len(gold_spans) == 0:
|
| 321 |
+
return document.copy()
|
| 322 |
+
|
| 323 |
+
pred_spans = document[span_layer].predictions
|
| 324 |
+
span_annotation_type = document.annotation_types()[span_layer]
|
| 325 |
+
if issubclass(span_annotation_type, Span) or simple_multi_span:
|
| 326 |
+
old2new_pred_span, span_id2gold_span = _get_aligned_span_mappings(
|
| 327 |
+
gold_spans=gold_spans, pred_spans=pred_spans, distance_type=distance_type
|
| 328 |
+
)
|
| 329 |
+
elif issubclass(span_annotation_type, MultiSpan):
|
| 330 |
+
# create Span objects from MultiSpan slices
|
| 331 |
+
gold_single_spans2multi_spans = get_spans2multi_spans_mapping(gold_spans)
|
| 332 |
+
pred_single_spans2multi_spans = get_spans2multi_spans_mapping(pred_spans)
|
| 333 |
+
# create the alignment mappings for the single spans
|
| 334 |
+
single_old2new_pred_span, single_span_id2gold_span = _get_aligned_span_mappings(
|
| 335 |
+
gold_spans=gold_single_spans2multi_spans.keys(),
|
| 336 |
+
pred_spans=pred_single_spans2multi_spans.keys(),
|
| 337 |
+
distance_type=distance_type,
|
| 338 |
+
)
|
| 339 |
+
# collect all Spans that are part of the same MultiSpan
|
| 340 |
+
pred_multi_span2single_spans: Dict[MultiSpan, List[Span]] = defaultdict(list)
|
| 341 |
+
for pred_span, multi_span in pred_single_spans2multi_spans.items():
|
| 342 |
+
pred_multi_span2single_spans[multi_span].append(pred_span)
|
| 343 |
+
|
| 344 |
+
# create the new mappings for the MultiSpans
|
| 345 |
+
old2new_pred_span = {}
|
| 346 |
+
span_id2gold_span = {}
|
| 347 |
+
for pred_multi_span, pred_single_spans in pred_multi_span2single_spans.items():
|
| 348 |
+
# if any of the single spans is aligned with a gold span, align the multi span
|
| 349 |
+
if any(
|
| 350 |
+
pred_single_span._id in single_old2new_pred_span
|
| 351 |
+
for pred_single_span in pred_single_spans
|
| 352 |
+
):
|
| 353 |
+
# get aligned gold multi spans
|
| 354 |
+
aligned_gold_multi_spans = set()
|
| 355 |
+
for pred_single_span in pred_single_spans:
|
| 356 |
+
if pred_single_span._id in single_old2new_pred_span:
|
| 357 |
+
aligned_gold_single_span = single_span_id2gold_span[pred_single_span._id]
|
| 358 |
+
aligned_gold_multi_span = gold_single_spans2multi_spans[
|
| 359 |
+
aligned_gold_single_span
|
| 360 |
+
]
|
| 361 |
+
aligned_gold_multi_spans.add(aligned_gold_multi_span)
|
| 362 |
+
|
| 363 |
+
# calculate distances between the predicted multi span and the aligned gold multi spans
|
| 364 |
+
gold_multi_spans_with_distance = [
|
| 365 |
+
(
|
| 366 |
+
gold_multi_span,
|
| 367 |
+
distance_slices(
|
| 368 |
+
slices=pred_multi_span.slices,
|
| 369 |
+
other_slices=gold_multi_span.slices,
|
| 370 |
+
distance_type=distance_type,
|
| 371 |
+
),
|
| 372 |
+
)
|
| 373 |
+
for gold_multi_span in aligned_gold_multi_spans
|
| 374 |
+
]
|
| 375 |
+
|
| 376 |
+
if len(aligned_gold_multi_spans) > 1:
|
| 377 |
+
logger.warning(
|
| 378 |
+
f"Multiple gold multi spans aligned with predicted multi span ({pred_multi_span}): "
|
| 379 |
+
f"{aligned_gold_multi_spans}"
|
| 380 |
+
)
|
| 381 |
+
# get the closest gold multi span
|
| 382 |
+
closest_gold_multi_span, min_distance = min(
|
| 383 |
+
gold_multi_spans_with_distance, key=lambda x: x[1]
|
| 384 |
+
)
|
| 385 |
+
old2new_pred_span[pred_multi_span._id] = pred_multi_span.copy(
|
| 386 |
+
slices=closest_gold_multi_span.slices
|
| 387 |
+
)
|
| 388 |
+
span_id2gold_span[pred_multi_span._id] = closest_gold_multi_span
|
| 389 |
+
else:
|
| 390 |
+
raise ValueError(f"Unsupported span annotation type: {span_annotation_type}")
|
| 391 |
+
|
| 392 |
result = document.copy(with_annotations=False)
|
| 393 |
|
| 394 |
# multiple predicted spans can be aligned with the same gold span,
|
|
|
|
| 427 |
)
|
| 428 |
|
| 429 |
return result
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def add_related_relations_from_binary_relations(
|
| 433 |
+
document: TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
|
| 434 |
+
link_relation_label: str,
|
| 435 |
+
link_partition_whitelist: Optional[List[List[str]]] = None,
|
| 436 |
+
relation_label_whitelist: Optional[List[str]] = None,
|
| 437 |
+
reversed_relation_suffix: str = "_reversed",
|
| 438 |
+
symmetric_relations: Optional[List[str]] = None,
|
| 439 |
+
) -> TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations:
|
| 440 |
+
span2partition = {}
|
| 441 |
+
for multi_span in document.labeled_multi_spans:
|
| 442 |
+
found_partition = False
|
| 443 |
+
for partition in document.labeled_partitions or [
|
| 444 |
+
LabeledSpan(start=0, end=len(document.text), label="ALL")
|
| 445 |
+
]:
|
| 446 |
+
starts, ends = zip(*multi_span.slices)
|
| 447 |
+
if partition.start <= min(starts) and max(ends) <= partition.end:
|
| 448 |
+
span2partition[multi_span] = partition
|
| 449 |
+
found_partition = True
|
| 450 |
+
break
|
| 451 |
+
if not found_partition:
|
| 452 |
+
raise ValueError(f"No partition found for multi_span {multi_span}")
|
| 453 |
+
|
| 454 |
+
rel_head2rels = defaultdict(list)
|
| 455 |
+
rel_tail2rels = defaultdict(list)
|
| 456 |
+
for rel in document.binary_relations:
|
| 457 |
+
rel_head2rels[rel.head].append(rel)
|
| 458 |
+
rel_tail2rels[rel.tail].append(rel)
|
| 459 |
+
|
| 460 |
+
link_partition_whitelist_tuples = None
|
| 461 |
+
if link_partition_whitelist is not None:
|
| 462 |
+
link_partition_whitelist_tuples = {tuple(pair) for pair in link_partition_whitelist}
|
| 463 |
+
|
| 464 |
+
skipped_labels = []
|
| 465 |
+
for link_rel in document.binary_relations:
|
| 466 |
+
if link_rel.label == link_relation_label:
|
| 467 |
+
head_partition = span2partition[link_rel.head]
|
| 468 |
+
tail_partition = span2partition[link_rel.tail]
|
| 469 |
+
if link_partition_whitelist_tuples is None or (
|
| 470 |
+
(head_partition.label, tail_partition.label) in link_partition_whitelist_tuples
|
| 471 |
+
):
|
| 472 |
+
# link_head -> link_tail == rel_head -> rel_tail
|
| 473 |
+
for rel in rel_head2rels.get(link_rel.tail, []):
|
| 474 |
+
label = rel.label
|
| 475 |
+
if relation_label_whitelist is None or label in relation_label_whitelist:
|
| 476 |
+
new_rel = RelatedRelation(
|
| 477 |
+
head=link_rel.head,
|
| 478 |
+
tail=rel.tail,
|
| 479 |
+
link_relation=link_rel,
|
| 480 |
+
relation=rel,
|
| 481 |
+
label=label,
|
| 482 |
+
)
|
| 483 |
+
document.related_relations.append(new_rel)
|
| 484 |
+
else:
|
| 485 |
+
skipped_labels.append(label)
|
| 486 |
+
|
| 487 |
+
# link_head -> link_tail == rel_tail -> rel_head
|
| 488 |
+
if reversed_relation_suffix is not None:
|
| 489 |
+
for reversed_rel in rel_tail2rels.get(link_rel.tail, []):
|
| 490 |
+
label = reversed_rel.label
|
| 491 |
+
if not (symmetric_relations is not None and label in symmetric_relations):
|
| 492 |
+
label = f"{label}{reversed_relation_suffix}"
|
| 493 |
+
if relation_label_whitelist is None or label in relation_label_whitelist:
|
| 494 |
+
new_rel = RelatedRelation(
|
| 495 |
+
head=link_rel.head,
|
| 496 |
+
tail=reversed_rel.head,
|
| 497 |
+
link_relation=link_rel,
|
| 498 |
+
relation=reversed_rel,
|
| 499 |
+
label=label,
|
| 500 |
+
)
|
| 501 |
+
document.related_relations.append(new_rel)
|
| 502 |
+
else:
|
| 503 |
+
skipped_labels.append(label)
|
| 504 |
+
|
| 505 |
+
else:
|
| 506 |
+
logger.warning(
|
| 507 |
+
f"Skipping related relation because of partition whitelist ({[head_partition.label, tail_partition.label]}): {link_rel.resolve()}"
|
| 508 |
+
)
|
| 509 |
+
if len(skipped_labels) > 0:
|
| 510 |
+
logger.warning(
|
| 511 |
+
f"Skipped relations with labels not in whitelist: {sorted(set(skipped_labels))}"
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
return document
|
src/document/types.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
|
| 3 |
+
from pytorch_ie import AnnotationLayer, annotation_field
|
| 4 |
+
from pytorch_ie.annotations import BinaryRelation
|
| 5 |
+
from pytorch_ie.documents import (
|
| 6 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclasses.dataclass(eq=True, frozen=True)
|
| 11 |
+
class RelatedRelation(BinaryRelation):
|
| 12 |
+
link_relation: BinaryRelation = dataclasses.field(default=None, compare=False)
|
| 13 |
+
relation: BinaryRelation = dataclasses.field(default=None, compare=False)
|
| 14 |
+
|
| 15 |
+
def __post_init__(self):
|
| 16 |
+
super().__post_init__()
|
| 17 |
+
# check if the reference_span is correct
|
| 18 |
+
self.reference_span
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def reference_span(self):
|
| 22 |
+
if self.link_relation is None:
|
| 23 |
+
raise ValueError(
|
| 24 |
+
"No semantically_same_relation available, cannot return reference_span"
|
| 25 |
+
)
|
| 26 |
+
if self.link_relation.head == self.head:
|
| 27 |
+
return self.link_relation.tail
|
| 28 |
+
elif self.link_relation.tail == self.head:
|
| 29 |
+
return self.link_relation.head
|
| 30 |
+
elif self.link_relation.head == self.tail:
|
| 31 |
+
return self.link_relation.tail
|
| 32 |
+
elif self.link_relation.tail == self.tail:
|
| 33 |
+
return self.link_relation.head
|
| 34 |
+
else:
|
| 35 |
+
raise ValueError(
|
| 36 |
+
"The semantically_same_relation is neither linked to head nor tail of the current relation"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclasses.dataclass
|
| 41 |
+
class TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations(
|
| 42 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
| 43 |
+
):
|
| 44 |
+
related_relations: AnnotationLayer[RelatedRelation] = annotation_field(
|
| 45 |
+
targets=["labeled_multi_spans", "binary_relations"]
|
| 46 |
+
)
|
src/metrics/__init__.py
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
from .coref_sklearn import CorefMetricsSKLearn
|
| 2 |
from .coref_torchmetrics import CorefMetricsTorchmetrics
|
|
|
|
|
|
| 1 |
from .coref_sklearn import CorefMetricsSKLearn
|
| 2 |
from .coref_torchmetrics import CorefMetricsTorchmetrics
|
| 3 |
+
from .score_distribution import ScoreDistribution
|
src/metrics/score_distribution.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from pytorch_ie import Document, DocumentMetric
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ScoreDistribution(DocumentMetric):
|
| 9 |
+
"""Computes the distribution of prediction scores for annotations in a layer. The scores are
|
| 10 |
+
separated into true positives (TP) and false positives (FP) based on the gold annotations.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
layer: The name of the annotation layer to analyze.
|
| 14 |
+
per_label: If True, the scores are separated per label. Default is False.
|
| 15 |
+
label_field: The field name of the label to use for separating the scores per label. Default is "label".
|
| 16 |
+
equal_sample_size_binning: If True, the scores are binned into equal sample sizes. If False,
|
| 17 |
+
the scores are binned into equal width. The former is useful when the distribution of scores is skewed.
|
| 18 |
+
Default is True.
|
| 19 |
+
show_plot: If True, a plot of the score distribution is shown. Default is False.
|
| 20 |
+
plotting_backend: The plotting backend to use. Default is "plotly".
|
| 21 |
+
plotting_caption_mapping: A mapping to rename any caption entries for plotting, i.e., the layer name,
|
| 22 |
+
labels, or TP/FP. Default is None.
|
| 23 |
+
plotting_colors: A dictionary mapping from gold scores to colors for plotting. Default is None.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
layer: str,
|
| 29 |
+
label_field: str = "label",
|
| 30 |
+
per_label: bool = False,
|
| 31 |
+
show_plot: bool = False,
|
| 32 |
+
equal_sample_size_binning: bool = True,
|
| 33 |
+
plotting_backend: str = "plotly",
|
| 34 |
+
plotting_caption_mapping: Optional[Dict[str, str]] = None,
|
| 35 |
+
plotting_colors: Optional[Dict[str, str]] = None,
|
| 36 |
+
plotly_use_create_distplot: bool = True,
|
| 37 |
+
plotly_barmode: Optional[str] = None,
|
| 38 |
+
plotly_marginal: Optional[str] = "violin",
|
| 39 |
+
plotly_font_size: int = 18,
|
| 40 |
+
plotly_font_family: Optional[str] = None,
|
| 41 |
+
plotly_background_color: Optional[str] = None,
|
| 42 |
+
):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.layer = layer
|
| 45 |
+
self.label_field = label_field
|
| 46 |
+
self.per_label = per_label
|
| 47 |
+
self.equal_sample_size_binning = equal_sample_size_binning
|
| 48 |
+
self.plotting_backend = plotting_backend
|
| 49 |
+
self.show_plot = show_plot
|
| 50 |
+
self.plotting_caption_mapping = plotting_caption_mapping or {}
|
| 51 |
+
self.plotting_colors = plotting_colors
|
| 52 |
+
self.plotly_use_create_distplot = plotly_use_create_distplot
|
| 53 |
+
self.plotly_barmode = plotly_barmode
|
| 54 |
+
self.plotly_marginal = plotly_marginal
|
| 55 |
+
self.plotly_font_size = plotly_font_size
|
| 56 |
+
self.plotly_font_family = plotly_font_family
|
| 57 |
+
self.plotly_background_color = plotly_background_color
|
| 58 |
+
self.scores: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
|
| 59 |
+
|
| 60 |
+
def reset(self):
|
| 61 |
+
self.scores = defaultdict(lambda: defaultdict(list))
|
| 62 |
+
|
| 63 |
+
def _update(self, document: Document):
|
| 64 |
+
|
| 65 |
+
gold_annotations = set(document[self.layer])
|
| 66 |
+
for ann in document[self.layer].predictions:
|
| 67 |
+
if self.per_label:
|
| 68 |
+
label = getattr(ann, self.label_field)
|
| 69 |
+
else:
|
| 70 |
+
label = "ALL"
|
| 71 |
+
if ann in gold_annotations:
|
| 72 |
+
self.scores[label]["TP"].append(ann.score)
|
| 73 |
+
else:
|
| 74 |
+
self.scores[label]["FP"].append(ann.score)
|
| 75 |
+
|
| 76 |
+
def _combine_scores(
|
| 77 |
+
self,
|
| 78 |
+
scores_tp: List[float],
|
| 79 |
+
score_fp: List[float],
|
| 80 |
+
col_name_pred: str = "prediction",
|
| 81 |
+
col_name_gold: str = "gold",
|
| 82 |
+
) -> pd.DataFrame:
|
| 83 |
+
scores_tp_df = pd.DataFrame(scores_tp, columns=[col_name_pred])
|
| 84 |
+
scores_tp_df[col_name_gold] = 1.0
|
| 85 |
+
scores_fp_df = pd.DataFrame(score_fp, columns=[col_name_pred])
|
| 86 |
+
scores_fp_df[col_name_gold] = 0.0
|
| 87 |
+
scores_df = pd.concat([scores_tp_df, scores_fp_df])
|
| 88 |
+
return scores_df
|
| 89 |
+
|
| 90 |
+
def _get_calibration_data_and_metrics(
|
| 91 |
+
self, scores: pd.DataFrame, q: int = 20
|
| 92 |
+
) -> Tuple[pd.DataFrame, pd.Series]:
|
| 93 |
+
from sklearn.metrics import brier_score_loss
|
| 94 |
+
|
| 95 |
+
if self.equal_sample_size_binning:
|
| 96 |
+
# Create bins with equal number of samples.
|
| 97 |
+
scores["bin"] = pd.qcut(scores["prediction"], q=q, labels=False)
|
| 98 |
+
else:
|
| 99 |
+
# Create bins with equal width.
|
| 100 |
+
scores["bin"] = pd.cut(
|
| 101 |
+
scores["prediction"],
|
| 102 |
+
bins=q,
|
| 103 |
+
include_lowest=True,
|
| 104 |
+
right=True,
|
| 105 |
+
labels=False,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
calibration_data = (
|
| 109 |
+
scores.groupby("bin")
|
| 110 |
+
.apply(
|
| 111 |
+
lambda x: pd.Series(
|
| 112 |
+
{
|
| 113 |
+
"avg_score": x["prediction"].mean(),
|
| 114 |
+
"fraction_positive": x["gold"].mean(),
|
| 115 |
+
"count": len(x),
|
| 116 |
+
}
|
| 117 |
+
)
|
| 118 |
+
)
|
| 119 |
+
.reset_index()
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
total_count = scores.shape[0]
|
| 123 |
+
calibration_data["bin_weight"] = calibration_data["count"] / total_count
|
| 124 |
+
|
| 125 |
+
# Calculate the absolute differences and squared differences.
|
| 126 |
+
calibration_data["abs_diff"] = abs(
|
| 127 |
+
calibration_data["avg_score"] - calibration_data["fraction_positive"]
|
| 128 |
+
)
|
| 129 |
+
calibration_data["squared_diff"] = (
|
| 130 |
+
calibration_data["avg_score"] - calibration_data["fraction_positive"]
|
| 131 |
+
) ** 2
|
| 132 |
+
|
| 133 |
+
# Compute Expected Calibration Error (ECE): weighted average of absolute differences.
|
| 134 |
+
ece = (calibration_data["abs_diff"] * calibration_data["bin_weight"]).sum()
|
| 135 |
+
|
| 136 |
+
# Compute Maximum Calibration Error (MCE): maximum absolute difference.
|
| 137 |
+
mce = calibration_data["abs_diff"].max()
|
| 138 |
+
|
| 139 |
+
# Compute Mean Squared Error (MSE): weighted average of squared differences.
|
| 140 |
+
mse = (calibration_data["squared_diff"] * calibration_data["bin_weight"]).sum()
|
| 141 |
+
|
| 142 |
+
# Compute the Brier Score on the raw predictions.
|
| 143 |
+
brier = brier_score_loss(scores["gold"], scores["prediction"])
|
| 144 |
+
|
| 145 |
+
values = {
|
| 146 |
+
"ece": ece,
|
| 147 |
+
"mce": mce,
|
| 148 |
+
"mse": mse,
|
| 149 |
+
"brier": brier,
|
| 150 |
+
}
|
| 151 |
+
return calibration_data, pd.Series(values)
|
| 152 |
+
|
| 153 |
+
def calculate_calibration_metrics(self, scores_combined: pd.DataFrame) -> pd.DataFrame:
|
| 154 |
+
|
| 155 |
+
calibration_data_dict = {}
|
| 156 |
+
calibration_metrics_dict = {}
|
| 157 |
+
for label, current_scores in scores_combined.groupby("label"):
|
| 158 |
+
calibration_data, calibration_metrics = self._get_calibration_data_and_metrics(
|
| 159 |
+
current_scores, q=20
|
| 160 |
+
)
|
| 161 |
+
calibration_data_dict[label] = calibration_data
|
| 162 |
+
calibration_metrics_dict[label] = calibration_metrics
|
| 163 |
+
all_calibration_data = pd.concat(
|
| 164 |
+
calibration_data_dict, names=["label", "idx"]
|
| 165 |
+
).reset_index(level=0)
|
| 166 |
+
all_calibration_metrics = pd.concat(calibration_metrics_dict, axis=1).T
|
| 167 |
+
|
| 168 |
+
if self.show_plot:
|
| 169 |
+
self.plot_calibration_data(calibration_data=all_calibration_data)
|
| 170 |
+
|
| 171 |
+
return all_calibration_metrics
|
| 172 |
+
|
| 173 |
+
def calculate_correlation(self, scores: pd.DataFrame) -> pd.Series:
|
| 174 |
+
result_dict = {}
|
| 175 |
+
for label, current_scores in scores.groupby("label"):
|
| 176 |
+
result_dict[label] = current_scores.drop("label", axis=1).corr()["prediction"]["gold"]
|
| 177 |
+
|
| 178 |
+
return pd.Series(result_dict, name="correlation")
|
| 179 |
+
|
| 180 |
+
@property
|
| 181 |
+
def mapped_layer(self):
|
| 182 |
+
return self.plotting_caption_mapping.get(self.layer, self.layer)
|
| 183 |
+
|
| 184 |
+
def plot_score_distribution(self, scores: pd.DataFrame):
|
| 185 |
+
if self.plotting_backend == "plotly":
|
| 186 |
+
for label in scores["label"].unique():
|
| 187 |
+
description = f"Distribution of Predicted Scores for {self.mapped_layer}"
|
| 188 |
+
if self.per_label:
|
| 189 |
+
label_mapped = self.plotting_caption_mapping.get(label, label)
|
| 190 |
+
description += f" ({label_mapped})"
|
| 191 |
+
if self.plotly_use_create_distplot:
|
| 192 |
+
import plotly.figure_factory as ff
|
| 193 |
+
|
| 194 |
+
current_scores = scores[scores["label"] == label]
|
| 195 |
+
# group by gold score
|
| 196 |
+
scores_dict = (
|
| 197 |
+
current_scores.groupby("gold")["prediction"].apply(list).to_dict()
|
| 198 |
+
)
|
| 199 |
+
group_labels, hist_data = zip(*scores_dict.items())
|
| 200 |
+
group_labels_renamed = [
|
| 201 |
+
self.plotting_caption_mapping.get(label, label) for label in group_labels
|
| 202 |
+
]
|
| 203 |
+
if self.plotting_colors is not None:
|
| 204 |
+
colors = [
|
| 205 |
+
self.plotting_colors[group_label] for group_label in group_labels
|
| 206 |
+
]
|
| 207 |
+
else:
|
| 208 |
+
colors = None
|
| 209 |
+
fig = ff.create_distplot(
|
| 210 |
+
hist_data,
|
| 211 |
+
group_labels=group_labels_renamed,
|
| 212 |
+
show_hist=True,
|
| 213 |
+
colors=colors,
|
| 214 |
+
bin_size=0.025,
|
| 215 |
+
)
|
| 216 |
+
else:
|
| 217 |
+
import plotly.express as px
|
| 218 |
+
|
| 219 |
+
fig = px.histogram(
|
| 220 |
+
scores,
|
| 221 |
+
x="prediction",
|
| 222 |
+
color="gold",
|
| 223 |
+
marginal=self.plotly_marginal, # "violin", # or box, violin, rug
|
| 224 |
+
hover_data=scores.columns,
|
| 225 |
+
color_discrete_map=self.plotting_colors,
|
| 226 |
+
nbins=50,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
fig.update_layout(
|
| 230 |
+
height=600,
|
| 231 |
+
width=800,
|
| 232 |
+
title_text=description,
|
| 233 |
+
title_x=0.5,
|
| 234 |
+
font=dict(size=self.plotly_font_size),
|
| 235 |
+
legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
|
| 236 |
+
)
|
| 237 |
+
if self.plotly_barmode is not None:
|
| 238 |
+
fig.update_layout(barmode=self.plotly_barmode)
|
| 239 |
+
if self.plotly_font_family is not None:
|
| 240 |
+
fig.update_layout(font_family=self.plotly_font_family)
|
| 241 |
+
if self.plotly_background_color is not None:
|
| 242 |
+
fig.update_layout(
|
| 243 |
+
plot_bgcolor=self.plotly_background_color,
|
| 244 |
+
paper_bgcolor=self.plotly_background_color,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
fig.show()
|
| 248 |
+
else:
|
| 249 |
+
raise NotImplementedError(f"Plotting backend {self.plotting_backend} not implemented")
|
| 250 |
+
|
| 251 |
+
def plot_calibration_data(self, calibration_data: pd.DataFrame):
|
| 252 |
+
import plotly.express as px
|
| 253 |
+
import plotly.graph_objects as go
|
| 254 |
+
|
| 255 |
+
color = "label" if self.per_label else None
|
| 256 |
+
x_col = "avg_score"
|
| 257 |
+
y_col = "fraction_positive"
|
| 258 |
+
fig = px.scatter(
|
| 259 |
+
calibration_data,
|
| 260 |
+
x=x_col,
|
| 261 |
+
y=y_col,
|
| 262 |
+
color=color,
|
| 263 |
+
trendline="ols",
|
| 264 |
+
labels=self.plotting_caption_mapping,
|
| 265 |
+
)
|
| 266 |
+
if not self.per_label:
|
| 267 |
+
fig["data"][1]["name"] = "prediction vs. gold"
|
| 268 |
+
|
| 269 |
+
# show legend only for trendlines
|
| 270 |
+
for idx, trace_data in enumerate(fig["data"]):
|
| 271 |
+
if idx % 2 == 0:
|
| 272 |
+
trace_data["showlegend"] = False
|
| 273 |
+
else:
|
| 274 |
+
trace_data["showlegend"] = True
|
| 275 |
+
|
| 276 |
+
# add the optimal line
|
| 277 |
+
minimum = calibration_data[x_col].min()
|
| 278 |
+
maximum = calibration_data[x_col].max()
|
| 279 |
+
fig.add_trace(
|
| 280 |
+
go.Scatter(
|
| 281 |
+
x=[minimum, maximum],
|
| 282 |
+
y=[minimum, maximum],
|
| 283 |
+
mode="lines",
|
| 284 |
+
name="optimal",
|
| 285 |
+
line=dict(color="black", dash="dash"),
|
| 286 |
+
)
|
| 287 |
+
)
|
| 288 |
+
fig.update_layout(
|
| 289 |
+
height=600,
|
| 290 |
+
width=800,
|
| 291 |
+
title_text=f"Mean Binned Scores for {self.mapped_layer}",
|
| 292 |
+
title_x=0.5,
|
| 293 |
+
font=dict(size=self.plotly_font_size),
|
| 294 |
+
)
|
| 295 |
+
fig.update_layout(
|
| 296 |
+
legend=dict(
|
| 297 |
+
yanchor="top",
|
| 298 |
+
y=0.99,
|
| 299 |
+
xanchor="left",
|
| 300 |
+
x=0.01,
|
| 301 |
+
title="OLS trendline" + ("s" if self.per_label else ""),
|
| 302 |
+
),
|
| 303 |
+
)
|
| 304 |
+
if self.plotly_background_color is not None:
|
| 305 |
+
fig.update_layout(
|
| 306 |
+
plot_bgcolor=self.plotly_background_color,
|
| 307 |
+
paper_bgcolor=self.plotly_background_color,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
if self.plotly_font_family is not None:
|
| 311 |
+
fig.update_layout(font_family=self.plotly_font_family)
|
| 312 |
+
|
| 313 |
+
fig.show()
|
| 314 |
+
|
| 315 |
+
def _compute(self) -> Dict[str, Dict[str, Any]]:
|
| 316 |
+
scores_combined = pd.concat(
|
| 317 |
+
{
|
| 318 |
+
label: self._combine_scores(scores["TP"], scores["FP"])
|
| 319 |
+
for label, scores in self.scores.items()
|
| 320 |
+
},
|
| 321 |
+
names=["label", "idx"],
|
| 322 |
+
).reset_index(level=0)
|
| 323 |
+
|
| 324 |
+
result_df = scores_combined.groupby("label")["prediction"].agg(["mean", "std", "count"])
|
| 325 |
+
if self.show_plot:
|
| 326 |
+
self.plot_score_distribution(scores=scores_combined)
|
| 327 |
+
|
| 328 |
+
calibration_metrics = self.calculate_calibration_metrics(scores_combined)
|
| 329 |
+
calibration_metrics["correlation"] = self.calculate_correlation(scores_combined)
|
| 330 |
+
|
| 331 |
+
result_df = pd.concat(
|
| 332 |
+
{"prediction": result_df, "prediction vs. gold": calibration_metrics}, axis=1
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if not self.per_label:
|
| 336 |
+
result = result_df.xs("ALL")
|
| 337 |
+
else:
|
| 338 |
+
result = result_df.T.stack().unstack()
|
| 339 |
+
|
| 340 |
+
result_dict = {
|
| 341 |
+
main_key: result.xs(main_key).T.to_dict()
|
| 342 |
+
for main_key in result.index.get_level_values(0).unique()
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
return result_dict
|
src/models/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from .sequence_classification_with_pooler import (
|
| 2 |
SequencePairSimilarityModelWithMaxCosineSim,
|
| 3 |
SequencePairSimilarityModelWithPooler2,
|
|
|
|
| 1 |
+
from .sequence_classification import SimpleSequenceClassificationModelWithInputTypeIds
|
| 2 |
from .sequence_classification_with_pooler import (
|
| 3 |
SequencePairSimilarityModelWithMaxCosineSim,
|
| 4 |
SequencePairSimilarityModelWithPooler2,
|
src/models/sequence_classification.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from pie_modules.models import SimpleSequenceClassificationModel
|
| 4 |
+
from pie_modules.models.simple_sequence_classification import InputType, OutputType, TargetType
|
| 5 |
+
from pytorch_ie import PyTorchIEModel
|
| 6 |
+
from torch import nn
|
| 7 |
+
from transformers import BertModel
|
| 8 |
+
from transformers.utils import is_accelerate_available
|
| 9 |
+
|
| 10 |
+
if is_accelerate_available():
|
| 11 |
+
from accelerate.hooks import add_hook_to_module
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@PyTorchIEModel.register()
|
| 15 |
+
class SimpleSequenceClassificationModelWithInputTypeIds(SimpleSequenceClassificationModel):
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self, num_token_type_ids: int, use_as_token_type_ids: str = "token_type_ids", **kwargs
|
| 19 |
+
):
|
| 20 |
+
super().__init__(**kwargs)
|
| 21 |
+
self.num_token_type_ids = num_token_type_ids
|
| 22 |
+
self.token_type_ids_key = use_as_token_type_ids
|
| 23 |
+
self.resize_type_embeddings(num_token_type_ids)
|
| 24 |
+
|
| 25 |
+
def get_input_type_embeddings(self) -> nn.Module:
|
| 26 |
+
base_model: BertModel = getattr(self.model, self.model.base_model_prefix)
|
| 27 |
+
if base_model is None:
|
| 28 |
+
raise ValueError("Model has no base model.")
|
| 29 |
+
return base_model.embeddings.token_type_embeddings
|
| 30 |
+
|
| 31 |
+
def set_input_type_embeddings(self, value):
|
| 32 |
+
base_model: BertModel = getattr(self.model, self.model.base_model_prefix)
|
| 33 |
+
if base_model is None:
|
| 34 |
+
raise ValueError("Model has no base model.")
|
| 35 |
+
base_model.embeddings.token_type_embeddings = value
|
| 36 |
+
|
| 37 |
+
def _resize_type_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
|
| 38 |
+
old_embeddings = self.get_input_type_embeddings()
|
| 39 |
+
new_embeddings = self.model._get_resized_embeddings(
|
| 40 |
+
old_embeddings, new_num_tokens, pad_to_multiple_of
|
| 41 |
+
)
|
| 42 |
+
if hasattr(old_embeddings, "_hf_hook"):
|
| 43 |
+
hook = old_embeddings._hf_hook
|
| 44 |
+
add_hook_to_module(new_embeddings, hook)
|
| 45 |
+
old_embeddings_requires_grad = old_embeddings.weight.requires_grad
|
| 46 |
+
new_embeddings.requires_grad_(old_embeddings_requires_grad)
|
| 47 |
+
self.set_input_type_embeddings(new_embeddings)
|
| 48 |
+
|
| 49 |
+
return self.get_input_type_embeddings()
|
| 50 |
+
|
| 51 |
+
def resize_type_embeddings(
|
| 52 |
+
self, new_num_types: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
|
| 53 |
+
) -> nn.Embedding:
|
| 54 |
+
"""
|
| 55 |
+
Same as resize_token_embeddings but for the token type embeddings.
|
| 56 |
+
|
| 57 |
+
Resizes input token type embeddings matrix of the model if `new_num_types != config.type_vocab_size`.
|
| 58 |
+
|
| 59 |
+
Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
| 60 |
+
|
| 61 |
+
Arguments:
|
| 62 |
+
new_num_types (`int`, *optional*):
|
| 63 |
+
The number of new token types in the embedding matrix. Increasing the size will add newly initialized
|
| 64 |
+
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
|
| 65 |
+
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
|
| 66 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 67 |
+
If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to
|
| 68 |
+
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
|
| 69 |
+
|
| 70 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
| 71 |
+
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
|
| 72 |
+
details about this, or help on choosing the correct value for resizing, refer to this guide:
|
| 73 |
+
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
|
| 74 |
+
|
| 75 |
+
Return:
|
| 76 |
+
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
|
| 77 |
+
"""
|
| 78 |
+
model_embeds = self._resize_type_embeddings(new_num_types, pad_to_multiple_of)
|
| 79 |
+
if new_num_types is None and pad_to_multiple_of is None:
|
| 80 |
+
return model_embeds
|
| 81 |
+
|
| 82 |
+
# Update base model and current model config
|
| 83 |
+
self.model.config.type_vocab_size = model_embeds.weight.shape[0]
|
| 84 |
+
|
| 85 |
+
# Tie weights again if needed
|
| 86 |
+
self.model.tie_weights()
|
| 87 |
+
|
| 88 |
+
return model_embeds
|
| 89 |
+
|
| 90 |
+
def forward(self, inputs: InputType, targets: Optional[TargetType] = None) -> OutputType:
|
| 91 |
+
kwargs = {**inputs, **(targets or {})}
|
| 92 |
+
# rename key to input_type_ids
|
| 93 |
+
kwargs["token_type_ids"] = kwargs.pop(self.token_type_ids_key)
|
| 94 |
+
return self.model(**kwargs)
|
src/pipeline/ner_re_pipeline.py
CHANGED
|
@@ -15,6 +15,7 @@ from typing import (
|
|
| 15 |
overload,
|
| 16 |
)
|
| 17 |
|
|
|
|
| 18 |
from pie_modules.utils import resolve_type
|
| 19 |
from pytorch_ie import AutoPipeline, WithDocumentTypeMixin
|
| 20 |
from pytorch_ie.core import Document
|
|
@@ -53,31 +54,105 @@ def move_annotations_to_predictions(doc: D, layer_names: List[str]) -> None:
|
|
| 53 |
doc[layer_name].predictions.extend(annotations)
|
| 54 |
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
def add_annotations_from_other_documents(
|
| 57 |
docs: Iterable[D],
|
| 58 |
other_docs: Sequence[Document],
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
| 79 |
else:
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
|
| 83 |
def process_pipeline_steps(
|
|
@@ -227,6 +302,9 @@ class NerRePipeline:
|
|
| 227 |
"re_add_gold_data": partial(
|
| 228 |
add_annotations_from_other_documents,
|
| 229 |
other_docs=original_docs,
|
|
|
|
|
|
|
|
|
|
| 230 |
layer_names=[self.entity_layer, self.relation_layer],
|
| 231 |
**self.processor_kwargs.get("re_add_gold_data", {}),
|
| 232 |
),
|
|
|
|
| 15 |
overload,
|
| 16 |
)
|
| 17 |
|
| 18 |
+
from pie_datasets import Dataset
|
| 19 |
from pie_modules.utils import resolve_type
|
| 20 |
from pytorch_ie import AutoPipeline, WithDocumentTypeMixin
|
| 21 |
from pytorch_ie.core import Document
|
|
|
|
| 54 |
doc[layer_name].predictions.extend(annotations)
|
| 55 |
|
| 56 |
|
| 57 |
+
def _add_annotations_from_other_document(
|
| 58 |
+
doc: D,
|
| 59 |
+
from_predictions: bool,
|
| 60 |
+
to_predictions: bool,
|
| 61 |
+
clear_before: bool,
|
| 62 |
+
other_doc: Optional[D] = None,
|
| 63 |
+
other_docs_dict: Optional[Dict[str, D]] = None,
|
| 64 |
+
layer_names: Optional[List[str]] = None,
|
| 65 |
+
) -> D:
|
| 66 |
+
if other_doc is None:
|
| 67 |
+
if other_docs_dict is None:
|
| 68 |
+
raise ValueError("Either other_doc or other_docs_dict must be provided")
|
| 69 |
+
other_doc = other_docs_dict.get(doc.id)
|
| 70 |
+
if other_doc is None:
|
| 71 |
+
logger.warning(f"Document with ID {doc.id} not found in other_docs")
|
| 72 |
+
return doc
|
| 73 |
+
|
| 74 |
+
# copy to not modify the input
|
| 75 |
+
other_doc_copy = type(other_doc).fromdict(other_doc.asdict())
|
| 76 |
+
|
| 77 |
+
if layer_names is None:
|
| 78 |
+
layer_names = [field.name for field in doc.annotation_fields()]
|
| 79 |
+
|
| 80 |
+
for layer_name in layer_names:
|
| 81 |
+
layer = doc[layer_name]
|
| 82 |
+
if to_predictions:
|
| 83 |
+
layer = layer.predictions
|
| 84 |
+
if clear_before:
|
| 85 |
+
layer.clear()
|
| 86 |
+
other_layer = other_doc_copy[layer_name]
|
| 87 |
+
if from_predictions:
|
| 88 |
+
other_layer = other_layer.predictions
|
| 89 |
+
other_annotations = list(other_layer)
|
| 90 |
+
other_layer.clear()
|
| 91 |
+
layer.extend(other_annotations)
|
| 92 |
+
|
| 93 |
+
return doc
|
| 94 |
+
|
| 95 |
+
|
| 96 |
def add_annotations_from_other_documents(
|
| 97 |
docs: Iterable[D],
|
| 98 |
other_docs: Sequence[Document],
|
| 99 |
+
get_other_doc_by_id: bool = False,
|
| 100 |
+
**kwargs,
|
| 101 |
+
) -> Sequence[D]:
|
| 102 |
+
other_id2doc = None
|
| 103 |
+
if get_other_doc_by_id:
|
| 104 |
+
other_id2doc = {doc.id: doc for doc in other_docs}
|
| 105 |
+
|
| 106 |
+
if isinstance(docs, Dataset):
|
| 107 |
+
if other_id2doc is None:
|
| 108 |
+
raise ValueError("get_other_doc_by_id must be True when passing a Dataset")
|
| 109 |
+
result = docs.map(
|
| 110 |
+
_add_annotations_from_other_document,
|
| 111 |
+
fn_kwargs=dict(other_docs_dict=other_id2doc, **kwargs),
|
| 112 |
+
)
|
| 113 |
+
elif isinstance(docs, list):
|
| 114 |
+
result = []
|
| 115 |
+
for i, doc in enumerate(docs):
|
| 116 |
+
if other_id2doc is not None:
|
| 117 |
+
other_doc = other_id2doc.get(doc.id)
|
| 118 |
+
if other_doc is None:
|
| 119 |
+
logger.warning(f"Document with ID {doc.id} not found in other_docs")
|
| 120 |
+
continue
|
| 121 |
else:
|
| 122 |
+
other_doc = other_docs[i]
|
| 123 |
+
|
| 124 |
+
# check if the IDs of the documents match
|
| 125 |
+
doc_id = getattr(doc, "id", None)
|
| 126 |
+
other_doc_id = getattr(other_doc, "id", None)
|
| 127 |
+
if doc_id is not None and doc_id != other_doc_id:
|
| 128 |
+
raise ValueError(
|
| 129 |
+
f"IDs of the documents do not match: {doc_id} != {other_doc_id}"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
current_result = _add_annotations_from_other_document(
|
| 133 |
+
doc, other_doc=other_doc, **kwargs
|
| 134 |
+
)
|
| 135 |
+
result.append(current_result)
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError(f"Unsupported type: {type(docs)}")
|
| 138 |
+
|
| 139 |
+
return result
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
DM = TypeVar("DM", bound=Dict[str, Iterable[Document]])
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def add_annotations_from_other_documents_dict(
|
| 146 |
+
docs: DM, other_docs: Dict[str, Sequence[Document]], **kwargs
|
| 147 |
+
) -> DM:
|
| 148 |
+
if set(docs.keys()) != set(other_docs.keys()):
|
| 149 |
+
raise ValueError("Keys of the documents do not match")
|
| 150 |
+
|
| 151 |
+
result_dict = {
|
| 152 |
+
key: add_annotations_from_other_documents(doc_list, other_docs[key], **kwargs)
|
| 153 |
+
for key, doc_list in docs.items()
|
| 154 |
+
}
|
| 155 |
+
return type(docs)(result_dict)
|
| 156 |
|
| 157 |
|
| 158 |
def process_pipeline_steps(
|
|
|
|
| 302 |
"re_add_gold_data": partial(
|
| 303 |
add_annotations_from_other_documents,
|
| 304 |
other_docs=original_docs,
|
| 305 |
+
from_predictions=False,
|
| 306 |
+
to_predictions=False,
|
| 307 |
+
clear_before=False,
|
| 308 |
layer_names=[self.entity_layer, self.relation_layer],
|
| 309 |
**self.processor_kwargs.get("re_add_gold_data", {}),
|
| 310 |
),
|
src/predict.py
CHANGED
|
@@ -34,14 +34,13 @@ root = pyrootutils.setup_root(
|
|
| 34 |
# ------------------------------------------------------------------------------------ #
|
| 35 |
|
| 36 |
import os
|
| 37 |
-
import timeit
|
| 38 |
from collections.abc import Iterable, Sequence
|
| 39 |
from typing import Any, Dict, Optional, Tuple, Union
|
| 40 |
|
| 41 |
import hydra
|
| 42 |
import pytorch_lightning as pl
|
| 43 |
from omegaconf import DictConfig, OmegaConf
|
| 44 |
-
from pie_datasets import
|
| 45 |
from pie_modules.models import * # noqa: F403
|
| 46 |
from pie_modules.taskmodules import * # noqa: F403
|
| 47 |
from pytorch_ie import Document, Pipeline
|
|
@@ -132,38 +131,13 @@ def predict(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
| 132 |
"pipeline": pipeline,
|
| 133 |
"serializer": serializer,
|
| 134 |
}
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
document_batch_size = cfg.get("document_batch_size", None)
|
| 143 |
-
for docs_batch in (
|
| 144 |
-
document_batch_iter(dataset_predict, document_batch_size)
|
| 145 |
-
if document_batch_size
|
| 146 |
-
else [dataset_predict]
|
| 147 |
-
):
|
| 148 |
-
if pipeline is not None:
|
| 149 |
-
t_start = timeit.default_timer()
|
| 150 |
-
docs_batch = pipeline(docs_batch, inplace=False)
|
| 151 |
-
prediction_time += timeit.default_timer() - t_start # type: ignore
|
| 152 |
-
|
| 153 |
-
# serialize the documents
|
| 154 |
-
if serializer is not None:
|
| 155 |
-
# the serializer should not return the serialized documents, but write them to disk
|
| 156 |
-
# and instead return some metadata such as the path to the serialized documents
|
| 157 |
-
serializer_result = serializer(docs_batch)
|
| 158 |
-
if "serializer" in result and result["serializer"] != serializer_result:
|
| 159 |
-
log.warning(
|
| 160 |
-
f"serializer result changed from {result['serializer']} to {serializer_result}"
|
| 161 |
-
" during prediction. Only the last result is returned."
|
| 162 |
-
)
|
| 163 |
-
result["serializer"] = serializer_result
|
| 164 |
-
|
| 165 |
-
if prediction_time is not None:
|
| 166 |
-
result["prediction_time"] = prediction_time
|
| 167 |
|
| 168 |
# serialize config with resolved paths
|
| 169 |
if cfg.get("config_out_path"):
|
|
|
|
| 34 |
# ------------------------------------------------------------------------------------ #
|
| 35 |
|
| 36 |
import os
|
|
|
|
| 37 |
from collections.abc import Iterable, Sequence
|
| 38 |
from typing import Any, Dict, Optional, Tuple, Union
|
| 39 |
|
| 40 |
import hydra
|
| 41 |
import pytorch_lightning as pl
|
| 42 |
from omegaconf import DictConfig, OmegaConf
|
| 43 |
+
from pie_datasets import DatasetDict
|
| 44 |
from pie_modules.models import * # noqa: F403
|
| 45 |
from pie_modules.taskmodules import * # noqa: F403
|
| 46 |
from pytorch_ie import Document, Pipeline
|
|
|
|
| 131 |
"pipeline": pipeline,
|
| 132 |
"serializer": serializer,
|
| 133 |
}
|
| 134 |
+
# predict and serialize
|
| 135 |
+
result: Dict[str, Any] = utils.predict_and_serialize(
|
| 136 |
+
pipeline=pipeline,
|
| 137 |
+
serializer=serializer,
|
| 138 |
+
dataset=dataset_predict,
|
| 139 |
+
document_batch_size=cfg.get("document_batch_size", None),
|
| 140 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
# serialize config with resolved paths
|
| 143 |
if cfg.get("config_out_path"):
|
src/serializer/interface.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
-
from typing import Any,
|
| 3 |
|
| 4 |
from pytorch_ie.core import Document
|
| 5 |
|
|
@@ -12,5 +12,5 @@ class DocumentSerializer(ABC):
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
@abstractmethod
|
| 15 |
-
def __call__(self, documents:
|
| 16 |
pass
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Any, Iterable
|
| 3 |
|
| 4 |
from pytorch_ie.core import Document
|
| 5 |
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
@abstractmethod
|
| 15 |
+
def __call__(self, documents: Iterable[Document]) -> Any:
|
| 16 |
pass
|
src/serializer/json.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
-
from typing import Dict, List, Optional, Sequence, Type, TypeVar
|
| 4 |
|
| 5 |
from pie_datasets import Dataset, DatasetDict, IterableDataset
|
| 6 |
from pie_datasets.core.dataset_dict import METADATA_FILE_NAME
|
|
@@ -8,7 +8,7 @@ from pytorch_ie.core import Document
|
|
| 8 |
from pytorch_ie.utils.hydra import resolve_optional_document_type, serialize_document_type
|
| 9 |
|
| 10 |
from src.serializer.interface import DocumentSerializer
|
| 11 |
-
from src.utils import get_pylogger
|
| 12 |
|
| 13 |
log = get_pylogger(__name__)
|
| 14 |
|
|
@@ -31,7 +31,7 @@ class JsonSerializer(DocumentSerializer):
|
|
| 31 |
@classmethod
|
| 32 |
def write(
|
| 33 |
cls,
|
| 34 |
-
documents:
|
| 35 |
path: str,
|
| 36 |
file_name: str = "documents.jsonl",
|
| 37 |
metadata_file_name: str = METADATA_FILE_NAME,
|
|
@@ -42,6 +42,9 @@ class JsonSerializer(DocumentSerializer):
|
|
| 42 |
log.info(f'serialize documents to "{realpath}" ...')
|
| 43 |
os.makedirs(realpath, exist_ok=True)
|
| 44 |
|
|
|
|
|
|
|
|
|
|
| 45 |
# dump metadata including the document_type
|
| 46 |
if len(documents) == 0:
|
| 47 |
raise Exception("cannot serialize empty list of documents")
|
|
@@ -130,7 +133,7 @@ class JsonSerializer(DocumentSerializer):
|
|
| 130 |
all_kwargs = {**self.default_kwargs, **kwargs}
|
| 131 |
return self.write(**all_kwargs)
|
| 132 |
|
| 133 |
-
def __call__(self, documents:
|
| 134 |
return self.write_with_defaults(documents=documents, **kwargs)
|
| 135 |
|
| 136 |
|
|
@@ -141,12 +144,15 @@ class JsonSerializer2(DocumentSerializer):
|
|
| 141 |
@classmethod
|
| 142 |
def write(
|
| 143 |
cls,
|
| 144 |
-
documents:
|
| 145 |
path: str,
|
| 146 |
split: str = "train",
|
| 147 |
) -> Dict[str, str]:
|
| 148 |
if not isinstance(documents, (Dataset, IterableDataset)):
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
| 150 |
dataset_dict = DatasetDict({split: documents})
|
| 151 |
dataset_dict.to_json(path=path)
|
| 152 |
return {"path": path, "split": split}
|
|
@@ -175,5 +181,5 @@ class JsonSerializer2(DocumentSerializer):
|
|
| 175 |
all_kwargs = {**self.default_kwargs, **kwargs}
|
| 176 |
return self.write(**all_kwargs)
|
| 177 |
|
| 178 |
-
def __call__(self, documents:
|
| 179 |
return self.write_with_defaults(documents=documents, **kwargs)
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Type, TypeVar
|
| 4 |
|
| 5 |
from pie_datasets import Dataset, DatasetDict, IterableDataset
|
| 6 |
from pie_datasets.core.dataset_dict import METADATA_FILE_NAME
|
|
|
|
| 8 |
from pytorch_ie.utils.hydra import resolve_optional_document_type, serialize_document_type
|
| 9 |
|
| 10 |
from src.serializer.interface import DocumentSerializer
|
| 11 |
+
from src.utils.logging_utils import get_pylogger
|
| 12 |
|
| 13 |
log = get_pylogger(__name__)
|
| 14 |
|
|
|
|
| 31 |
@classmethod
|
| 32 |
def write(
|
| 33 |
cls,
|
| 34 |
+
documents: Iterable[Document],
|
| 35 |
path: str,
|
| 36 |
file_name: str = "documents.jsonl",
|
| 37 |
metadata_file_name: str = METADATA_FILE_NAME,
|
|
|
|
| 42 |
log.info(f'serialize documents to "{realpath}" ...')
|
| 43 |
os.makedirs(realpath, exist_ok=True)
|
| 44 |
|
| 45 |
+
if not isinstance(documents, Sequence):
|
| 46 |
+
documents = list(documents)
|
| 47 |
+
|
| 48 |
# dump metadata including the document_type
|
| 49 |
if len(documents) == 0:
|
| 50 |
raise Exception("cannot serialize empty list of documents")
|
|
|
|
| 133 |
all_kwargs = {**self.default_kwargs, **kwargs}
|
| 134 |
return self.write(**all_kwargs)
|
| 135 |
|
| 136 |
+
def __call__(self, documents: Iterable[Document], **kwargs) -> Dict[str, str]:
|
| 137 |
return self.write_with_defaults(documents=documents, **kwargs)
|
| 138 |
|
| 139 |
|
|
|
|
| 144 |
@classmethod
|
| 145 |
def write(
|
| 146 |
cls,
|
| 147 |
+
documents: Iterable[Document],
|
| 148 |
path: str,
|
| 149 |
split: str = "train",
|
| 150 |
) -> Dict[str, str]:
|
| 151 |
if not isinstance(documents, (Dataset, IterableDataset)):
|
| 152 |
+
if not isinstance(documents, Sequence):
|
| 153 |
+
documents = IterableDataset.from_documents(documents)
|
| 154 |
+
else:
|
| 155 |
+
documents = Dataset.from_documents(documents)
|
| 156 |
dataset_dict = DatasetDict({split: documents})
|
| 157 |
dataset_dict.to_json(path=path)
|
| 158 |
return {"path": path, "split": split}
|
|
|
|
| 181 |
all_kwargs = {**self.default_kwargs, **kwargs}
|
| 182 |
return self.write(**all_kwargs)
|
| 183 |
|
| 184 |
+
def __call__(self, documents: Iterable[Document], **kwargs) -> Dict[str, str]:
|
| 185 |
return self.write_with_defaults(documents=documents, **kwargs)
|
src/start_demo.py
CHANGED
|
@@ -99,6 +99,7 @@ def main(cfg: DictConfig) -> None:
|
|
| 99 |
render_caption2mode = {v: k for k, v in render_mode2caption.items()}
|
| 100 |
default_min_similarity = cfg["default_min_similarity"]
|
| 101 |
default_top_k = cfg["default_top_k"]
|
|
|
|
| 102 |
layer_caption_mapping = cfg["layer_caption_mapping"]
|
| 103 |
relation_name_mapping = cfg["relation_name_mapping"]
|
| 104 |
|
|
@@ -287,6 +288,13 @@ def main(cfg: DictConfig) -> None:
|
|
| 287 |
step=1,
|
| 288 |
value=default_top_k,
|
| 289 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
retrieve_similar_adus_btn = gr.Button(
|
| 291 |
"Retrieve *similar* ADUs for *selected* ADU"
|
| 292 |
)
|
|
@@ -361,18 +369,23 @@ def main(cfg: DictConfig) -> None:
|
|
| 361 |
load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset")
|
| 362 |
|
| 363 |
render_event_kwargs = dict(
|
| 364 |
-
fn=lambda _retriever, _document_id, _render_as, _render_kwargs, _all_relevant_adus_df, _all_relevant_adus_query_doc_id:
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
),
|
| 375 |
inputs=[
|
|
|
|
| 376 |
retriever_state,
|
| 377 |
selected_document_id,
|
| 378 |
render_as,
|
|
@@ -583,10 +596,11 @@ def main(cfg: DictConfig) -> None:
|
|
| 583 |
).success(**show_stats_kwargs)
|
| 584 |
|
| 585 |
retrieve_relevant_adus_event_kwargs = dict(
|
| 586 |
-
fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans(
|
| 587 |
retriever=_retriever[0],
|
| 588 |
query_span_id=_selected_adu_id,
|
| 589 |
k=_top_k,
|
|
|
|
| 590 |
score_threshold=_min_similarity,
|
| 591 |
relation_label_mapping=relation_name_mapping,
|
| 592 |
# columns=relevant_adus.headers
|
|
@@ -596,6 +610,7 @@ def main(cfg: DictConfig) -> None:
|
|
| 596 |
selected_adu_id,
|
| 597 |
min_similarity,
|
| 598 |
top_k,
|
|
|
|
| 599 |
],
|
| 600 |
outputs=[relevant_adus_df],
|
| 601 |
)
|
|
@@ -614,10 +629,11 @@ def main(cfg: DictConfig) -> None:
|
|
| 614 |
).success(**retrieve_relevant_adus_event_kwargs)
|
| 615 |
|
| 616 |
retrieve_similar_adus_btn.click(
|
| 617 |
-
fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k: retrieve_similar_spans(
|
| 618 |
retriever=_retriever[0],
|
| 619 |
query_span_id=_selected_adu_id,
|
| 620 |
k=_tok_k,
|
|
|
|
| 621 |
score_threshold=_min_similarity,
|
| 622 |
),
|
| 623 |
inputs=[
|
|
@@ -625,6 +641,7 @@ def main(cfg: DictConfig) -> None:
|
|
| 625 |
selected_adu_id,
|
| 626 |
min_similarity,
|
| 627 |
top_k,
|
|
|
|
| 628 |
],
|
| 629 |
outputs=[similar_adus_df],
|
| 630 |
)
|
|
@@ -635,10 +652,11 @@ def main(cfg: DictConfig) -> None:
|
|
| 635 |
)
|
| 636 |
|
| 637 |
retrieve_all_similar_adus_btn.click(
|
| 638 |
-
fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_similar_spans(
|
| 639 |
retriever=_retriever[0],
|
| 640 |
query_doc_id=_document_id,
|
| 641 |
k=_tok_k,
|
|
|
|
| 642 |
score_threshold=_min_similarity,
|
| 643 |
query_span_id_column="query_span_id",
|
| 644 |
),
|
|
@@ -647,16 +665,18 @@ def main(cfg: DictConfig) -> None:
|
|
| 647 |
selected_document_id,
|
| 648 |
min_similarity,
|
| 649 |
top_k,
|
|
|
|
| 650 |
],
|
| 651 |
outputs=[all_similar_adus_df],
|
| 652 |
)
|
| 653 |
|
| 654 |
retrieve_all_relevant_adus_btn.click(
|
| 655 |
-
fn=lambda _retriever, _document_id, _min_similarity, _tok_k: (
|
| 656 |
retrieve_all_relevant_spans(
|
| 657 |
retriever=_retriever[0],
|
| 658 |
query_doc_id=_document_id,
|
| 659 |
k=_tok_k,
|
|
|
|
| 660 |
score_threshold=_min_similarity,
|
| 661 |
query_span_id_column="query_span_id",
|
| 662 |
query_span_text_column="query_span_text",
|
|
@@ -668,6 +688,7 @@ def main(cfg: DictConfig) -> None:
|
|
| 668 |
selected_document_id,
|
| 669 |
min_similarity,
|
| 670 |
top_k,
|
|
|
|
| 671 |
],
|
| 672 |
outputs=[all_relevant_adus_df, all_relevant_adus_query_doc_id],
|
| 673 |
)
|
|
|
|
| 99 |
render_caption2mode = {v: k for k, v in render_mode2caption.items()}
|
| 100 |
default_min_similarity = cfg["default_min_similarity"]
|
| 101 |
default_top_k = cfg["default_top_k"]
|
| 102 |
+
default_min_score = cfg["default_min_score"]
|
| 103 |
layer_caption_mapping = cfg["layer_caption_mapping"]
|
| 104 |
relation_name_mapping = cfg["relation_name_mapping"]
|
| 105 |
|
|
|
|
| 288 |
step=1,
|
| 289 |
value=default_top_k,
|
| 290 |
)
|
| 291 |
+
min_score = gr.Slider(
|
| 292 |
+
label="Minimum Score",
|
| 293 |
+
minimum=0.0,
|
| 294 |
+
maximum=1.0,
|
| 295 |
+
step=0.01,
|
| 296 |
+
value=default_min_score,
|
| 297 |
+
)
|
| 298 |
retrieve_similar_adus_btn = gr.Button(
|
| 299 |
"Retrieve *similar* ADUs for *selected* ADU"
|
| 300 |
)
|
|
|
|
| 369 |
load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset")
|
| 370 |
|
| 371 |
render_event_kwargs = dict(
|
| 372 |
+
fn=lambda _rendered_output, _retriever, _document_id, _render_as, _render_kwargs, _all_relevant_adus_df, _all_relevant_adus_query_doc_id: (
|
| 373 |
+
render_annotated_document(
|
| 374 |
+
retriever=_retriever[0],
|
| 375 |
+
document_id=_document_id,
|
| 376 |
+
render_with=render_caption2mode[_render_as],
|
| 377 |
+
render_kwargs_json=_render_kwargs,
|
| 378 |
+
highlight_span_ids=(
|
| 379 |
+
_all_relevant_adus_df["query_span_id"].tolist()
|
| 380 |
+
if _document_id == _all_relevant_adus_query_doc_id
|
| 381 |
+
else None
|
| 382 |
+
),
|
| 383 |
+
)
|
| 384 |
+
if _document_id.strip() != ""
|
| 385 |
+
else _rendered_output
|
| 386 |
),
|
| 387 |
inputs=[
|
| 388 |
+
rendered_output,
|
| 389 |
retriever_state,
|
| 390 |
selected_document_id,
|
| 391 |
render_as,
|
|
|
|
| 596 |
).success(**show_stats_kwargs)
|
| 597 |
|
| 598 |
retrieve_relevant_adus_event_kwargs = dict(
|
| 599 |
+
fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k, _min_score: retrieve_relevant_spans(
|
| 600 |
retriever=_retriever[0],
|
| 601 |
query_span_id=_selected_adu_id,
|
| 602 |
k=_top_k,
|
| 603 |
+
min_score=_min_score,
|
| 604 |
score_threshold=_min_similarity,
|
| 605 |
relation_label_mapping=relation_name_mapping,
|
| 606 |
# columns=relevant_adus.headers
|
|
|
|
| 610 |
selected_adu_id,
|
| 611 |
min_similarity,
|
| 612 |
top_k,
|
| 613 |
+
min_score,
|
| 614 |
],
|
| 615 |
outputs=[relevant_adus_df],
|
| 616 |
)
|
|
|
|
| 629 |
).success(**retrieve_relevant_adus_event_kwargs)
|
| 630 |
|
| 631 |
retrieve_similar_adus_btn.click(
|
| 632 |
+
fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k, _min_score: retrieve_similar_spans(
|
| 633 |
retriever=_retriever[0],
|
| 634 |
query_span_id=_selected_adu_id,
|
| 635 |
k=_tok_k,
|
| 636 |
+
min_score=_min_score,
|
| 637 |
score_threshold=_min_similarity,
|
| 638 |
),
|
| 639 |
inputs=[
|
|
|
|
| 641 |
selected_adu_id,
|
| 642 |
min_similarity,
|
| 643 |
top_k,
|
| 644 |
+
min_score,
|
| 645 |
],
|
| 646 |
outputs=[similar_adus_df],
|
| 647 |
)
|
|
|
|
| 652 |
)
|
| 653 |
|
| 654 |
retrieve_all_similar_adus_btn.click(
|
| 655 |
+
fn=lambda _retriever, _document_id, _min_similarity, _tok_k, _min_score: retrieve_all_similar_spans(
|
| 656 |
retriever=_retriever[0],
|
| 657 |
query_doc_id=_document_id,
|
| 658 |
k=_tok_k,
|
| 659 |
+
min_score=_min_score,
|
| 660 |
score_threshold=_min_similarity,
|
| 661 |
query_span_id_column="query_span_id",
|
| 662 |
),
|
|
|
|
| 665 |
selected_document_id,
|
| 666 |
min_similarity,
|
| 667 |
top_k,
|
| 668 |
+
min_score,
|
| 669 |
],
|
| 670 |
outputs=[all_similar_adus_df],
|
| 671 |
)
|
| 672 |
|
| 673 |
retrieve_all_relevant_adus_btn.click(
|
| 674 |
+
fn=lambda _retriever, _document_id, _min_similarity, _tok_k, _min_score: (
|
| 675 |
retrieve_all_relevant_spans(
|
| 676 |
retriever=_retriever[0],
|
| 677 |
query_doc_id=_document_id,
|
| 678 |
k=_tok_k,
|
| 679 |
+
min_score=_min_score,
|
| 680 |
score_threshold=_min_similarity,
|
| 681 |
query_span_id_column="query_span_id",
|
| 682 |
query_span_text_column="query_span_text",
|
|
|
|
| 688 |
selected_document_id,
|
| 689 |
min_similarity,
|
| 690 |
top_k,
|
| 691 |
+
min_score,
|
| 692 |
],
|
| 693 |
outputs=[all_relevant_adus_df, all_relevant_adus_query_doc_id],
|
| 694 |
)
|
src/taskmodules/cross_text_binary_coref_nli.py
CHANGED
|
@@ -62,6 +62,9 @@ class CrossTextBinaryCorefTaskModuleByNli(RelationStatisticsMixin, TaskModuleTyp
|
|
| 62 |
tokenizer_name_or_path: str,
|
| 63 |
labels: List[str],
|
| 64 |
entailment_label: str,
|
|
|
|
|
|
|
|
|
|
| 65 |
**kwargs,
|
| 66 |
) -> None:
|
| 67 |
super().__init__(**kwargs)
|
|
@@ -69,6 +72,9 @@ class CrossTextBinaryCorefTaskModuleByNli(RelationStatisticsMixin, TaskModuleTyp
|
|
| 69 |
|
| 70 |
self.labels = labels
|
| 71 |
self.entailment_label = entailment_label
|
|
|
|
|
|
|
|
|
|
| 72 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
|
| 73 |
|
| 74 |
def _post_prepare(self):
|
|
@@ -118,9 +124,18 @@ class CrossTextBinaryCorefTaskModuleByNli(RelationStatisticsMixin, TaskModuleTyp
|
|
| 118 |
for task_encoding in task_encodings:
|
| 119 |
all_texts.extend(task_encoding.inputs["text"])
|
| 120 |
all_texts_pair.extend(task_encoding.inputs["text_pair"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
inputs = self.tokenizer(
|
| 122 |
-
text=
|
| 123 |
-
text_pair=
|
| 124 |
truncation=True,
|
| 125 |
padding=True,
|
| 126 |
return_tensors="pt",
|
|
@@ -159,8 +174,20 @@ class CrossTextBinaryCorefTaskModuleByNli(RelationStatisticsMixin, TaskModuleTyp
|
|
| 159 |
task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType],
|
| 160 |
task_output: TaskOutputType,
|
| 161 |
) -> Iterator[Tuple[str, Annotation]]:
|
| 162 |
-
if
|
|
|
|
|
|
|
|
|
|
| 163 |
probs = task_output["entailment_probability_pair"]
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
new_coref_rel = task_encoding.metadata["candidate_annotation"].copy(score=score)
|
| 166 |
yield "binary_coref_relations", new_coref_rel
|
|
|
|
| 62 |
tokenizer_name_or_path: str,
|
| 63 |
labels: List[str],
|
| 64 |
entailment_label: str,
|
| 65 |
+
combine_score_method: str = "average",
|
| 66 |
+
keep_all_relations: bool = False,
|
| 67 |
+
as_text_pair: bool = True,
|
| 68 |
**kwargs,
|
| 69 |
) -> None:
|
| 70 |
super().__init__(**kwargs)
|
|
|
|
| 72 |
|
| 73 |
self.labels = labels
|
| 74 |
self.entailment_label = entailment_label
|
| 75 |
+
self.combine_score_method = combine_score_method
|
| 76 |
+
self.keep_all_relations = keep_all_relations
|
| 77 |
+
self.as_text_pair = as_text_pair
|
| 78 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
|
| 79 |
|
| 80 |
def _post_prepare(self):
|
|
|
|
| 124 |
for task_encoding in task_encodings:
|
| 125 |
all_texts.extend(task_encoding.inputs["text"])
|
| 126 |
all_texts_pair.extend(task_encoding.inputs["text_pair"])
|
| 127 |
+
if self.as_text_pair:
|
| 128 |
+
text = all_texts
|
| 129 |
+
text_pair = all_texts_pair
|
| 130 |
+
else:
|
| 131 |
+
text = [
|
| 132 |
+
f"{text}{self.tokenizer.sep_token}{text_pair}"
|
| 133 |
+
for text, text_pair in zip(all_texts, all_texts_pair)
|
| 134 |
+
]
|
| 135 |
+
text_pair = None
|
| 136 |
inputs = self.tokenizer(
|
| 137 |
+
text=text,
|
| 138 |
+
text_pair=text_pair,
|
| 139 |
truncation=True,
|
| 140 |
padding=True,
|
| 141 |
return_tensors="pt",
|
|
|
|
| 174 |
task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType],
|
| 175 |
task_output: TaskOutputType,
|
| 176 |
) -> Iterator[Tuple[str, Annotation]]:
|
| 177 |
+
if (
|
| 178 |
+
all(label == self.entailment_label for label in task_output["label_pair"])
|
| 179 |
+
or self.keep_all_relations
|
| 180 |
+
):
|
| 181 |
probs = task_output["entailment_probability_pair"]
|
| 182 |
+
if self.combine_score_method == "average":
|
| 183 |
+
score = (probs[0] + probs[1]) / 2
|
| 184 |
+
elif self.combine_score_method == "min":
|
| 185 |
+
score = min(probs)
|
| 186 |
+
elif self.combine_score_method == "max":
|
| 187 |
+
score = max(probs)
|
| 188 |
+
elif self.combine_score_method == "product":
|
| 189 |
+
score = probs[0] * probs[1]
|
| 190 |
+
else:
|
| 191 |
+
raise ValueError(f"Unsupported combine_score_method: {self.combine_score_method}")
|
| 192 |
new_coref_rel = task_encoding.metadata["candidate_annotation"].copy(score=score)
|
| 193 |
yield "binary_coref_relations", new_coref_rel
|
src/train.py
CHANGED
|
@@ -38,13 +38,14 @@ from typing import Any, Dict, List, Optional, Tuple
|
|
| 38 |
|
| 39 |
import hydra
|
| 40 |
import pytorch_lightning as pl
|
| 41 |
-
from omegaconf import DictConfig
|
| 42 |
from pie_datasets import DatasetDict
|
| 43 |
from pie_modules.models import * # noqa: F403
|
| 44 |
from pie_modules.models import SimpleGenerativeModel
|
| 45 |
from pie_modules.models.interface import RequiresTaskmoduleConfig
|
| 46 |
from pie_modules.taskmodules import * # noqa: F403
|
| 47 |
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
|
|
|
| 48 |
from pytorch_ie.core import PyTorchIEModel, TaskModule
|
| 49 |
from pytorch_ie.models import * # noqa: F403
|
| 50 |
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
|
|
@@ -56,6 +57,7 @@ from pytorch_lightning.loggers import Logger
|
|
| 56 |
from src import utils
|
| 57 |
from src.datamodules import PieDataModule
|
| 58 |
from src.models import * # noqa: F403
|
|
|
|
| 59 |
from src.taskmodules import * # noqa: F403
|
| 60 |
|
| 61 |
log = utils.get_pylogger(__name__)
|
|
@@ -81,6 +83,27 @@ def get_metric_value(metric_dict: dict, metric_name: str) -> Optional[float]:
|
|
| 81 |
return metric_value
|
| 82 |
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
@utils.task_wrapper
|
| 85 |
def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
| 86 |
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
|
|
@@ -179,6 +202,11 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
| 179 |
)
|
| 180 |
additional_model_kwargs["base_model_config"] = base_model_config
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
# initialize the model
|
| 183 |
model: PyTorchIEModel = hydra.utils.instantiate(
|
| 184 |
cfg.model, _convert_="partial", **additional_model_kwargs
|
|
@@ -207,9 +235,11 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
| 207 |
log.info("Logging hyperparameters!")
|
| 208 |
utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)
|
| 209 |
|
| 210 |
-
if cfg.model_save_dir is not None:
|
| 211 |
-
log.info(f"Save taskmodule to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
|
| 212 |
-
taskmodule.save_pretrained(
|
|
|
|
|
|
|
| 213 |
else:
|
| 214 |
log.warning("the taskmodule is not saved because no save_dir is specified")
|
| 215 |
|
|
@@ -238,15 +268,17 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
| 238 |
f"Expected format: " + '"epoch_{best_epoch}.ckpt"'
|
| 239 |
)
|
| 240 |
|
| 241 |
-
if not cfg.trainer.get("fast_dev_run"):
|
| 242 |
-
if cfg.model_save_dir is not None:
|
| 243 |
if best_ckpt_path == "":
|
| 244 |
log.warning("Best ckpt not found! Using current weights for saving...")
|
| 245 |
else:
|
| 246 |
model = type(model).load_from_checkpoint(best_ckpt_path)
|
| 247 |
|
| 248 |
-
log.info(f"Save model to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
|
| 249 |
-
model.save_pretrained(
|
|
|
|
|
|
|
| 250 |
else:
|
| 251 |
log.warning("the model is not saved because no save_dir is specified")
|
| 252 |
|
|
@@ -275,8 +307,36 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
| 275 |
|
| 276 |
# add model_save_dir to the result so that it gets dumped to job_return_value.json
|
| 277 |
# if we use hydra_callbacks.SaveJobReturnValueCallback
|
| 278 |
-
if cfg.get("model_save_dir") is not None:
|
| 279 |
-
metric_dict["model_save_dir"] = cfg.model_save_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
|
| 281 |
return metric_dict, object_dict
|
| 282 |
|
|
@@ -301,4 +361,5 @@ def main(cfg: DictConfig) -> Optional[float]:
|
|
| 301 |
if __name__ == "__main__":
|
| 302 |
utils.replace_sys_args_with_values_from_files()
|
| 303 |
utils.prepare_omegaconf()
|
|
|
|
| 304 |
main()
|
|
|
|
| 38 |
|
| 39 |
import hydra
|
| 40 |
import pytorch_lightning as pl
|
| 41 |
+
from omegaconf import DictConfig, OmegaConf
|
| 42 |
from pie_datasets import DatasetDict
|
| 43 |
from pie_modules.models import * # noqa: F403
|
| 44 |
from pie_modules.models import SimpleGenerativeModel
|
| 45 |
from pie_modules.models.interface import RequiresTaskmoduleConfig
|
| 46 |
from pie_modules.taskmodules import * # noqa: F403
|
| 47 |
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
| 48 |
+
from pytorch_ie import Pipeline
|
| 49 |
from pytorch_ie.core import PyTorchIEModel, TaskModule
|
| 50 |
from pytorch_ie.models import * # noqa: F403
|
| 51 |
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
|
|
|
|
| 57 |
from src import utils
|
| 58 |
from src.datamodules import PieDataModule
|
| 59 |
from src.models import * # noqa: F403
|
| 60 |
+
from src.serializer.interface import DocumentSerializer
|
| 61 |
from src.taskmodules import * # noqa: F403
|
| 62 |
|
| 63 |
log = utils.get_pylogger(__name__)
|
|
|
|
| 83 |
return metric_value
|
| 84 |
|
| 85 |
|
| 86 |
+
def flatten_nested_dict(d: Dict[str, Any], parent_key: str = "", sep: str = ".") -> Dict[str, Any]:
|
| 87 |
+
"""Flatten a nested dictionary.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
d (Dict[str, Any]): The dictionary to flatten.
|
| 91 |
+
parent_key (str): The parent key.
|
| 92 |
+
sep (str): The separator.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Dict[str, Any]: The flattened dictionary.
|
| 96 |
+
"""
|
| 97 |
+
items: List[Tuple[str, Any]] = []
|
| 98 |
+
for k, v in d.items():
|
| 99 |
+
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
| 100 |
+
if isinstance(v, dict):
|
| 101 |
+
items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
|
| 102 |
+
else:
|
| 103 |
+
items.append((new_key, v))
|
| 104 |
+
return dict(items)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
@utils.task_wrapper
|
| 108 |
def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
| 109 |
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
|
|
|
|
| 202 |
)
|
| 203 |
additional_model_kwargs["base_model_config"] = base_model_config
|
| 204 |
|
| 205 |
+
if issubclass(model_cls, SimpleSequenceClassificationModelWithInputTypeIds): # noqa: F405
|
| 206 |
+
# add the number of input type ids to the model:
|
| 207 |
+
# 2 for B- and I-labels for each entity type, 1 for O labels, 1 for padding
|
| 208 |
+
additional_model_kwargs["num_token_type_ids"] = len(taskmodule.entity_labels) * 2 + 1 + 1
|
| 209 |
+
|
| 210 |
# initialize the model
|
| 211 |
model: PyTorchIEModel = hydra.utils.instantiate(
|
| 212 |
cfg.model, _convert_="partial", **additional_model_kwargs
|
|
|
|
| 235 |
log.info("Logging hyperparameters!")
|
| 236 |
utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)
|
| 237 |
|
| 238 |
+
if cfg.paths.model_save_dir is not None:
|
| 239 |
+
log.info(f"Save taskmodule to {cfg.paths.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
|
| 240 |
+
taskmodule.save_pretrained(
|
| 241 |
+
save_directory=cfg.paths.model_save_dir, push_to_hub=cfg.push_to_hub
|
| 242 |
+
)
|
| 243 |
else:
|
| 244 |
log.warning("the taskmodule is not saved because no save_dir is specified")
|
| 245 |
|
|
|
|
| 268 |
f"Expected format: " + '"epoch_{best_epoch}.ckpt"'
|
| 269 |
)
|
| 270 |
|
| 271 |
+
if not cfg.trainer.get("fast_dev_run") or cfg.get("predict", False):
|
| 272 |
+
if cfg.paths.model_save_dir is not None:
|
| 273 |
if best_ckpt_path == "":
|
| 274 |
log.warning("Best ckpt not found! Using current weights for saving...")
|
| 275 |
else:
|
| 276 |
model = type(model).load_from_checkpoint(best_ckpt_path)
|
| 277 |
|
| 278 |
+
log.info(f"Save model to {cfg.paths.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
|
| 279 |
+
model.save_pretrained(
|
| 280 |
+
save_directory=cfg.paths.model_save_dir, push_to_hub=cfg.push_to_hub
|
| 281 |
+
)
|
| 282 |
else:
|
| 283 |
log.warning("the model is not saved because no save_dir is specified")
|
| 284 |
|
|
|
|
| 307 |
|
| 308 |
# add model_save_dir to the result so that it gets dumped to job_return_value.json
|
| 309 |
# if we use hydra_callbacks.SaveJobReturnValueCallback
|
| 310 |
+
if cfg.paths.get("model_save_dir") is not None:
|
| 311 |
+
metric_dict["model_save_dir"] = cfg.paths.model_save_dir
|
| 312 |
+
|
| 313 |
+
if cfg.get("predict"):
|
| 314 |
+
# Init the inference pipeline
|
| 315 |
+
pipeline: Optional[Pipeline] = None
|
| 316 |
+
if cfg.get("pipeline") and cfg.pipeline.get("_target_"):
|
| 317 |
+
log.info(f"Instantiating inference pipeline <{cfg.pipeline._target_}>")
|
| 318 |
+
pipeline = hydra.utils.instantiate(cfg.pipeline, _convert_="partial")
|
| 319 |
+
# Init the serializer
|
| 320 |
+
serializer: Optional[DocumentSerializer] = None
|
| 321 |
+
if cfg.get("serializer") and cfg.serializer.get("_target_"):
|
| 322 |
+
log.info(f"Instantiating serializer <{cfg.serializer._target_}>")
|
| 323 |
+
serializer = hydra.utils.instantiate(cfg.serializer, _convert_="partial")
|
| 324 |
+
# predict and serialize
|
| 325 |
+
predict_metrics: Dict[str, Any] = utils.predict_and_serialize(
|
| 326 |
+
pipeline=pipeline,
|
| 327 |
+
serializer=serializer,
|
| 328 |
+
dataset=dataset[cfg.dataset_split],
|
| 329 |
+
document_batch_size=cfg.get("document_batch_size", None),
|
| 330 |
+
)
|
| 331 |
+
# flatten the predict_metrics dict
|
| 332 |
+
predict_metrics_flat = flatten_nested_dict(predict_metrics, sep="/")
|
| 333 |
+
metric_dict.update(predict_metrics_flat)
|
| 334 |
+
|
| 335 |
+
if cfg.get("delete_model_dir"):
|
| 336 |
+
import shutil
|
| 337 |
+
|
| 338 |
+
log.info(f"Deleting model directory {cfg.paths.model_save_dir}")
|
| 339 |
+
shutil.rmtree(cfg.paths.model_save_dir)
|
| 340 |
|
| 341 |
return metric_dict, object_dict
|
| 342 |
|
|
|
|
| 361 |
if __name__ == "__main__":
|
| 362 |
utils.replace_sys_args_with_values_from_files()
|
| 363 |
utils.prepare_omegaconf()
|
| 364 |
+
OmegaConf.register_new_resolver("eval", eval)
|
| 365 |
main()
|
src/utils/__init__.py
CHANGED
|
@@ -5,7 +5,8 @@ from .config_utils import (
|
|
| 5 |
prepare_omegaconf,
|
| 6 |
)
|
| 7 |
from .data_utils import download_and_unzip, filter_dataframe_and_get_column
|
|
|
|
| 8 |
from .logging_utils import close_loggers, get_pylogger, log_hyperparameters
|
| 9 |
from .rich_utils import enforce_tags, print_config_tree
|
| 10 |
-
from .span_utils import distance
|
| 11 |
from .task_utils import extras, replace_sys_args_with_values_from_files, save_file, task_wrapper
|
|
|
|
| 5 |
prepare_omegaconf,
|
| 6 |
)
|
| 7 |
from .data_utils import download_and_unzip, filter_dataframe_and_get_column
|
| 8 |
+
from .inference_utils import predict_and_serialize
|
| 9 |
from .logging_utils import close_loggers, get_pylogger, log_hyperparameters
|
| 10 |
from .rich_utils import enforce_tags, print_config_tree
|
| 11 |
+
from .span_utils import distance, distance_slices
|
| 12 |
from .task_utils import extras, replace_sys_args_with_values_from_files, save_file, task_wrapper
|
src/utils/inference_utils.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import timeit
|
| 2 |
+
from collections.abc import Iterable, Sequence
|
| 3 |
+
from typing import Any, Dict, Optional, Union
|
| 4 |
+
|
| 5 |
+
from pytorch_ie import Document, Pipeline
|
| 6 |
+
|
| 7 |
+
from src.serializer.interface import DocumentSerializer
|
| 8 |
+
|
| 9 |
+
from .logging_utils import get_pylogger
|
| 10 |
+
|
| 11 |
+
log = get_pylogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def document_batch_iter(
|
| 15 |
+
dataset: Iterable[Document], batch_size: int
|
| 16 |
+
) -> Iterable[Sequence[Document]]:
|
| 17 |
+
if isinstance(dataset, Sequence):
|
| 18 |
+
for i in range(0, len(dataset), batch_size):
|
| 19 |
+
yield dataset[i : i + batch_size]
|
| 20 |
+
elif isinstance(dataset, Iterable):
|
| 21 |
+
docs = []
|
| 22 |
+
for doc in dataset:
|
| 23 |
+
docs.append(doc)
|
| 24 |
+
if len(docs) == batch_size:
|
| 25 |
+
yield docs
|
| 26 |
+
docs = []
|
| 27 |
+
if docs:
|
| 28 |
+
yield docs
|
| 29 |
+
else:
|
| 30 |
+
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def predict_and_serialize(
|
| 34 |
+
pipeline: Optional[Pipeline],
|
| 35 |
+
serializer: Optional[DocumentSerializer],
|
| 36 |
+
dataset: Iterable[Document],
|
| 37 |
+
document_batch_size: Optional[int] = None,
|
| 38 |
+
) -> Dict[str, Any]:
|
| 39 |
+
result: Dict[str, Any] = {}
|
| 40 |
+
if pipeline is not None:
|
| 41 |
+
log.info("Starting inference!")
|
| 42 |
+
prediction_time = 0.0
|
| 43 |
+
else:
|
| 44 |
+
log.warning("No prediction pipeline is defined, skip inference!")
|
| 45 |
+
prediction_time = None
|
| 46 |
+
docs_batch: Union[Iterable[Document], Sequence[Document]]
|
| 47 |
+
|
| 48 |
+
batch_iter: Union[Sequence[Iterable[Document]], Iterable[Sequence[Document]]]
|
| 49 |
+
if document_batch_size is None:
|
| 50 |
+
batch_iter = [dataset]
|
| 51 |
+
else:
|
| 52 |
+
batch_iter = document_batch_iter(dataset=dataset, batch_size=document_batch_size)
|
| 53 |
+
for docs_batch in batch_iter:
|
| 54 |
+
if pipeline is not None:
|
| 55 |
+
t_start = timeit.default_timer()
|
| 56 |
+
docs_batch = pipeline(docs_batch, inplace=False)
|
| 57 |
+
prediction_time += timeit.default_timer() - t_start # type: ignore
|
| 58 |
+
|
| 59 |
+
# serialize the documents
|
| 60 |
+
if serializer is not None:
|
| 61 |
+
# the serializer should not return the serialized documents, but write them to disk
|
| 62 |
+
# and instead return some metadata such as the path to the serialized documents
|
| 63 |
+
serializer_result = serializer(docs_batch)
|
| 64 |
+
if "serializer" in result and result["serializer"] != serializer_result:
|
| 65 |
+
log.warning(
|
| 66 |
+
f"serializer result changed from {result['serializer']} to {serializer_result}"
|
| 67 |
+
" during prediction. Only the last result is returned."
|
| 68 |
+
)
|
| 69 |
+
result["serializer"] = serializer_result
|
| 70 |
+
|
| 71 |
+
if prediction_time is not None:
|
| 72 |
+
result["prediction_time"] = prediction_time
|
| 73 |
+
|
| 74 |
+
return result
|
src/utils/span_utils.py
CHANGED
|
@@ -58,3 +58,17 @@ def distance(
|
|
| 58 |
raise ValueError(
|
| 59 |
f"unknown distance_type={distance_type}. use one of: center, inner, outer"
|
| 60 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
raise ValueError(
|
| 59 |
f"unknown distance_type={distance_type}. use one of: center, inner, outer"
|
| 60 |
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def distance_slices(
|
| 64 |
+
slices: Tuple[Tuple[int, int], ...],
|
| 65 |
+
other_slices: Tuple[Tuple[int, int], ...],
|
| 66 |
+
distance_type: str,
|
| 67 |
+
) -> float:
|
| 68 |
+
starts, ends = zip(*slices)
|
| 69 |
+
other_starts, other_ends = zip(*other_slices)
|
| 70 |
+
return distance(
|
| 71 |
+
start_end=(min(starts), max(ends)),
|
| 72 |
+
other_start_end=(min(other_starts), max(other_ends)),
|
| 73 |
+
distance_type=distance_type,
|
| 74 |
+
)
|