|
|
|
""" |
|
modular_graph_and_candidates.py |
|
================================ |
|
Create **one** rich view that combines |
|
1. The *dependency graph* between existing **modular_*.py** implementations in |
|
π€Β Transformers (blue/π‘) **and** |
|
2. The list of *missing* modular models (fullβred nodes) **plus** similarity |
|
edges (fullβred links) between highlyβoverlapping modelling files β the |
|
output of *find_modular_candidates.py* β so you can immediately spot good |
|
refactor opportunities. |
|
|
|
βββΒ UsageΒ βββ |
|
|
|
```bash |
|
python modular_graph_and_candidates.py /path/to/transformers \ |
|
--multimodal # keep only models whose modelling code mentions |
|
# "pixel_values" β₯Β 3 times |
|
--sim-threshold 0.5 # Jaccard cutoff (default 0.50) |
|
--out graph.html # output HTML file name |
|
``` |
|
|
|
Colour legend in the generated HTML: |
|
* π‘Β **base model**Β β has modular shards *imported* by others but no parent |
|
* π΅Β **derived modular model**Β β has a `modular_*.py` and inherits from β₯β―1 model |
|
* π΄Β **candidate**Β β no `modular_*.py` yet (and/or very similar to another) |
|
* red edges = highβJaccard similarity links (potential to factorise) |
|
""" |
|
from __future__ import annotations |
|
|
|
import argparse |
|
import ast |
|
import json |
|
import re |
|
import tokenize |
|
from collections import Counter, defaultdict |
|
from itertools import combinations |
|
from pathlib import Path |
|
from typing import Dict, List, Set, Tuple |
|
from sentence_transformers import SentenceTransformer, util |
|
from tqdm import tqdm |
|
import numpy as np |
|
|
|
|
|
|
|
|
|
SIM_DEFAULT = 0.78 |
|
PIXEL_MIN_HITS = 0 |
|
HTML_DEFAULT = "d3_modular_graph.html" |
|
|
|
|
|
|
|
|
|
|
|
def _strip_source(code: str) -> str: |
|
"""Remove docβstrings, comments and import lines to keep only the core code.""" |
|
code = re.sub(r'("""|\'\'\')(?:.|\n)*?\1', "", code) |
|
code = re.sub(r"#.*", "", code) |
|
return "\n".join(ln for ln in code.splitlines() |
|
if not re.match(r"\s*(from|import)\s+", ln)) |
|
|
|
def _tokenise(code: str) -> Set[str]: |
|
toks: Set[str] = set() |
|
for tok in tokenize.generate_tokens(iter(code.splitlines(keepends=True)).__next__): |
|
if tok.type == tokenize.NAME: |
|
toks.add(tok.string) |
|
return toks |
|
|
|
def build_token_bags(models_root: Path) -> Tuple[Dict[str, List[Set[str]]], Dict[str, int]]: |
|
"""Return tokenβbags of every `modeling_*.py` plus a pixelβvalue counter.""" |
|
bags: Dict[str, List[Set[str]]] = defaultdict(list) |
|
pixel_hits: Dict[str, int] = defaultdict(int) |
|
for mdl_dir in sorted(p for p in models_root.iterdir() if p.is_dir()): |
|
for py in mdl_dir.rglob("modeling_*.py"): |
|
try: |
|
text = py.read_text(encoding="utfβ8") |
|
pixel_hits[mdl_dir.name] += text.count("pixel_values") |
|
bags[mdl_dir.name].append(_tokenise(_strip_source(text))) |
|
except Exception as e: |
|
print(f"β οΈ Skipped {py}: {e}") |
|
return bags, pixel_hits |
|
|
|
def _jaccard(a: Set[str], b: Set[str]) -> float: |
|
return 0.0 if (not a or not b) else len(a & b) / len(a | b) |
|
|
|
def similarity_clusters(bags: Dict[str, List[Set[str]]], thr: float) -> Dict[Tuple[str,str], float]: |
|
"""Return {(modelA, modelB): score} for pairs with Jaccard β₯ *thr*.""" |
|
largest = {m: max(ts, key=len) for m, ts in bags.items() if ts} |
|
out: Dict[Tuple[str,str], float] = {} |
|
for m1, m2 in combinations(sorted(largest.keys()), 2): |
|
s = _jaccard(largest[m1], largest[m2]) |
|
if s >= thr: |
|
out[(m1, m2)] = s |
|
return out |
|
|
|
def embedding_similarity_clusters(models_root: Path, missing: List[str], thr: float) -> Dict[Tuple[str, str], float]: |
|
model = SentenceTransformer("nomic-ai/nomic-embed-code") |
|
model.max_seq_length = 4096 |
|
texts = {} |
|
|
|
for name in tqdm(missing, desc="Reading modeling files"): |
|
code = "" |
|
for py in (models_root / name).rglob("modeling_*.py"): |
|
try: |
|
code += _strip_source(py.read_text(encoding="utf-8")) + "\n" |
|
except Exception: |
|
continue |
|
texts[name] = code.strip() or " " |
|
|
|
names = list(texts) |
|
all_embeddings = [] |
|
|
|
print("Encoding embeddings...") |
|
batch_size = 8 |
|
for i in tqdm(range(0, len(names), batch_size), desc="Batches", leave=False): |
|
batch = [texts[n] for n in names[i:i+batch_size]] |
|
emb = model.encode(batch, convert_to_numpy=True, show_progress_bar=False) |
|
all_embeddings.append(emb) |
|
|
|
embeddings = np.vstack(all_embeddings) |
|
|
|
print("Computing pairwise similarities...") |
|
sims = embeddings @ embeddings.T |
|
|
|
out = {} |
|
for i in range(len(names)): |
|
for j in range(i + 1, len(names)): |
|
s = sims[i, j] |
|
if s >= thr: |
|
out[(names[i], names[j])] = float(s) |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def modular_files(models_root: Path) -> List[Path]: |
|
return [p for p in models_root.rglob("modular_*.py") if p.suffix == ".py"] |
|
|
|
def dependency_graph(modular_files: List[Path], models_root: Path) -> Dict[str, List[Dict[str,str]]]: |
|
"""Return {derived_model: [{source, imported_class}, ...]} |
|
|
|
Only `modeling_*` imports are kept; anything coming from configuration/processing/ |
|
image* utils is ignored so the visual graph focuses strictly on modelling code. |
|
Excludes edges to sources whose model name is not a model dir. |
|
""" |
|
model_names = {p.name for p in models_root.iterdir() if p.is_dir()} |
|
deps: Dict[str, List[Dict[str,str]]] = defaultdict(list) |
|
for fp in modular_files: |
|
derived = fp.parent.name |
|
try: |
|
tree = ast.parse(fp.read_text(encoding="utfβ8"), filename=str(fp)) |
|
except Exception as e: |
|
print(f"β οΈ AST parse failed for {fp}: {e}") |
|
continue |
|
for node in ast.walk(tree): |
|
if not isinstance(node, ast.ImportFrom) or not node.module: |
|
continue |
|
mod = node.module |
|
|
|
if ("modeling_" not in mod or |
|
"configuration_" in mod or |
|
"processing_" in mod or |
|
"image_processing" in mod or |
|
"modeling_attn_mask_utils" in mod): |
|
continue |
|
parts = re.split(r"[./]", mod) |
|
src = next((p for p in parts if p not in {"", "models", "transformers"}), "") |
|
if not src or src == derived or src not in model_names: |
|
continue |
|
for alias in node.names: |
|
deps[derived].append({"source": src, "imported_class": alias.name}) |
|
return dict(deps) |
|
|
|
|
|
|
|
|
|
def build_graph_json( |
|
transformers_dir: Path, |
|
threshold: float = SIM_DEFAULT, |
|
multimodal: bool = False, |
|
sim_method: str = "jaccard", |
|
) -> dict: |
|
"""Return the {nodes, links} dict that D3 needs.""" |
|
models_root = transformers_dir / "src/transformers/models" |
|
bags, pix_hits = build_token_bags(models_root) |
|
|
|
mod_files = modular_files(models_root) |
|
deps = dependency_graph(mod_files, models_root) |
|
|
|
models_with_modular = {p.parent.name for p in mod_files} |
|
missing = [m for m in bags if m not in models_with_modular] |
|
if multimodal: |
|
missing = [m for m in missing if pix_hits[m] >= PIXEL_MIN_HITS] |
|
|
|
if sim_method == "jaccard": |
|
sims = similarity_clusters({m: bags[m] for m in missing}, threshold) |
|
else: |
|
sims = embedding_similarity_clusters(models_root, missing, threshold) |
|
|
|
|
|
nodes: Set[str] = set() |
|
links: List[dict] = [] |
|
|
|
for drv, lst in deps.items(): |
|
for d in lst: |
|
links.append({ |
|
"source": d["source"], |
|
"target": drv, |
|
"label": f"{sum(1 for x in lst if x['source'] == d['source'])} imports", |
|
"cand": False |
|
}) |
|
nodes.update({d["source"], drv}) |
|
|
|
for (a, b), s in sims.items(): |
|
links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True}) |
|
nodes.update({a, b}) |
|
|
|
nodes.update(missing) |
|
|
|
deg = Counter() |
|
for lk in links: |
|
deg[lk["source"]] += 1 |
|
deg[lk["target"]] += 1 |
|
max_deg = max(deg.values() or [1]) |
|
|
|
targets = {lk["target"] for lk in links if not lk["cand"]} |
|
sources = {lk["source"] for lk in links if not lk["cand"]} |
|
missing_only = [m for m in missing if m not in sources and m not in targets] |
|
nodes.update(missing_only) |
|
|
|
nodelist = [] |
|
for n in sorted(nodes): |
|
if n in missing_only: |
|
cls = "cand" |
|
elif n in sources and n not in targets: |
|
cls = "base" |
|
else: |
|
cls = "derived" |
|
nodelist.append({"id": n, "cls": cls, "sz": 1 + 2*(deg[n]/max_deg)}) |
|
|
|
graph = {"nodes": nodelist, "links": links} |
|
return graph |
|
|
|
|
|
def generate_html(graph: dict) -> str: |
|
"""Return the full HTML string with inlined CSS/JS + graph JSON.""" |
|
js = JS.replace("__GRAPH_DATA__", json.dumps(graph, separators=(",", ":"))) |
|
return HTML.replace("__CSS__", CSS).replace("__JS__", js) |
|
|
|
|
|
|
|
|
|
|
|
|
|
CSS = """ |
|
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap'); |
|
:root { --base: 60px; } |
|
body { margin:0; font-family:'Inter',Arial,sans-serif; background:transparent; overflow:hidden; } |
|
svg { width:100vw; height:100vh; } |
|
.link { stroke:#999; stroke-opacity:.6; } |
|
.link.cand { stroke:#e63946; stroke-width:2.5; } |
|
.node-label { fill:#333; pointer-events:none; text-anchor:middle; font-weight:600; } |
|
.link-label { fill:#555; font-size:10px; pointer-events:none; text-anchor:middle; } |
|
.node.base path { fill:#ffbe0b; } |
|
.node.derived circle { fill:#1f77b4; } |
|
.node.cand circle, .node.cand path { fill:#e63946; } |
|
#legend { position:fixed; top:18px; left:18px; background:rgba(255,255,255,.92); padding:18px 28px; |
|
border-radius:10px; border:1.5px solid #bbb; font-size:18px; box-shadow:0 2px 8px rgba(0,0,0,.08); } |
|
""" |
|
|
|
JS = """ |
|
|
|
function updateVisibility() { |
|
const show = document.getElementById('toggleRed').checked; |
|
svg.selectAll('.link.cand').style('display', show ? null : 'none'); |
|
svg.selectAll('.node.cand').style('display', show ? null : 'none'); |
|
svg.selectAll('.link-label') |
|
.filter(d => d.cand) |
|
.style('display', show ? null : 'none'); |
|
} |
|
|
|
document.getElementById('toggleRed').addEventListener('change', updateVisibility); |
|
|
|
|
|
const graph = __GRAPH_DATA__; |
|
const W = innerWidth, H = innerHeight; |
|
const svg = d3.select('#dependency').call(d3.zoom().on('zoom', e => g.attr('transform', e.transform))); |
|
const g = svg.append('g'); |
|
|
|
const link = g.selectAll('line') |
|
.data(graph.links) |
|
.join('line') |
|
.attr('class', d => d.cand ? 'link cand' : 'link'); |
|
|
|
const linkLbl = g.selectAll('text.link-label') |
|
.data(graph.links) |
|
.join('text') |
|
.attr('class', 'link-label') |
|
.text(d => d.label); |
|
|
|
const node = g.selectAll('g.node') |
|
.data(graph.nodes) |
|
.join('g') |
|
.attr('class', d => `node ${d.cls}`) |
|
.call(d3.drag().on('start', dragStart).on('drag', dragged).on('end', dragEnd)); |
|
|
|
node.filter(d => d.cls==='base').append('image') |
|
.attr('xlink:href', 'hf-logo.svg').attr('x', -30).attr('y', -30).attr('width', 60).attr('height', 60); |
|
node.filter(d => d.cls!=='base').append('circle').attr('r', d => 20*d.sz); |
|
node.append('text').attr('class','node-label').attr('dy','-2.4em').text(d => d.id); |
|
|
|
const sim = d3.forceSimulation(graph.nodes) |
|
.force('link', d3.forceLink(graph.links).id(d => d.id).distance(520)) // tighter links |
|
.force('charge', d3.forceManyBody().strength(-600)) // weaker repulsion |
|
.force('center', d3.forceCenter(W / 2, H / 2)) |
|
.force('collide', d3.forceCollide(d => d.cls === 'base' ? 50 : 50)); // smaller bubble spacing |
|
|
|
|
|
sim.on('tick', () => { |
|
link.attr('x1', d=>d.source.x).attr('y1', d=>d.source.y) |
|
.attr('x2', d=>d.target.x).attr('y2', d=>d.target.y); |
|
linkLbl.attr('x', d=> (d.source.x+d.target.x)/2) |
|
.attr('y', d=> (d.source.y+d.target.y)/2); |
|
node.attr('transform', d=>`translate(${d.x},${d.y})`); |
|
}); |
|
|
|
function dragStart(e,d){ if(!e.active) sim.alphaTarget(.3).restart(); d.fx=d.x; d.fy=d.y; } |
|
function dragged(e,d){ d.fx=e.x; d.fy=e.y; } |
|
function dragEnd(e,d){ if(!e.active) sim.alphaTarget(0); d.fx=d.fy=null; } |
|
""" |
|
|
|
HTML = """ |
|
<!DOCTYPE html> |
|
<html lang='en'><head><meta charset='UTF-8'> |
|
<title>Transformers modular graph</title> |
|
<style>__CSS__</style></head><body> |
|
<div id='legend'> |
|
π‘ base<br>π΅ modular<br>π΄ candidate<br>red edgeΒ = high embedding similarity<br><br> |
|
<label><input type="checkbox" id="toggleRed" checked> Show candidates edges and nodes</label> |
|
</div> |
|
<svg id='dependency'></svg> |
|
<script src='https://d3js.org/d3.v7.min.js'></script> |
|
<script>__JS__</script></body></html> |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def write_html(graph_data: dict, path: Path): |
|
path.write_text(generate_html(graph_data), encoding="utf-8") |
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
ap = argparse.ArgumentParser(description="Visualise modular dependencies + candidates") |
|
ap.add_argument("transformers", help="Path to local π€ transformers repo root") |
|
ap.add_argument("--multimodal", action="store_true", help="filter to models with β₯3 'pixel_values'") |
|
ap.add_argument("--sim-threshold", type=float, default=SIM_DEFAULT) |
|
ap.add_argument("--out", default=HTML_DEFAULT) |
|
ap.add_argument("--sim-method", choices=["jaccard", "embedding"], default="jaccard", |
|
help="Similarity method: 'jaccard' or 'embedding'") |
|
args = ap.parse_args() |
|
|
|
graph = build_graph_json( |
|
transformers_dir=Path(args.transformers).expanduser().resolve(), |
|
threshold=args.sim_threshold, |
|
multimodal=args.multimodal, |
|
sim_method=args.sim_method, |
|
) |
|
write_html(graph, Path(args.out).expanduser()) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|