|
import networkx as nx |
|
import numpy as np |
|
from cdlib import algorithms |
|
|
|
|
|
|
|
def normalize_text(s): |
|
"""Removing articles and punctuation, and standardizing whitespace are all typical text processing steps.""" |
|
import string, re |
|
|
|
def remove_articles(text): |
|
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) |
|
return re.sub(regex, " ", text) |
|
|
|
def white_space_fix(text): |
|
return " ".join(text.split()) |
|
|
|
def remove_punc(text): |
|
exclude = set(string.punctuation) |
|
return "".join(ch for ch in text if ch not in exclude) |
|
|
|
def lower(text): |
|
return text.lower() |
|
|
|
return white_space_fix(remove_articles(remove_punc(lower(s)))) |
|
|
|
|
|
def compute_exact_match(prediction, truth): |
|
return int(normalize_text(prediction) == normalize_text(truth)) |
|
|
|
|
|
def compute_f1(prediction, truth): |
|
pred_tokens = normalize_text(prediction).split() |
|
truth_tokens = normalize_text(truth).split() |
|
|
|
|
|
if len(pred_tokens) == 0 or len(truth_tokens) == 0: |
|
return int(pred_tokens == truth_tokens) |
|
|
|
common_tokens = set(pred_tokens) & set(truth_tokens) |
|
|
|
|
|
if len(common_tokens) == 0: |
|
return 0 |
|
|
|
prec = len(common_tokens) / len(pred_tokens) |
|
rec = len(common_tokens) / len(truth_tokens) |
|
|
|
return 2 * (prec * rec) / (prec + rec) |
|
|
|
|
|
def is_date_or_num(answer): |
|
answer = answer.lower().split() |
|
for w in answer: |
|
w = w.strip() |
|
if w.isnumeric() or w in ["ngày", "tháng", "năm"]: |
|
return True |
|
return False |
|
|
|
|
|
def find_best_cluster(answers, best_answer, thr=0.79): |
|
if len(answers) == 0: |
|
return best_answer |
|
elif len(answers) == 1: |
|
return answers[0] |
|
dists = np.zeros((len(answers), len(answers))) |
|
for i in range(len(answers) - 1): |
|
for j in range(i + 1, len(answers)): |
|
a1 = answers[i].lower().strip() |
|
a2 = answers[j].lower().strip() |
|
if is_date_or_num(a1) or is_date_or_num(a2): |
|
|
|
if a1 == a2 or ("tháng" in a1 and a1 in a2) or ("tháng" in a2 and a2 in a1): |
|
dists[i, j] = 1 |
|
dists[j, i] = 1 |
|
|
|
elif a1 == a2 or (a1 in a2) or (a2 in a1) or compute_f1(a1.lower(), a2.lower()) >= thr: |
|
dists[i, j] = 1 |
|
dists[j, i] = 1 |
|
|
|
try: |
|
thr = 1 |
|
dups = np.where(dists >= thr) |
|
dup_strs = [] |
|
edges = [] |
|
for i, j in zip(dups[0], dups[1]): |
|
if i != j: |
|
edges.append((i, j)) |
|
G = nx.Graph() |
|
for i, answer in enumerate(answers): |
|
G.add_node(i, content=answer) |
|
G.add_edges_from(edges) |
|
partition = algorithms.louvain(G) |
|
max_len_comm = np.max([len(x) for x in partition.communities]) |
|
best_comms = [] |
|
for comm in partition.communities: |
|
|
|
if len(comm) == max_len_comm: |
|
best_comms.append([answers[i] for i in comm]) |
|
|
|
|
|
for comm in best_comms: |
|
if best_answer in comm: |
|
return best_answer |
|
mid = len(best_comms[0]) // 2 |
|
|
|
return sorted(best_comms[0], key=len)[mid] |
|
except Exception as e: |
|
print(e, "Disconnected graph") |
|
return best_answer |
|
|