AMLSim / scripts /visualize /plot_alert_pattern_subgraphs.py
dingyiz's picture
Upload folder using huggingface_hub
2795186 verified
raw
history blame
3.31 kB
import os
import sys
import csv
import json
import networkx as nx
from collections import defaultdict
import matplotlib
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore', category=matplotlib.cbook.deprecation.MatplotlibDeprecationWarning)
def load_alerts(_conf_json):
_g = nx.DiGraph()
_bank_accts = defaultdict(list)
with open(_conf_json, "r") as rf:
conf = json.load(rf)
data_dir = os.path.join(conf["output"]["directory"], conf["general"]["simulation_name"])
acct_csv = os.path.join(data_dir, conf["output"]["alert_members"])
tx_csv = os.path.join(data_dir, conf["output"]["alert_transactions"])
input_dir = conf["input"]["directory"]
schema_json = os.path.join(input_dir, conf["input"]["schema"])
with open(schema_json, "r") as rf:
schema = json.load(rf)
acct_idx = None
bank_idx = None
orig_idx = None
bene_idx = None
amt_idx = None
date_idx = None
for i, col in enumerate(schema["alert_member"]):
if col.get("dataType") == "account_id":
acct_idx = i
elif col.get("dataType") == "bank_id":
bank_idx = i
for i, col in enumerate(schema["alert_tx"]):
if col.get("dataType") == "orig_id":
orig_idx = i
elif col.get("dataType") == "dest_id":
bene_idx = i
elif col.get("dataType") == "amount":
amt_idx = i
elif col.get("dataType") == "timestamp":
date_idx = i
with open(acct_csv, "r") as rf:
reader = csv.reader(rf)
next(reader)
for row in reader:
acct_id = row[acct_idx]
bank_id = row[bank_idx]
_g.add_node(acct_id, bank_id=bank_id)
_bank_accts[bank_id].append(acct_id)
with open(tx_csv, "r") as rf:
reader = csv.reader(rf)
next(reader)
for row in reader:
orig_id = row[orig_idx]
bene_id = row[bene_idx]
amount = row[amt_idx]
date = row[date_idx].split("T")[0] # Extract only the date
label = amount + "\n" + date
_g.add_edge(orig_id, bene_id, amount=amount, date=date, label=label)
return _g, _bank_accts
def plot_alerts(_g, _bank_accts, _output_png):
bank_ids = _bank_accts.keys()
cmap = plt.get_cmap("tab10")
pos = nx.nx_agraph.graphviz_layout(_g)
plt.figure(figsize=(12.0, 8.0))
plt.axis('off')
for i, bank_id in enumerate(bank_ids):
color = cmap(i)
members = _bank_accts[bank_id]
nx.draw_networkx_nodes(_g, pos, members, node_size=300, node_color=color, label=bank_id)
nx.draw_networkx_labels(_g, pos, {n: n for n in members}, font_size=10)
edge_labels = nx.get_edge_attributes(_g, "label")
nx.draw_networkx_edges(_g, pos)
nx.draw_networkx_edge_labels(_g, pos, edge_labels, font_size=6)
plt.legend(numpoints=1)
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
plt.savefig(_output_png, dpi=120)
if __name__ == "__main__":
argv = sys.argv
if len(argv) < 3:
print("Usage: python3 %s [ConfJSON] [OutputPNG]" % argv[0])
exit(1)
conf_json = argv[1]
output_png = argv[2]
g, bank_accts = load_alerts(conf_json)
plot_alerts(g, bank_accts, output_png)