Ashley Andrea Squarcio commited on
Commit
2fc692a
·
1 Parent(s): 486389e

Initial import: code, dependencies, chunk_index.pkl (LFS tracked)

Browse files
Files changed (7) hide show
  1. app.py +87 -0
  2. chunk_index.pkl +3 -0
  3. dspy_wrapper.py +71 -0
  4. main.py +34 -0
  5. neo4j_config.py +111 -0
  6. requirements.txt +156 -0
  7. retriever.py +222 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from main import rag_pipeline
2
+ import gradio as gr
3
+ import html
4
+
5
+
6
+ municipalities = [
7
+ "Comun General de Fascia",
8
+ "Comune di Capo D'Orlando",
9
+ "Comune di Casatenovo",
10
+ "Comune di Fonte Nuova",
11
+ "Comune di Gubbio",
12
+ "Comune di Torre Santa Susanna",
13
+ "Comune di Santa Maria Capua Vetere"
14
+ ]
15
+
16
+ def answer_fn(query: str, municipality: str):
17
+ # "All" or empty count as no filter
18
+ filters = {}
19
+ if municipality and municipality != "All":
20
+ filters["municipality"] = municipality
21
+
22
+ output = rag_pipeline(query=query, municipality=municipality)
23
+
24
+ final_answer = output["final_answer"]
25
+ cot = output["chain_of_thought"]
26
+
27
+ chunks = output["retrieved_chunks"]
28
+ html_blocks = []
29
+ for i, c in enumerate(chunks, 1):
30
+ text = html.escape(c["chunk_text"])
31
+ meta = {
32
+ "Document ID": c.get("document_id", "N/A"),
33
+ "Municipality": c.get("municipality", "N/A"),
34
+ "Section": c.get("section", "N/A"),
35
+ "Page": c.get("page", "N/A"),
36
+ "Score": f"{c.get('final_score', 0):.4f}"
37
+ }
38
+
39
+ # transforms metadata dictionary into a series of <li> (list) items
40
+ meta_lines = "".join(f"<li><b>{k}</b>: {v}</li>" for k, v in meta.items())
41
+
42
+ # displays chunk number as header, metadata and text
43
+ block = f"""
44
+ <div style="margin-bottom:1em;">
45
+ <h4>Chunk {i}</h4>
46
+ <ul>{meta_lines}</ul>
47
+ <p style="white-space: pre-wrap;">{text}</p>
48
+ </div>
49
+ <hr/>
50
+ """
51
+ html_blocks.append(block)
52
+ chunks_html = "\n".join(html_blocks) or "<i>No chunks retrieved.</i>"
53
+
54
+ return final_answer, cot, chunks_html
55
+
56
+
57
+ with gr.Blocks() as demo:
58
+ gr.Markdown("## DSPy RAG Demo")
59
+
60
+ with gr.Row():
61
+ query_input = gr.Textbox(label="Question", placeholder="Type your query here…")
62
+ muni_input = gr.Dropdown(
63
+ choices=["All"] + municipalities,
64
+ value="All",
65
+ label="Municipality (optional)"
66
+ )
67
+
68
+ run_btn = gr.Button("Get Answer")
69
+
70
+ ans_out = gr.Textbox(label="Final Answer", lines=3) # answer container
71
+
72
+ # CoT container inside its accordion
73
+ with gr.Accordion("Chain of Thought Reasoning", open=False):
74
+ cot_txt = gr.Textbox(label="", interactive=False, lines=6)
75
+
76
+ # chunks HTML container inside its accordion
77
+ with gr.Accordion("Retrieved Chunks with Metadata", open=False):
78
+ chunks_html = gr.HTML("<i>No data yet.</i>")
79
+
80
+ # Wire the button click to the function (with outputs matching the order of returned values)
81
+ run_btn.click(
82
+ fn=answer_fn,
83
+ inputs=[query_input, muni_input],
84
+ outputs=[ans_out, cot_txt, chunks_html]
85
+ )
86
+
87
+ demo.launch(share=True)
chunk_index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5962e522e6557c69c68ba22fc4d70487fc34b83cc6875be11b5e3564b3a38de1
3
+ size 17549445
dspy_wrapper.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dspy
2
+ from typing import List, Dict
3
+ import os
4
+
5
+
6
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
7
+ if not OPENAI_API_KEY:
8
+ raise RuntimeError("Missing OPENAI_API_KEY env var")
9
+
10
+ gpt_4o_mini = dspy.LM('openai/gpt-4o-mini', api_key=OPENAI_API_KEY)
11
+ # using unimib credentials, switch to PeS if needed!
12
+ dspy.configure(lm=gpt_4o_mini)
13
+
14
+
15
+ # == Building Blocks ==
16
+ class DSPyHybridRetriever(dspy.Module):
17
+ def __init__(self, retriever):
18
+ super().__init__()
19
+ self.retriever = retriever
20
+
21
+ def forward(self, query: str, municipality: str = "", top_k: int = 5):
22
+ results = self.retriever.retrieve(query, top_k=top_k, municipality_filter=municipality) # remember to change to rerank
23
+ return {"retrieved_chunks": results}
24
+
25
+ class RetrieveChunks(dspy.Signature):
26
+ """Given a user query and optional municipality, retrieve relevant text chunks."""
27
+ query = dspy.InputField(desc="User's question")
28
+ municipality = dspy.InputField(desc="Optional municipality filter")
29
+ retrieved_chunks = dspy.OutputField(
30
+ desc=(
31
+ "List of retrieved chunks, each as a dict with keys: "
32
+ "`chunk`, `document_id`, `section`, `level`, `page`, "
33
+ "`dense_score`, `sparse_score`, `graph_score`, `final_score`"
34
+ ),
35
+ type=List[Dict] # each item is a dict carrying all those fields
36
+ )
37
+
38
+ class AnswerWithEvidence(dspy.Signature):
39
+ """Answer the query using reasoning and retrieved chunks as context."""
40
+ query = dspy.InputField(desc="User's question")
41
+ retrieved_chunks = dspy.InputField(desc="Retrieved text chunks (List[dict])")
42
+ answer = dspy.OutputField(desc="Final answer")
43
+ rationale = dspy.OutputField(desc="Chain-of-thought reasoning")
44
+
45
+
46
+ # == RAG Pipeline ==
47
+ class RAGChain(dspy.Module):
48
+ def __init__(self, retriever, answerer):
49
+ super().__init__()
50
+ self.retriever = retriever
51
+ self.answerer = answerer
52
+
53
+ def forward(self, query: str, municipality: str = ""):
54
+ # retrieve full dicts
55
+ retrieved = self.retriever(query=query, municipality=municipality)
56
+ chunks = retrieved["retrieved_chunks"]
57
+
58
+ # feed only the raw text into the CoT module
59
+ answer_result = self.answerer(
60
+ query=query,
61
+ retrieved_chunks=[c["chunk_text"] for c in chunks]
62
+ )
63
+
64
+ # return both the metadata and the LLM answer
65
+ return {
66
+ "query": query,
67
+ "municipality": municipality,
68
+ "retrieved_chunks": chunks,
69
+ "chain_of_thought": answer_result.rationale,
70
+ "final_answer": answer_result.answer,
71
+ }
main.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from neo4j_config import URI, USER, PASSWORD, AUTH
2
+ from retriever import *
3
+ from dspy_wrapper import *
4
+ from neo4j import GraphDatabase
5
+ import os, pickle
6
+
7
+
8
+ # == Fast Load of Precomputed Index ==
9
+ HERE = os.path.dirname(__file__)
10
+ with open(os.path.join(HERE, "chunk_index.pkl"), "rb") as f:
11
+ all_chunks = pickle.load(f) # already contain embeddings and ids
12
+
13
+
14
+ # == Neo4j Setup ==
15
+ with GraphDatabase.driver(URI, auth=AUTH) as driver:
16
+ driver.verify_connectivity()
17
+
18
+
19
+ # == Retrieval ==
20
+ retriever = HybridRetriever(all_chunks)
21
+
22
+ reranker = GraphReranker(
23
+ retriever,
24
+ neo4j_uri=URI,
25
+ neo4j_user=USER,
26
+ neo4j_pass=PASSWORD,
27
+ beta=0.2,
28
+ max_hops=3
29
+ )
30
+
31
+ # == Pipeline Initialization ==
32
+ retriever_module = DSPyHybridRetriever(retriever)
33
+ cot_module = dspy.ChainOfThought(AnswerWithEvidence)
34
+ rag_pipeline = RAGChain(retriever=retriever_module, answerer=cot_module)
neo4j_config.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from neo4j import GraphDatabase
3
+
4
+
5
+ URI = "neo4j+s://1ea442ce.databases.neo4j.io"
6
+ USER = "neo4j"
7
+ PASSWORD = "diGxvEhJnqcp18rHDwPzGv1KaRvxprUvdD1h31unwa8"
8
+ AUTH = (USER, PASSWORD)
9
+
10
+ with GraphDatabase.driver(URI, auth=AUTH) as driver:
11
+ driver.verify_connectivity()
12
+
13
+
14
+ def normalize_int(value, default=0):
15
+ """
16
+ Safely convert value to int.
17
+ - If already int, return it.
18
+ - If str of digits, parse it.
19
+ - Otherwise return `default`.
20
+ """
21
+ if isinstance(value, int):
22
+ return value
23
+ if isinstance(value, str) and value.isdigit():
24
+ return int(value)
25
+ # optionally, extract digits from strings like "1.":
26
+ m = re.match(r"(\d+)", str(value))
27
+ if m:
28
+ return int(m.group(1))
29
+ return default
30
+
31
+
32
+ def add_municipality(tx, municipality):
33
+ tx.run("""
34
+ MERGE (m:Municipality {name: $municipality})
35
+ """, municipality=municipality)
36
+
37
+ def add_document(tx, doc_id, municipality):
38
+ tx.run("""
39
+ MATCH (m:Municipality {name: $municipality})
40
+ MERGE (d:Document {doc_id: $doc_id})
41
+ MERGE (m)-[:HAS_DOCUMENT]->(d)
42
+ """, municipality=municipality, doc_id=doc_id)
43
+
44
+ def add_chunk(tx, chunk):
45
+ tx.run("""
46
+ MATCH (d:Document {doc_id: $doc_id})
47
+ MERGE (c:Chunk {id: $id})
48
+ SET c.page = $page,
49
+ c.section = $section,
50
+ c.level = $level,
51
+ c.text = $text,
52
+ c.embedding = $embedding
53
+ MERGE (d)-[:HAS_CHUNK]->(c)
54
+ """, id=chunk["id"], doc_id=chunk["document_id"],
55
+ page=chunk["page"], section=chunk["section"],
56
+ level=chunk["level"], text=chunk["chunk_text"],
57
+ embedding=chunk["embedding"])
58
+
59
+ def link_parent(tx, parent_id, child_id):
60
+ tx.run("""
61
+ MATCH (p:Chunk {id: $parent_id}), (c:Chunk {id: $child_id})
62
+ MERGE (p)-[:HAS_SUBSECTION]->(c)
63
+ """, parent_id=parent_id, child_id=child_id)
64
+
65
+ def link_sibling(tx, sibling1_id, sibling2_id):
66
+ tx.run("""
67
+ MATCH (c1:Chunk {id: $sibling1_id}), (c2:Chunk {id: $sibling2_id})
68
+ MERGE (c1)-[:NEXT_TO]->(c2)
69
+ """, sibling1_id=sibling1_id, sibling2_id=sibling2_id)
70
+
71
+ # takes again quite some time to compute, we could re-download a pkl file with ids as well
72
+ def sync_chunk_ids(all_chunks, driver, prefix_len=50):
73
+ """
74
+ For each chunk in-memory, look up its real DB id by matching on:
75
+ - page
76
+ - section
77
+ - the first `prefix_len` chars of text
78
+
79
+ If already present, overwrites chunk["id"] with the DB value when found,
80
+ otherwise retrieves the id from the graph db and adds it to each chunk's dict.
81
+ """
82
+ with driver.session() as session:
83
+ for chunk in all_chunks:
84
+ # build prefix of the chunk text
85
+ prefix = chunk["chunk_text"][:prefix_len]
86
+ # normalize numeric props
87
+ page = normalize_int(chunk.get("page"))
88
+
89
+ cypher = """
90
+ MATCH (c:Chunk {
91
+ page: $page,
92
+ section: $section
93
+ })
94
+ WHERE c.text STARTS WITH $prefix
95
+ RETURN c.id AS real_id
96
+ LIMIT 1
97
+ """
98
+ params = {
99
+ "page": page,
100
+ "section": chunk["section"],
101
+ "prefix": prefix
102
+ }
103
+
104
+ rec = session.run(cypher, params).single()
105
+ if rec:
106
+ chunk["id"] = rec["real_id"]
107
+ else:
108
+ print(f"No DB match for chunk: page={page} "
109
+ f"section={chunk.get('section')!r} prefix={prefix!r}")
110
+
111
+ ## CHUNK INGESTION CODE NOT PRESENT HERE!! CHECK COLAB NB!
requirements.txt ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.11.18
4
+ aiosignal==1.3.2
5
+ alembic==1.15.2
6
+ altair==5.5.0
7
+ annotated-types==0.7.0
8
+ anyio==4.9.0
9
+ appnope==0.1.4
10
+ asttokens==3.0.0
11
+ asyncer==0.0.8
12
+ attrs==25.3.0
13
+ backoff==2.2.1
14
+ blinker==1.9.0
15
+ cachetools==5.5.2
16
+ certifi==2025.4.26
17
+ cffi==1.17.1
18
+ charset-normalizer==3.4.2
19
+ click==8.1.8
20
+ cloudpickle==3.1.1
21
+ colorlog==6.9.0
22
+ comm==0.2.2
23
+ cryptography==44.0.3
24
+ datasets==3.5.1
25
+ debugpy==1.8.14
26
+ decorator==5.2.1
27
+ dill==0.3.8
28
+ diskcache==5.6.3
29
+ distro==1.9.0
30
+ dspy==2.6.14
31
+ dspy-ai==2.6.23
32
+ executing==2.2.0
33
+ faiss-cpu==1.11.0
34
+ fastapi==0.115.12
35
+ ffmpy==0.5.0
36
+ filelock==3.18.0
37
+ frozenlist==1.6.0
38
+ fsspec==2025.3.0
39
+ gitdb==4.0.12
40
+ GitPython==3.1.44
41
+ gradio==5.30.0
42
+ gradio_client==1.10.1
43
+ groovy==0.1.2
44
+ h11==0.16.0
45
+ httpcore==1.0.9
46
+ httpx==0.28.1
47
+ huggingface-hub==0.30.2
48
+ idna==3.10
49
+ importlib_metadata==8.7.0
50
+ ipykernel==6.29.5
51
+ ipython>=8,<9
52
+ ipython_pygments_lexers==1.1.1
53
+ jedi==0.19.2
54
+ Jinja2==3.1.6
55
+ jiter==0.9.0
56
+ joblib==1.5.0
57
+ json_repair==0.44.1
58
+ jsonschema==4.23.0
59
+ jsonschema-specifications==2025.4.1
60
+ jupyter_client==8.6.3
61
+ jupyter_core==5.7.2
62
+ litellm==1.63.7
63
+ magicattr==0.1.6
64
+ Mako==1.3.10
65
+ markdown-it-py==3.0.0
66
+ MarkupSafe==3.0.2
67
+ matplotlib-inline==0.1.7
68
+ mdurl==0.1.2
69
+ mpmath==1.3.0
70
+ multidict==6.4.3
71
+ multiprocess==0.70.16
72
+ narwhals==1.38.0
73
+ neo4j==5.28.1
74
+ nest-asyncio==1.6.0
75
+ networkx==3.4.2
76
+ numpy==2.2.5
77
+ openai==1.61.0
78
+ optuna==4.3.0
79
+ orjson==3.10.18
80
+ packaging==24.2
81
+ pandas==2.2.3
82
+ parso==0.8.4
83
+ pdfminer.six==20250327
84
+ pdfplumber==0.11.6
85
+ pexpect==4.9.0
86
+ pillow==11.2.1
87
+ platformdirs==4.3.7
88
+ prompt_toolkit==3.0.51
89
+ propcache==0.3.1
90
+ protobuf==6.30.2
91
+ psutil==7.0.0
92
+ ptyprocess==0.7.0
93
+ pure_eval==0.2.3
94
+ pyarrow==20.0.0
95
+ pycparser==2.22
96
+ pydantic==2.11.4
97
+ pydantic_core==2.33.2
98
+ pydeck==0.9.1
99
+ pydub==0.25.1
100
+ Pygments==2.19.1
101
+ PyMuPDF==1.25.5
102
+ PyPDF2==3.0.1
103
+ pypdfium2==4.30.1
104
+ python-dateutil==2.9.0.post0
105
+ python-dotenv==1.1.0
106
+ python-multipart==0.0.20
107
+ pytz==2025.2
108
+ PyYAML==6.0.2
109
+ pyzmq==26.4.0
110
+ RapidFuzz==3.13.0
111
+ referencing==0.36.2
112
+ regex==2024.11.6
113
+ requests==2.32.3
114
+ rich==14.0.0
115
+ rpds-py==0.24.0
116
+ ruff==0.11.10
117
+ safehttpx==0.1.6
118
+ safetensors==0.5.3
119
+ scikit-learn==1.6.1
120
+ scipy==1.15.2
121
+ semantic-version==2.10.0
122
+ sentence-transformers==4.1.0
123
+ setuptools==80.3.1
124
+ shellingham==1.5.4
125
+ six==1.17.0
126
+ sklearn-preprocessing==0.1.0
127
+ smmap==5.0.2
128
+ sniffio==1.3.1
129
+ SQLAlchemy==2.0.40
130
+ stack-data==0.6.3
131
+ starlette==0.46.2
132
+ streamlit==1.45.0
133
+ sympy==1.14.0
134
+ tenacity==9.1.2
135
+ threadpoolctl==3.6.0
136
+ tiktoken==0.9.0
137
+ tokenizers==0.21.1
138
+ toml==0.10.2
139
+ tomlkit==0.13.2
140
+ torch==2.7.0
141
+ tornado==6.4.2
142
+ tqdm==4.67.1
143
+ traitlets==5.14.3
144
+ transformers==4.51.3
145
+ typer==0.15.4
146
+ typing-inspection==0.4.0
147
+ typing_extensions==4.13.2
148
+ tzdata==2025.2
149
+ ujson==5.10.0
150
+ urllib3==2.4.0
151
+ uvicorn==0.34.2
152
+ wcwidth==0.2.13
153
+ websockets==15.0.1
154
+ xxhash==3.5.0
155
+ yarl==1.20.0
156
+ zipp==3.21.0
retriever.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import re
3
+ import numpy as np
4
+ from typing import List, Dict
5
+ from sklearn.preprocessing import normalize
6
+ from sentence_transformers import SentenceTransformer
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+ from rapidfuzz import fuzz
10
+ from neo4j import GraphDatabase
11
+
12
+
13
+ embedding_model = SentenceTransformer('nickprock/multi-sentence-BERTino')
14
+ embedding_dim = embedding_model.get_sentence_embedding_dimension()
15
+
16
+ # == Base Hybrid Retriever (Dense + Sparse) ==
17
+ class HybridRetriever:
18
+ def __init__(self, chunks: List[Dict]):
19
+ self.raw_chunks = chunks
20
+
21
+ # load precomputed embeddings (with bertino)
22
+ print("Loading dense embeddings from chunks...")
23
+ dense_embeddings = np.vstack([
24
+ chunk["embedding"] for chunk in chunks
25
+ ])
26
+ self.embeddings = normalize(dense_embeddings, axis=1, norm='l2') # l2 normalization for cosine similarity
27
+
28
+ print("Fitting FAISS index...")
29
+ self.index = faiss.IndexFlatIP(self.embeddings.shape[1])
30
+ self.index.add(self.embeddings)
31
+
32
+ print("Fitting TF-IDF vectorizer...")
33
+ self.texts = [chunk["chunk_text"] for chunk in chunks]
34
+ self.vectorizer = TfidfVectorizer(stop_words=None)
35
+ self.sparse_matrix = self.vectorizer.fit_transform(self.texts)
36
+
37
+ # this is just a temporary hard-coded fix for correctly matching this municipality in the UI retrieval,
38
+ # since ' is still not recognized in municipality extraction
39
+ self._overrides = {
40
+ "capo d orlando": "capo d"
41
+ }
42
+
43
+ def fuzzy_match(self, municipality_ref, user_input, threshold=80):
44
+ """
45
+ Fuzzy-matches stored municipality against user input.
46
+ Uses rapidfuzz for accurate and fast matching.
47
+ """
48
+ # Normalize both strings
49
+ def normalize(text):
50
+ text = text.lower()
51
+ # remove common prefixes
52
+ for prefix in ("comune di ", "comuni di ", "città di ", "citta di "):
53
+ if text.startswith(prefix):
54
+ text = text[len(prefix):]
55
+ # strip out non-alphanumeric characters
56
+ text = re.sub(r"[^a-z0-9]+", " ", text)
57
+ return re.sub(r"\s+", " ", text).strip()
58
+
59
+ ref = normalize(municipality_ref)
60
+ inp = normalize(user_input)
61
+
62
+ # override check: if the normalized input matches a key, force match
63
+ if inp in self._overrides:
64
+ # normalize the override target and compare to this ref
65
+ override_target = normalize(self._overrides[inp])
66
+ if ref == override_target:
67
+ return True
68
+
69
+ # Compute fuzzy match score
70
+ score = fuzz.ratio(ref, inp)
71
+
72
+ return score >= threshold
73
+
74
+ def retrieve(self, query, top_k=5, municipality_filter=None, alpha=0.8, threshold=0.3):
75
+ print("Starting hybrid retrieval...")
76
+
77
+ query_embedding = embedding_model.encode([query], convert_to_numpy=True)
78
+ query_embedding = normalize(query_embedding, axis=1, norm='l2') # query also needs l2 normalization
79
+
80
+ query_tfidf = self.vectorizer.transform([query])
81
+ sparse_sim = cosine_similarity(query_tfidf, self.sparse_matrix).flatten()
82
+
83
+ D, I = self.index.search(query_embedding, top_k * 5)
84
+
85
+ results = []
86
+ seen = set()
87
+ for rank, idx in enumerate(I[0]):
88
+ if idx in seen:
89
+ continue
90
+ seen.add(idx)
91
+
92
+ # Enforce municipality constraint
93
+ if municipality_filter:
94
+ stored = self.raw_chunks[idx].get("municipality", "")
95
+ if not self.fuzzy_match(stored, municipality_filter):
96
+ continue
97
+
98
+ dense_score = D[0][rank]
99
+ sparse_score = sparse_sim[idx]
100
+ hybrid_score = alpha * dense_score + (1 - alpha) * sparse_score
101
+
102
+ if hybrid_score < threshold:
103
+ continue
104
+
105
+ results.append({
106
+ **self.raw_chunks[idx], # appends all chunk properties to the output (including id and embedding!)
107
+ "dense_score": dense_score,
108
+ "sparse_score": sparse_score,
109
+ "hybrid_score": hybrid_score
110
+ })
111
+
112
+ if len(results) >= top_k:
113
+ break
114
+
115
+ # Fallback if nothing survived the threshold/filtering
116
+ if not results:
117
+ print(f"No results above threshold for '{municipality_filter}'. Falling back to top-{top_k}.")
118
+ results = []
119
+ for rank, idx in enumerate(I[0][:top_k]):
120
+ dense_score = D[0][rank]
121
+ sparse_score = sparse_sim[idx]
122
+ hybrid_score = alpha * dense_score + (1 - alpha) * sparse_score
123
+ results.append({
124
+ **self.raw_chunks[idx],
125
+ "dense_score": dense_score,
126
+ "sparse_score": sparse_score,
127
+ "hybrid_score": hybrid_score
128
+ })
129
+
130
+ # Return sorted top-k
131
+ top_results = sorted(results, key=lambda x: x["hybrid_score"], reverse=True)[:top_k]
132
+ return top_results
133
+
134
+
135
+ # == Graph Reranker ==
136
+ class GraphReranker:
137
+ def __init__(
138
+ self,
139
+ base_retriever,
140
+ neo4j_uri: str,
141
+ neo4j_user: str,
142
+ neo4j_pass: str,
143
+ beta: float = 0.2,
144
+ max_hops: int = 3
145
+ ):
146
+ """
147
+ base_retriever: instance of HybridRetriever
148
+ beta: weight for the graph component
149
+ max_hops: how many hops we'll search in shortestPath()
150
+ """
151
+ self.base = base_retriever
152
+ self.beta = beta
153
+ self.max_hops = max_hops
154
+ self.driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_pass))
155
+
156
+ def graph_score(self, candidate_id: str, seed_ids: list[str]) -> float:
157
+ """
158
+ Uses built-in shortestPath() over HAS_SUBSECTION|NEXT_TO
159
+ with a literal max_hops in the relationship pattern.
160
+ """
161
+ rel_pat = f"[:HAS_SUBSECTION|NEXT_TO*..{self.max_hops}]"
162
+ query = f"""
163
+ MATCH p = shortestPath(
164
+ (seed:Chunk {{id: $seed_id}})-{rel_pat}-(cand:Chunk {{id: $cand_id}})
165
+ )
166
+ RETURN length(p) AS hops
167
+ """
168
+ min_hops = None
169
+ with self.driver.session() as sess:
170
+ for sid in seed_ids:
171
+ # Skip self, don’t count seed==candidate
172
+ if sid == candidate_id:
173
+ continue
174
+
175
+ rec = sess.run(
176
+ query,
177
+ {"seed_id": sid, "cand_id": candidate_id}
178
+ ).single()
179
+ if rec and rec["hops"] is not None:
180
+ h = rec["hops"]
181
+ if min_hops is None or h < min_hops:
182
+ min_hops = h
183
+
184
+ # If no path found to any OTHER seed, score is 0
185
+ return 0.0 if min_hops is None else 1.0 / (1 + min_hops)
186
+
187
+ def rerank(
188
+ self,
189
+ query: str,
190
+ top_k: int = 5,
191
+ municipality_filter: str | None = None,
192
+ alpha: float = 0.8,
193
+ threshold: float = 0.3
194
+ ) -> list[dict]:
195
+ """
196
+ 1. Pull a broader set of text-only candidates
197
+ 2. Compute graph_score for each
198
+ 3. Blend and return the top_k final results
199
+ """
200
+ raw_cands = self.base.retrieve(
201
+ query,
202
+ top_k=top_k * 5, # more candidates than top_k to give material to the graph
203
+ municipality_filter=municipality_filter,
204
+ alpha=alpha,
205
+ threshold=threshold
206
+ )
207
+
208
+ # extract seed IDs from the first top_k (best text‐only hits)
209
+ seed_ids = [c["id"] for c in raw_cands[:top_k]]
210
+
211
+ # compute graph_score for each candidate
212
+ enriched = []
213
+ for cand in raw_cands:
214
+ g = self.graph_score(cand["id"], seed_ids)
215
+ basic = cand.get("hybrid_score",
216
+ alpha * cand["dense_score"] + (1-alpha) * cand["sparse_score"])
217
+ final = basic + self.beta * g
218
+ enriched.append({ **cand,
219
+ "graph_score": g,
220
+ "final_score": final })
221
+
222
+ return sorted(enriched, key=lambda x: x["final_score"], reverse=True)[:top_k]