from . import SentenceEvaluator import logging from sentence_transformers.util import paraphrase_mining import os import csv from typing import List, Tuple, Dict from collections import defaultdict logger = logging.getLogger(__name__) class ParaphraseMiningEvaluator(SentenceEvaluator): """ Given a large set of sentences, this evaluator performs paraphrase (duplicate) mining and identifies the pairs with the highest similarity. It compare the extracted paraphrase pairs with a set of gold labels and computes the F1 score. """ def __init__(self, sentences_map: Dict[str, str], duplicates_list: List[Tuple[str, str]] = None, duplicates_dict: Dict[str, Dict[str, bool]] = None, add_transitive_closure: bool = False, query_chunk_size:int = 5000, corpus_chunk_size:int = 100000, max_pairs: int = 500000, top_k: int = 100, show_progress_bar: bool = False, batch_size: int = 16, name: str = '', write_csv: bool = True): """ :param sentences_map: A dictionary that maps sentence-ids to sentences, i.e. sentences_map[id] => sentence. :param duplicates_list: Duplicates_list is a list with id pairs [(id1, id2), (id1, id5)] that identifies the duplicates / paraphrases in the sentences_map :param duplicates_dict: A default dictionary mapping [id1][id2] to true if id1 and id2 are duplicates. Must be symmetric, i.e., if [id1][id2] => True, then [id2][id1] => True. :param add_transitive_closure: If true, it adds a transitive closure, i.e. if dup[a][b] and dup[b][c], then dup[a][c] :param query_chunk_size: To identify the paraphrases, the cosine-similarity between all sentence-pairs will be computed. As this might require a lot of memory, we perform a batched computation. #query_batch_size sentences will be compared against up to #corpus_batch_size sentences. In the default setting, 5000 sentences will be grouped together and compared up-to against 100k other sentences. :param corpus_chunk_size: The corpus will be batched, to reduce the memory requirement :param max_pairs: We will only extract up to #max_pairs potential paraphrase candidates. :param top_k: For each query, we extract the top_k most similar pairs and add it to a sorted list. I.e., for one sentence we cannot find more than top_k paraphrases :param show_progress_bar: Output a progress bar :param batch_size: Batch size for computing sentence embeddings :param name: Name of the experiment :param write_csv: Write results to CSV file """ self.sentences = [] self.ids = [] for id, sentence in sentences_map.items(): self.sentences.append(sentence) self.ids.append(id) self.name = name self.show_progress_bar = show_progress_bar self.batch_size = batch_size self.query_chunk_size = query_chunk_size self.corpus_chunk_size = corpus_chunk_size self.max_pairs = max_pairs self.top_k = top_k self.duplicates = duplicates_dict if duplicates_dict is not None else defaultdict(lambda: defaultdict(bool)) if duplicates_list is not None: for id1, id2 in duplicates_list: if id1 in sentences_map and id2 in sentences_map: self.duplicates[id1][id2] = True self.duplicates[id2][id1] = True #Add transitive closure if add_transitive_closure: self.duplicates = self.add_transitive_closure(self.duplicates) positive_key_pairs = set() for key1 in self.duplicates: for key2 in self.duplicates[key1]: if key1 in sentences_map and key2 in sentences_map and (self.duplicates[key1][key2] or self.duplicates[key2][key1]): positive_key_pairs.add(tuple(sorted([key1, key2]))) self.total_num_duplicates = len(positive_key_pairs) if name: name = "_" + name self.csv_file: str = "paraphrase_mining_evaluation" + name + "_results.csv" self.csv_headers = ["epoch", "steps", "precision", "recall", "f1", "threshold", "average_precision"] self.write_csv = write_csv def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: if epoch != -1: out_txt = f" after epoch {epoch}:" if steps == -1 else f" in epoch {epoch} after {steps} steps:" else: out_txt = ":" logger.info("Paraphrase Mining Evaluation on " + self.name + " dataset" + out_txt) #Compute embedding for the sentences pairs_list = paraphrase_mining(model, self.sentences, self.show_progress_bar, self.batch_size, self.query_chunk_size, self.corpus_chunk_size, self.max_pairs, self.top_k ) logger.info("Number of candidate pairs: " + str(len(pairs_list))) #Compute F1 score and Average Precision n_extract = n_correct = 0 threshold = 0 best_f1 = best_recall = best_precision = 0 average_precision = 0 for idx in range(len(pairs_list)): score, i, j = pairs_list[idx] id1 = self.ids[i] id2 = self.ids[j] #Compute optimal threshold and F1-score n_extract += 1 if self.duplicates[id1][id2] or self.duplicates[id2][id1]: n_correct += 1 precision = n_correct / n_extract recall = n_correct / self.total_num_duplicates f1 = 2 * precision * recall / (precision + recall) average_precision += precision if f1 > best_f1: best_f1 = f1 best_precision = precision best_recall = recall threshold = (pairs_list[idx][0] + pairs_list[min(idx + 1, len(pairs_list)-1)][0]) / 2 average_precision = average_precision / self.total_num_duplicates logger.info("Average Precision: {:.2f}".format(average_precision * 100)) logger.info("Optimal threshold: {:.4f}".format(threshold)) logger.info("Precision: {:.2f}".format(best_precision * 100)) logger.info("Recall: {:.2f}".format(best_recall * 100)) logger.info("F1: {:.2f}\n".format(best_f1 * 100)) if output_path is not None and self.write_csv: csv_path = os.path.join(output_path, self.csv_file) if not os.path.isfile(csv_path): with open(csv_path, newline='', mode="w", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow(self.csv_headers) writer.writerow([epoch, steps, best_precision, best_recall, best_f1, threshold, average_precision]) else: with open(csv_path, newline='', mode="a", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow([epoch, steps, best_precision, best_recall, best_f1, threshold, average_precision]) return average_precision @staticmethod def add_transitive_closure(graph): nodes_visited = set() for a in list(graph.keys()): if a not in nodes_visited: connected_subgraph_nodes = set() connected_subgraph_nodes.add(a) # Add all nodes in the connected graph neighbor_nodes_queue = list(graph[a]) while len(neighbor_nodes_queue) > 0: node = neighbor_nodes_queue.pop(0) if node not in connected_subgraph_nodes: connected_subgraph_nodes.add(node) neighbor_nodes_queue.extend(graph[node]) # Ensure transitivity between all nodes in the graph connected_subgraph_nodes = list(connected_subgraph_nodes) for i in range(len(connected_subgraph_nodes) - 1): for j in range(i + 1, len(connected_subgraph_nodes)): graph[connected_subgraph_nodes[i]][connected_subgraph_nodes[j]] = True graph[connected_subgraph_nodes[j]][connected_subgraph_nodes[i]] = True nodes_visited.add(connected_subgraph_nodes[i]) nodes_visited.add(connected_subgraph_nodes[j]) return graph