Spaces:
Running
on
L4
Running
on
L4
Upload folder using huggingface_hub
Browse files- .DS_Store +0 -0
- .env.example +2 -1
- .gitignore +3 -1
- README.md +20 -1
- backend/colpali.py +25 -7
- frontend/app.py +46 -13
- frontend/layout.py +95 -14
- globals.css +22 -1
- main.py +106 -19
- output.css +77 -1
- prepare_feed_deploy.py +977 -0
- pyproject.toml +10 -1
- static/.DS_Store +0 -0
- uv.lock +0 -0
.DS_Store
CHANGED
|
Binary files a/.DS_Store and b/.DS_Store differ
|
|
|
.env.example
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
VESPA_APP_URL=https://abcde.z.vespa-app.cloud
|
| 2 |
HF_TOKEN=hf_xxxxxxxxxx
|
| 3 |
-
VESPA_CLOUD_SECRET_TOKEN=vespa_cloud_xxxxxxxx
|
|
|
|
|
|
| 1 |
VESPA_APP_URL=https://abcde.z.vespa-app.cloud
|
| 2 |
HF_TOKEN=hf_xxxxxxxxxx
|
| 3 |
+
VESPA_CLOUD_SECRET_TOKEN=vespa_cloud_xxxxxxxx
|
| 4 |
+
GEMINI_API_KEY=
|
.gitignore
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
.sesskey
|
| 2 |
.venv/
|
| 3 |
__pycache__/
|
|
|
|
| 4 |
.python-version
|
| 5 |
.env
|
| 6 |
template/
|
| 7 |
*.json
|
| 8 |
-
output/
|
|
|
|
|
|
| 1 |
.sesskey
|
| 2 |
.venv/
|
| 3 |
__pycache__/
|
| 4 |
+
ipynb_checkpoints/
|
| 5 |
.python-version
|
| 6 |
.env
|
| 7 |
template/
|
| 8 |
*.json
|
| 9 |
+
output/
|
| 10 |
+
pdfs/
|
README.md
CHANGED
|
@@ -27,7 +27,7 @@ preload_from_hub:
|
|
| 27 |
|
| 28 |
# Visual Retrieval ColPali
|
| 29 |
|
| 30 |
-
#
|
| 31 |
|
| 32 |
First, install `uv`:
|
| 33 |
|
|
@@ -35,6 +35,25 @@ First, install `uv`:
|
|
| 35 |
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 36 |
```
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
Then, in this directory, run:
|
| 39 |
|
| 40 |
```bash
|
|
|
|
| 27 |
|
| 28 |
# Visual Retrieval ColPali
|
| 29 |
|
| 30 |
+
# Prepare data and Vespa application
|
| 31 |
|
| 32 |
First, install `uv`:
|
| 33 |
|
|
|
|
| 35 |
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 36 |
```
|
| 37 |
|
| 38 |
+
Then, run:
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
uv sync --extra dev --extra feed
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
Convert the `prepare_feed_deploy.py` to notebook to:
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
jupytext --to notebook prepare_feed_deploy.py
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
And launch a Jupyter instance, see https://docs.astral.sh/uv/guides/integration/jupyter/ for recommended approach.
|
| 51 |
+
|
| 52 |
+
Open and follow the `prepare_feed_deploy.ipynb` notebook to prepare the data and deploy the Vespa application.
|
| 53 |
+
|
| 54 |
+
# Developing on the web app
|
| 55 |
+
|
| 56 |
+
|
| 57 |
Then, in this directory, run:
|
| 58 |
|
| 59 |
```bash
|
backend/colpali.py
CHANGED
|
@@ -170,13 +170,13 @@ def gen_similarity_maps(
|
|
| 170 |
if vespa_sim_maps:
|
| 171 |
print("Using provided similarity maps")
|
| 172 |
# A sim map looks like this:
|
| 173 |
-
# "
|
| 174 |
# {
|
| 175 |
# "address": {
|
| 176 |
# "patch": "0",
|
| 177 |
# "querytoken": "0"
|
| 178 |
# },
|
| 179 |
-
# "value":
|
| 180 |
# },
|
| 181 |
# ... and so on.
|
| 182 |
# Now turn these into a tensor of same shape as previous similarity map
|
|
@@ -189,7 +189,7 @@ def gen_similarity_maps(
|
|
| 189 |
)
|
| 190 |
)
|
| 191 |
for idx, vespa_sim_map in enumerate(vespa_sim_maps):
|
| 192 |
-
for cell in vespa_sim_map["
|
| 193 |
patch = int(cell["address"]["patch"])
|
| 194 |
# if dummy model then just use 1024 as the image_seq_length
|
| 195 |
|
|
@@ -359,7 +359,7 @@ async def query_vespa_default(
|
|
| 359 |
start = time.perf_counter()
|
| 360 |
response: VespaQueryResponse = await session.query(
|
| 361 |
body={
|
| 362 |
-
"yql": "select id,title,url,
|
| 363 |
"ranking": "default",
|
| 364 |
"query": query,
|
| 365 |
"timeout": timeout,
|
|
@@ -392,7 +392,7 @@ async def query_vespa_bm25(
|
|
| 392 |
start = time.perf_counter()
|
| 393 |
response: VespaQueryResponse = await session.query(
|
| 394 |
body={
|
| 395 |
-
"yql": "select id,title,url,
|
| 396 |
"ranking": "bm25",
|
| 397 |
"query": query,
|
| 398 |
"timeout": timeout,
|
|
@@ -472,7 +472,7 @@ async def query_vespa_nearest_neighbor(
|
|
| 472 |
**query_tensors,
|
| 473 |
"presentation.timing": True,
|
| 474 |
# if we use rank({nn_string}, userQuery()), dynamic summary doesn't work, see https://github.com/vespa-engine/vespa/issues/28704
|
| 475 |
-
"yql": f"select id,title,snippet,text,url,
|
| 476 |
"ranking.profile": "retrieval-and-rerank",
|
| 477 |
"timeout": timeout,
|
| 478 |
"hits": hits,
|
|
@@ -492,6 +492,24 @@ def is_special_token(token: str) -> bool:
|
|
| 492 |
return True
|
| 493 |
return False
|
| 494 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
|
| 496 |
async def get_result_from_query(
|
| 497 |
app: Vespa,
|
|
@@ -538,7 +556,7 @@ def add_sim_maps_to_result(
|
|
| 538 |
imgs: List[str] = []
|
| 539 |
vespa_sim_maps: List[str] = []
|
| 540 |
for single_result in result["root"]["children"]:
|
| 541 |
-
img = single_result["fields"]["
|
| 542 |
if img:
|
| 543 |
imgs.append(img)
|
| 544 |
vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
|
|
|
|
| 170 |
if vespa_sim_maps:
|
| 171 |
print("Using provided similarity maps")
|
| 172 |
# A sim map looks like this:
|
| 173 |
+
# "quantized": [
|
| 174 |
# {
|
| 175 |
# "address": {
|
| 176 |
# "patch": "0",
|
| 177 |
# "querytoken": "0"
|
| 178 |
# },
|
| 179 |
+
# "value": 12, # score in range [-128, 127]
|
| 180 |
# },
|
| 181 |
# ... and so on.
|
| 182 |
# Now turn these into a tensor of same shape as previous similarity map
|
|
|
|
| 189 |
)
|
| 190 |
)
|
| 191 |
for idx, vespa_sim_map in enumerate(vespa_sim_maps):
|
| 192 |
+
for cell in vespa_sim_map["quantized"]["cells"]:
|
| 193 |
patch = int(cell["address"]["patch"])
|
| 194 |
# if dummy model then just use 1024 as the image_seq_length
|
| 195 |
|
|
|
|
| 359 |
start = time.perf_counter()
|
| 360 |
response: VespaQueryResponse = await session.query(
|
| 361 |
body={
|
| 362 |
+
"yql": "select id,title,url,blur_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
|
| 363 |
"ranking": "default",
|
| 364 |
"query": query,
|
| 365 |
"timeout": timeout,
|
|
|
|
| 392 |
start = time.perf_counter()
|
| 393 |
response: VespaQueryResponse = await session.query(
|
| 394 |
body={
|
| 395 |
+
"yql": "select id,title,url,blur_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
|
| 396 |
"ranking": "bm25",
|
| 397 |
"query": query,
|
| 398 |
"timeout": timeout,
|
|
|
|
| 472 |
**query_tensors,
|
| 473 |
"presentation.timing": True,
|
| 474 |
# if we use rank({nn_string}, userQuery()), dynamic summary doesn't work, see https://github.com/vespa-engine/vespa/issues/28704
|
| 475 |
+
"yql": f"select id,title,snippet,text,url,blur_image,page_number,summaryfeatures from pdf_page where {nn_string} or userQuery()",
|
| 476 |
"ranking.profile": "retrieval-and-rerank",
|
| 477 |
"timeout": timeout,
|
| 478 |
"hits": hits,
|
|
|
|
| 492 |
return True
|
| 493 |
return False
|
| 494 |
|
| 495 |
+
async def get_full_image_from_vespa(
|
| 496 |
+
app: Vespa,
|
| 497 |
+
id: str) -> str:
|
| 498 |
+
async with app.asyncio(connections=1, total_timeout=120) as session:
|
| 499 |
+
start = time.perf_counter()
|
| 500 |
+
response: VespaQueryResponse = await session.query(
|
| 501 |
+
body={
|
| 502 |
+
"yql": f"select full_image from pdf_page where id contains \"{id}\"",
|
| 503 |
+
"ranking": "unranked",
|
| 504 |
+
"presentation.timing": True,
|
| 505 |
+
},
|
| 506 |
+
)
|
| 507 |
+
assert response.is_successful(), response.json
|
| 508 |
+
stop = time.perf_counter()
|
| 509 |
+
print(
|
| 510 |
+
f"Getting image from Vespa took: {stop - start} s, vespa said searchtime was {response.json.get('timing', {}).get('searchtime', -1)} s"
|
| 511 |
+
)
|
| 512 |
+
return response.json["root"]["children"][0]["fields"]["full_image"]
|
| 513 |
|
| 514 |
async def get_result_from_query(
|
| 515 |
app: Vespa,
|
|
|
|
| 556 |
imgs: List[str] = []
|
| 557 |
vespa_sim_maps: List[str] = []
|
| 558 |
for single_result in result["root"]["children"]:
|
| 559 |
+
img = single_result["fields"]["blur_image"]
|
| 560 |
if img:
|
| 561 |
imgs.append(img)
|
| 562 |
vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
|
frontend/app.py
CHANGED
|
@@ -131,9 +131,13 @@ def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"):
|
|
| 131 |
|
| 132 |
def SampleQueries():
|
| 133 |
sample_queries = [
|
| 134 |
-
"
|
| 135 |
-
"
|
| 136 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
]
|
| 138 |
|
| 139 |
query_badges = []
|
|
@@ -193,21 +197,23 @@ def Search(request, search_results=[]):
|
|
| 193 |
)
|
| 194 |
return Div(
|
| 195 |
Div(
|
| 196 |
-
SearchBox(query_value=query_value, ranking_value=ranking_value),
|
| 197 |
Div(
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
),
|
| 201 |
cls="grid",
|
| 202 |
),
|
| 203 |
-
cls="grid",
|
| 204 |
)
|
| 205 |
|
| 206 |
|
| 207 |
-
def LoadingMessage():
|
| 208 |
return Div(
|
| 209 |
Lucide(icon="loader-circle", cls="size-5 mr-1.5 animate-spin"),
|
| 210 |
-
Span(
|
| 211 |
cls="p-10 text-muted-foreground flex items-center justify-center",
|
| 212 |
id="loading-indicator",
|
| 213 |
)
|
|
@@ -250,7 +256,7 @@ def SearchResult(results: list, query_id: Optional[str] = None):
|
|
| 250 |
result_items = []
|
| 251 |
for idx, result in enumerate(results):
|
| 252 |
fields = result["fields"] # Extract the 'fields' part of each result
|
| 253 |
-
|
| 254 |
|
| 255 |
# Filter sim_map fields that are words with 4 or more characters
|
| 256 |
sim_map_fields = {
|
|
@@ -286,7 +292,7 @@ def SearchResult(results: list, query_id: Optional[str] = None):
|
|
| 286 |
"Reset",
|
| 287 |
variant="outline",
|
| 288 |
size="sm",
|
| 289 |
-
data_image_src=
|
| 290 |
cls="reset-button pointer-events-auto font-mono text-xs h-5 rounded-none px-2",
|
| 291 |
)
|
| 292 |
|
|
@@ -312,7 +318,11 @@ def SearchResult(results: list, query_id: Optional[str] = None):
|
|
| 312 |
Div(
|
| 313 |
Div(
|
| 314 |
Img(
|
| 315 |
-
src=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
alt=fields["title"],
|
| 317 |
cls="result-image w-full h-full object-contain",
|
| 318 |
),
|
|
@@ -350,12 +360,35 @@ def SearchResult(results: list, query_id: Optional[str] = None):
|
|
| 350 |
),
|
| 351 |
cls="bg-background px-3 py-5 hidden md:block",
|
| 352 |
),
|
| 353 |
-
cls="grid grid-cols-1 md:grid-cols-2 col-span-2",
|
| 354 |
)
|
| 355 |
)
|
|
|
|
| 356 |
return Div(
|
| 357 |
*result_items,
|
| 358 |
image_swapping,
|
| 359 |
id="search-results",
|
| 360 |
cls="grid grid-cols-2 gap-px bg-border",
|
| 361 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
def SampleQueries():
|
| 133 |
sample_queries = [
|
| 134 |
+
"Proportion of female new hires 2021-2023?",
|
| 135 |
+
"Total amount of performance-based pay awarded in 2023?",
|
| 136 |
+
"What is the percentage distribution of employees with performance-based pay relative to the limit in 2023?",
|
| 137 |
+
"What is the breakdown of management costs by investment strategy in 2023?",
|
| 138 |
+
"2023 profit loss portfolio",
|
| 139 |
+
"net cash flow operating activities",
|
| 140 |
+
"fund currency basket returns",
|
| 141 |
]
|
| 142 |
|
| 143 |
query_badges = []
|
|
|
|
| 197 |
)
|
| 198 |
return Div(
|
| 199 |
Div(
|
|
|
|
| 200 |
Div(
|
| 201 |
+
SearchBox(query_value=query_value, ranking_value=ranking_value),
|
| 202 |
+
Div(
|
| 203 |
+
LoadingMessage(),
|
| 204 |
+
id="search-results", # This will be replaced by the search results
|
| 205 |
+
),
|
| 206 |
+
cls="grid",
|
| 207 |
),
|
| 208 |
cls="grid",
|
| 209 |
),
|
|
|
|
| 210 |
)
|
| 211 |
|
| 212 |
|
| 213 |
+
def LoadingMessage(display_text="Retrieving search results"):
|
| 214 |
return Div(
|
| 215 |
Lucide(icon="loader-circle", cls="size-5 mr-1.5 animate-spin"),
|
| 216 |
+
Span(display_text, cls="text-base text-center"),
|
| 217 |
cls="p-10 text-muted-foreground flex items-center justify-center",
|
| 218 |
id="loading-indicator",
|
| 219 |
)
|
|
|
|
| 256 |
result_items = []
|
| 257 |
for idx, result in enumerate(results):
|
| 258 |
fields = result["fields"] # Extract the 'fields' part of each result
|
| 259 |
+
blur_image_base64 = f"data:image/jpeg;base64,{fields['blur_image']}"
|
| 260 |
|
| 261 |
# Filter sim_map fields that are words with 4 or more characters
|
| 262 |
sim_map_fields = {
|
|
|
|
| 292 |
"Reset",
|
| 293 |
variant="outline",
|
| 294 |
size="sm",
|
| 295 |
+
data_image_src=blur_image_base64,
|
| 296 |
cls="reset-button pointer-events-auto font-mono text-xs h-5 rounded-none px-2",
|
| 297 |
)
|
| 298 |
|
|
|
|
| 318 |
Div(
|
| 319 |
Div(
|
| 320 |
Img(
|
| 321 |
+
src=blur_image_base64,
|
| 322 |
+
hx_get=f"/full_image?id={fields['id']}",
|
| 323 |
+
style="filter: blur(5px);",
|
| 324 |
+
hx_trigger="load",
|
| 325 |
+
hx_swap="outerHTML",
|
| 326 |
alt=fields["title"],
|
| 327 |
cls="result-image w-full h-full object-contain",
|
| 328 |
),
|
|
|
|
| 360 |
),
|
| 361 |
cls="bg-background px-3 py-5 hidden md:block",
|
| 362 |
),
|
| 363 |
+
cls="grid grid-cols-1 md:grid-cols-2 col-span-2 border-t",
|
| 364 |
)
|
| 365 |
)
|
| 366 |
+
|
| 367 |
return Div(
|
| 368 |
*result_items,
|
| 369 |
image_swapping,
|
| 370 |
id="search-results",
|
| 371 |
cls="grid grid-cols-2 gap-px bg-border",
|
| 372 |
)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def ChatResult(query_id: str, query: str):
|
| 376 |
+
return Div(
|
| 377 |
+
Div("Chat", cls="text-xl font-semibold p-3"),
|
| 378 |
+
Div(
|
| 379 |
+
Div(
|
| 380 |
+
Div(
|
| 381 |
+
LoadingMessage(display_text="Waiting for response..."),
|
| 382 |
+
cls="bg-muted/80 dark:bg-muted/40 text-black dark:text-white p-2 rounded-md",
|
| 383 |
+
hx_ext="sse",
|
| 384 |
+
sse_connect=f"/get-message?query_id={query_id}&query={quote_plus(query)}",
|
| 385 |
+
sse_swap="message",
|
| 386 |
+
sse_close="close",
|
| 387 |
+
hx_swap="innerHTML",
|
| 388 |
+
),
|
| 389 |
+
),
|
| 390 |
+
id="chat-messages",
|
| 391 |
+
cls="overflow-auto min-h-0 grid items-end px-3",
|
| 392 |
+
),
|
| 393 |
+
cls="h-full grid grid-rows-[auto_1fr_auto] min-h-0 gap-3",
|
| 394 |
+
)
|
frontend/layout.py
CHANGED
|
@@ -1,15 +1,96 @@
|
|
| 1 |
-
from fasthtml.components import Div, Img, Nav, Title
|
| 2 |
-
from fasthtml.xtend import A
|
| 3 |
from lucide_fasthtml import Lucide
|
| 4 |
from shad4fast import Button, Separator
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
def Logo():
|
| 8 |
return Div(
|
| 9 |
-
Img(
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
)
|
| 14 |
|
| 15 |
|
|
@@ -38,23 +119,23 @@ def Links():
|
|
| 38 |
),
|
| 39 |
Separator(orientation="vertical"),
|
| 40 |
ThemeToggle(),
|
| 41 |
-
cls=
|
| 42 |
)
|
| 43 |
|
| 44 |
|
| 45 |
def Layout(*c, **kwargs):
|
| 46 |
return (
|
| 47 |
-
Title(
|
| 48 |
Body(
|
| 49 |
Header(
|
| 50 |
A(Logo(), href="/"),
|
| 51 |
Links(),
|
| 52 |
-
cls=
|
| 53 |
-
),
|
| 54 |
-
Main(
|
| 55 |
-
*c, **kwargs,
|
| 56 |
-
cls='flex-1 h-full'
|
| 57 |
),
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
),
|
|
|
|
|
|
|
| 60 |
)
|
|
|
|
| 1 |
+
from fasthtml.components import Body, Div, Header, Img, Nav, Title
|
| 2 |
+
from fasthtml.xtend import A, Script
|
| 3 |
from lucide_fasthtml import Lucide
|
| 4 |
from shad4fast import Button, Separator
|
| 5 |
|
| 6 |
+
script = Script(
|
| 7 |
+
"""
|
| 8 |
+
document.addEventListener("DOMContentLoaded", function () {
|
| 9 |
+
const main = document.querySelector('main');
|
| 10 |
+
const aside = document.querySelector('aside');
|
| 11 |
+
const body = document.body;
|
| 12 |
+
|
| 13 |
+
if (main && aside && main.nextElementSibling === aside) {
|
| 14 |
+
// Main + Aside layout
|
| 15 |
+
body.classList.add('grid-cols-[minmax(0,_4fr)_minmax(0,_1fr)]');
|
| 16 |
+
aside.classList.remove('hidden');
|
| 17 |
+
} else if (main) {
|
| 18 |
+
// Only Main layout (full width)
|
| 19 |
+
body.classList.add('grid-cols-[1fr]');
|
| 20 |
+
}
|
| 21 |
+
});
|
| 22 |
+
"""
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
overlay_scrollbars = Script(
|
| 26 |
+
"""
|
| 27 |
+
(function () {
|
| 28 |
+
const { OverlayScrollbars } = OverlayScrollbarsGlobal;
|
| 29 |
+
|
| 30 |
+
function getPreferredTheme() {
|
| 31 |
+
return localStorage.theme === 'dark' || (!('theme' in localStorage) && window.matchMedia('(prefers-color-scheme: dark)').matches)
|
| 32 |
+
? 'dark'
|
| 33 |
+
: 'light';
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
function applyOverlayScrollbars(element, scrollbarTheme) {
|
| 37 |
+
// Destroy existing OverlayScrollbars instance if it exists
|
| 38 |
+
const instance = OverlayScrollbars(element);
|
| 39 |
+
if (instance) {
|
| 40 |
+
instance.destroy();
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
// Reinitialize OverlayScrollbars with the new theme
|
| 44 |
+
OverlayScrollbars(element, {
|
| 45 |
+
scrollbars: {
|
| 46 |
+
theme: scrollbarTheme,
|
| 47 |
+
visibility: 'auto',
|
| 48 |
+
autoHide: 'leave',
|
| 49 |
+
autoHideDelay: 800
|
| 50 |
+
}
|
| 51 |
+
});
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
function updateScrollbarTheme() {
|
| 55 |
+
const isDarkMode = getPreferredTheme() === 'dark';
|
| 56 |
+
const scrollbarTheme = isDarkMode ? 'os-theme-light' : 'os-theme-dark'; // Light theme in dark mode, dark theme in light mode
|
| 57 |
+
|
| 58 |
+
const mainElement = document.querySelector('main');
|
| 59 |
+
const chatMessagesElement = document.querySelector('#chat-messages'); // Select the chat message container by ID
|
| 60 |
+
|
| 61 |
+
if (mainElement) {
|
| 62 |
+
applyOverlayScrollbars(mainElement, scrollbarTheme);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
if (chatMessagesElement) {
|
| 66 |
+
applyOverlayScrollbars(chatMessagesElement, scrollbarTheme);
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
// Apply the correct theme immediately when the page loads
|
| 71 |
+
updateScrollbarTheme();
|
| 72 |
+
|
| 73 |
+
// Observe changes in the 'dark' class on the <html> element
|
| 74 |
+
const observer = new MutationObserver(updateScrollbarTheme);
|
| 75 |
+
observer.observe(document.documentElement, { attributes: true, attributeFilter: ['class'] });
|
| 76 |
+
})();
|
| 77 |
+
"""
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
|
| 81 |
def Logo():
|
| 82 |
return Div(
|
| 83 |
+
Img(
|
| 84 |
+
src="https://assets.vespa.ai/logos/vespa-logo-black.svg",
|
| 85 |
+
alt="Vespa Logo",
|
| 86 |
+
cls="h-full dark:hidden",
|
| 87 |
+
),
|
| 88 |
+
Img(
|
| 89 |
+
src="https://assets.vespa.ai/logos/vespa-logo-white.svg",
|
| 90 |
+
alt="Vespa Logo Dark Mode",
|
| 91 |
+
cls="h-full hidden dark:block",
|
| 92 |
+
),
|
| 93 |
+
cls="h-[27px]",
|
| 94 |
)
|
| 95 |
|
| 96 |
|
|
|
|
| 119 |
),
|
| 120 |
Separator(orientation="vertical"),
|
| 121 |
ThemeToggle(),
|
| 122 |
+
cls="flex items-center space-x-3",
|
| 123 |
)
|
| 124 |
|
| 125 |
|
| 126 |
def Layout(*c, **kwargs):
|
| 127 |
return (
|
| 128 |
+
Title("Visual Retrieval ColPali"),
|
| 129 |
Body(
|
| 130 |
Header(
|
| 131 |
A(Logo(), href="/"),
|
| 132 |
Links(),
|
| 133 |
+
cls="min-h-[55px] h-[55px] w-full flex items-center justify-between px-4",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
),
|
| 135 |
+
*c,
|
| 136 |
+
**kwargs,
|
| 137 |
+
cls="grid grid-rows-[55px_1fr] min-h-0",
|
| 138 |
),
|
| 139 |
+
script,
|
| 140 |
+
overlay_scrollbars,
|
| 141 |
)
|
globals.css
CHANGED
|
@@ -183,4 +183,25 @@
|
|
| 183 |
width: 100%;
|
| 184 |
height: 100%;
|
| 185 |
z-index: 10;
|
| 186 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
width: 100%;
|
| 184 |
height: 100%;
|
| 185 |
z-index: 10;
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
header {
|
| 189 |
+
grid-column: 1/-1;
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
main {
|
| 193 |
+
overflow: auto;
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
aside {
|
| 197 |
+
overflow: auto;
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
.scroll-container {
|
| 201 |
+
padding-right: 10px;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
.question-message {
|
| 205 |
+
background-color: #61D790;
|
| 206 |
+
color: #2E2F27;
|
| 207 |
+
}
|
main.py
CHANGED
|
@@ -1,22 +1,25 @@
|
|
| 1 |
import asyncio
|
|
|
|
|
|
|
| 2 |
from concurrent.futures import ThreadPoolExecutor
|
| 3 |
from functools import partial
|
| 4 |
|
| 5 |
from fasthtml.common import *
|
| 6 |
from shad4fast import *
|
| 7 |
from vespa.application import Vespa
|
| 8 |
-
import time
|
| 9 |
|
|
|
|
| 10 |
from backend.colpali import (
|
| 11 |
-
get_result_from_query,
|
| 12 |
-
get_query_embeddings_and_token_map,
|
| 13 |
add_sim_maps_to_result,
|
|
|
|
|
|
|
| 14 |
is_special_token,
|
|
|
|
| 15 |
)
|
| 16 |
-
from backend.vespa_app import get_vespa_app
|
| 17 |
-
from backend.cache import LRUCache
|
| 18 |
from backend.modelmanager import ModelManager
|
|
|
|
| 19 |
from frontend.app import (
|
|
|
|
| 20 |
Home,
|
| 21 |
Search,
|
| 22 |
SearchBox,
|
|
@@ -25,7 +28,10 @@ from frontend.app import (
|
|
| 25 |
SimMapButtonReady,
|
| 26 |
)
|
| 27 |
from frontend.layout import Layout
|
| 28 |
-
import
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
highlight_js_theme_link = Link(id="highlight-theme", rel="stylesheet", href="")
|
| 31 |
highlight_js_theme = Script(src="/static/js/highlightjs-theme.js")
|
|
@@ -35,15 +41,27 @@ highlight_js = HighlightJS(
|
|
| 35 |
light="github",
|
| 36 |
)
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
app, rt = fast_app(
|
| 40 |
-
htmlkw={"cls": "h-full"},
|
| 41 |
pico=False,
|
| 42 |
hdrs=(
|
| 43 |
ShadHead(tw_cdn=False, theme_handle=True),
|
| 44 |
highlight_js,
|
| 45 |
highlight_js_theme_link,
|
| 46 |
highlight_js_theme,
|
|
|
|
|
|
|
|
|
|
| 47 |
),
|
| 48 |
)
|
| 49 |
vespa_app: Vespa = get_vespa_app()
|
|
@@ -53,6 +71,16 @@ task_cache = LRUCache(
|
|
| 53 |
max_size=1000
|
| 54 |
) # Map from query_id to boolean value - False if not all results are ready.
|
| 55 |
thread_pool = ThreadPoolExecutor()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
@app.on_event("startup")
|
|
@@ -72,7 +100,7 @@ def serve_static(filepath: str):
|
|
| 72 |
|
| 73 |
@rt("/")
|
| 74 |
def get():
|
| 75 |
-
return Layout(Home())
|
| 76 |
|
| 77 |
|
| 78 |
@rt("/search")
|
|
@@ -86,16 +114,18 @@ def get(request):
|
|
| 86 |
if not query_value:
|
| 87 |
# Show SearchBox and a message for missing query
|
| 88 |
return Layout(
|
| 89 |
-
|
| 90 |
-
SearchBox(query_value=query_value, ranking_value=ranking_value),
|
| 91 |
Div(
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
),
|
| 96 |
-
cls="
|
| 97 |
-
)
|
| 98 |
-
cls="grid",
|
| 99 |
)
|
| 100 |
)
|
| 101 |
# Generate a unique query_id based on the query and ranking value
|
|
@@ -107,7 +137,12 @@ def get(request):
|
|
| 107 |
# search_results = get_results_children(result)
|
| 108 |
# return Layout(Search(request, search_results))
|
| 109 |
# Show the loading message if a query is provided
|
| 110 |
-
return Layout(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
|
| 113 |
@rt("/fetch_results")
|
|
@@ -215,15 +250,67 @@ async def get_sim_map(query_id: str, idx: int, token: str):
|
|
| 215 |
sim_map_b64 = search_results[idx]["fields"].get(sim_map_key, None)
|
| 216 |
if sim_map_b64 is None:
|
| 217 |
return SimMapButtonPoll(query_id=query_id, idx=idx, token=token)
|
| 218 |
-
sim_map_img_src = f"data:image/
|
| 219 |
return SimMapButtonReady(
|
| 220 |
query_id=query_id, idx=idx, token=token, img_src=sim_map_img_src
|
| 221 |
)
|
| 222 |
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
@rt("/app")
|
| 225 |
def get():
|
| 226 |
-
return Layout(Div(P(f"Connected to Vespa at {vespa_app.url}"), cls="p-4"))
|
| 227 |
|
| 228 |
|
| 229 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import asyncio
|
| 2 |
+
import hashlib
|
| 3 |
+
import time
|
| 4 |
from concurrent.futures import ThreadPoolExecutor
|
| 5 |
from functools import partial
|
| 6 |
|
| 7 |
from fasthtml.common import *
|
| 8 |
from shad4fast import *
|
| 9 |
from vespa.application import Vespa
|
|
|
|
| 10 |
|
| 11 |
+
from backend.cache import LRUCache
|
| 12 |
from backend.colpali import (
|
|
|
|
|
|
|
| 13 |
add_sim_maps_to_result,
|
| 14 |
+
get_query_embeddings_and_token_map,
|
| 15 |
+
get_result_from_query,
|
| 16 |
is_special_token,
|
| 17 |
+
get_full_image_from_vespa,
|
| 18 |
)
|
|
|
|
|
|
|
| 19 |
from backend.modelmanager import ModelManager
|
| 20 |
+
from backend.vespa_app import get_vespa_app
|
| 21 |
from frontend.app import (
|
| 22 |
+
ChatResult,
|
| 23 |
Home,
|
| 24 |
Search,
|
| 25 |
SearchBox,
|
|
|
|
| 28 |
SimMapButtonReady,
|
| 29 |
)
|
| 30 |
from frontend.layout import Layout
|
| 31 |
+
import google.generativeai as genai
|
| 32 |
+
from PIL import Image
|
| 33 |
+
import io
|
| 34 |
+
import base64
|
| 35 |
|
| 36 |
highlight_js_theme_link = Link(id="highlight-theme", rel="stylesheet", href="")
|
| 37 |
highlight_js_theme = Script(src="/static/js/highlightjs-theme.js")
|
|
|
|
| 41 |
light="github",
|
| 42 |
)
|
| 43 |
|
| 44 |
+
overlayscrollbars_link = Link(
|
| 45 |
+
rel="stylesheet",
|
| 46 |
+
href="https://cdnjs.cloudflare.com/ajax/libs/overlayscrollbars/2.10.0/styles/overlayscrollbars.min.css",
|
| 47 |
+
type="text/css",
|
| 48 |
+
)
|
| 49 |
+
overlayscrollbars_js = Script(
|
| 50 |
+
src="https://cdnjs.cloudflare.com/ajax/libs/overlayscrollbars/2.10.0/browser/overlayscrollbars.browser.es5.min.js"
|
| 51 |
+
)
|
| 52 |
+
sselink = Script(src="https://unpkg.com/[email protected]/sse.js")
|
| 53 |
|
| 54 |
app, rt = fast_app(
|
| 55 |
+
htmlkw={"cls": "grid h-full"},
|
| 56 |
pico=False,
|
| 57 |
hdrs=(
|
| 58 |
ShadHead(tw_cdn=False, theme_handle=True),
|
| 59 |
highlight_js,
|
| 60 |
highlight_js_theme_link,
|
| 61 |
highlight_js_theme,
|
| 62 |
+
overlayscrollbars_link,
|
| 63 |
+
overlayscrollbars_js,
|
| 64 |
+
sselink,
|
| 65 |
),
|
| 66 |
)
|
| 67 |
vespa_app: Vespa = get_vespa_app()
|
|
|
|
| 71 |
max_size=1000
|
| 72 |
) # Map from query_id to boolean value - False if not all results are ready.
|
| 73 |
thread_pool = ThreadPoolExecutor()
|
| 74 |
+
# Gemini config
|
| 75 |
+
|
| 76 |
+
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
|
| 77 |
+
GEMINI_SYSTEM_PROMPT = """If the user query is a question, try your best to answer it based on the provided images.
|
| 78 |
+
If the user query is not an obvious question, reply with 'No question detected.'. Your response should be HTML formatted.
|
| 79 |
+
This means that newlines will be replaced with <br> tags, bold text will be enclosed in <b> tags, and so on.
|
| 80 |
+
"""
|
| 81 |
+
gemini_model = genai.GenerativeModel(
|
| 82 |
+
"gemini-1.5-flash-8b", system_instruction=GEMINI_SYSTEM_PROMPT
|
| 83 |
+
)
|
| 84 |
|
| 85 |
|
| 86 |
@app.on_event("startup")
|
|
|
|
| 100 |
|
| 101 |
@rt("/")
|
| 102 |
def get():
|
| 103 |
+
return Layout(Main(Home()))
|
| 104 |
|
| 105 |
|
| 106 |
@rt("/search")
|
|
|
|
| 114 |
if not query_value:
|
| 115 |
# Show SearchBox and a message for missing query
|
| 116 |
return Layout(
|
| 117 |
+
Main(
|
|
|
|
| 118 |
Div(
|
| 119 |
+
SearchBox(query_value=query_value, ranking_value=ranking_value),
|
| 120 |
+
Div(
|
| 121 |
+
P(
|
| 122 |
+
"No query provided. Please enter a query.",
|
| 123 |
+
cls="text-center text-muted-foreground",
|
| 124 |
+
),
|
| 125 |
+
cls="p-10",
|
| 126 |
),
|
| 127 |
+
cls="grid",
|
| 128 |
+
)
|
|
|
|
| 129 |
)
|
| 130 |
)
|
| 131 |
# Generate a unique query_id based on the query and ranking value
|
|
|
|
| 137 |
# search_results = get_results_children(result)
|
| 138 |
# return Layout(Search(request, search_results))
|
| 139 |
# Show the loading message if a query is provided
|
| 140 |
+
return Layout(
|
| 141 |
+
Main(Search(request), data_overlayscrollbars_initialize=True, cls="border-t"),
|
| 142 |
+
Aside(
|
| 143 |
+
ChatResult(query_id=query_id, query=query_value), cls="border-t border-l"
|
| 144 |
+
),
|
| 145 |
+
) # Show SearchBox and Loading message initially
|
| 146 |
|
| 147 |
|
| 148 |
@rt("/fetch_results")
|
|
|
|
| 250 |
sim_map_b64 = search_results[idx]["fields"].get(sim_map_key, None)
|
| 251 |
if sim_map_b64 is None:
|
| 252 |
return SimMapButtonPoll(query_id=query_id, idx=idx, token=token)
|
| 253 |
+
sim_map_img_src = f"data:image/png;base64,{sim_map_b64}"
|
| 254 |
return SimMapButtonReady(
|
| 255 |
query_id=query_id, idx=idx, token=token, img_src=sim_map_img_src
|
| 256 |
)
|
| 257 |
|
| 258 |
|
| 259 |
+
@app.get("/full_image")
|
| 260 |
+
async def full_image(id: str):
|
| 261 |
+
"""
|
| 262 |
+
Endpoint to get the full quality image for a given result id.
|
| 263 |
+
"""
|
| 264 |
+
image_data = await get_full_image_from_vespa(vespa_app, id)
|
| 265 |
+
|
| 266 |
+
# Decode the base64 image data
|
| 267 |
+
# image_data = base64.b64decode(image_data)
|
| 268 |
+
image_data = "data:image/jpeg;base64," + image_data
|
| 269 |
+
|
| 270 |
+
return Img(
|
| 271 |
+
src=image_data,
|
| 272 |
+
alt="something",
|
| 273 |
+
cls="result-image w-full h-full object-contain",
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
async def message_generator(query_id: str, query: str):
|
| 278 |
+
result = None
|
| 279 |
+
while result is None:
|
| 280 |
+
result = result_cache.get(query_id)
|
| 281 |
+
await asyncio.sleep(0.5)
|
| 282 |
+
search_results = get_results_children(result)
|
| 283 |
+
images = [result["fields"]["blur_image"] for result in search_results]
|
| 284 |
+
# from b64 to PIL image
|
| 285 |
+
images = [Image.open(io.BytesIO(base64.b64decode(img))) for img in images]
|
| 286 |
+
|
| 287 |
+
# If newlines are present in the response, the connection will be closed.
|
| 288 |
+
def replace_newline_with_br(text):
|
| 289 |
+
return text.replace("\n", "<br>")
|
| 290 |
+
|
| 291 |
+
response_text = ""
|
| 292 |
+
async for chunk in await gemini_model.generate_content_async(
|
| 293 |
+
images + ["\n\n Query: ", query], stream=True
|
| 294 |
+
):
|
| 295 |
+
if chunk.text:
|
| 296 |
+
response_text += chunk.text
|
| 297 |
+
response_text = replace_newline_with_br(response_text)
|
| 298 |
+
yield f"event: message\ndata: {response_text}\n\n"
|
| 299 |
+
await asyncio.sleep(0.5)
|
| 300 |
+
yield "event: close\ndata: \n\n"
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
@app.get("/get-message")
|
| 304 |
+
async def get_message(query_id: str, query: str):
|
| 305 |
+
return StreamingResponse(
|
| 306 |
+
message_generator(query_id=query_id, query=query),
|
| 307 |
+
media_type="text/event-stream",
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
@rt("/app")
|
| 312 |
def get():
|
| 313 |
+
return Layout(Main(Div(P(f"Connected to Vespa at {vespa_app.url}"), cls="p-4")))
|
| 314 |
|
| 315 |
|
| 316 |
if __name__ == "__main__":
|
output.css
CHANGED
|
@@ -927,6 +927,10 @@ body {
|
|
| 927 |
max-height: 100vh;
|
| 928 |
}
|
| 929 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 930 |
.min-h-\[55px\] {
|
| 931 |
min-height: 55px;
|
| 932 |
}
|
|
@@ -1096,6 +1100,22 @@ body {
|
|
| 1096 |
grid-template-columns: repeat(2, minmax(0, 1fr));
|
| 1097 |
}
|
| 1098 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1099 |
.flex-col {
|
| 1100 |
flex-direction: column;
|
| 1101 |
}
|
|
@@ -1112,10 +1132,18 @@ body {
|
|
| 1112 |
align-content: flex-start;
|
| 1113 |
}
|
| 1114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1115 |
.items-center {
|
| 1116 |
align-items: center;
|
| 1117 |
}
|
| 1118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1119 |
.justify-center {
|
| 1120 |
justify-content: center;
|
| 1121 |
}
|
|
@@ -1136,6 +1164,10 @@ body {
|
|
| 1136 |
gap: 0.5rem;
|
| 1137 |
}
|
| 1138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1139 |
.gap-4 {
|
| 1140 |
gap: 1rem;
|
| 1141 |
}
|
|
@@ -1200,6 +1232,10 @@ body {
|
|
| 1200 |
margin-bottom: calc(0.5rem * var(--tw-space-y-reverse));
|
| 1201 |
}
|
| 1202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1203 |
.self-stretch {
|
| 1204 |
align-self: stretch;
|
| 1205 |
}
|
|
@@ -1252,6 +1288,11 @@ body {
|
|
| 1252 |
border-width: 2px;
|
| 1253 |
}
|
| 1254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1255 |
.border-b {
|
| 1256 |
border-bottom-width: 1px;
|
| 1257 |
}
|
|
@@ -1493,6 +1534,10 @@ body {
|
|
| 1493 |
padding-top: 1rem;
|
| 1494 |
}
|
| 1495 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1496 |
.text-left {
|
| 1497 |
text-align: left;
|
| 1498 |
}
|
|
@@ -1577,6 +1622,11 @@ body {
|
|
| 1577 |
letter-spacing: 0.025em;
|
| 1578 |
}
|
| 1579 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1580 |
.text-card-foreground {
|
| 1581 |
color: hsl(var(--card-foreground));
|
| 1582 |
}
|
|
@@ -1993,6 +2043,27 @@ body {
|
|
| 1993 |
z-index: 10;
|
| 1994 |
}
|
| 1995 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1996 |
:root:has(.data-\[state\=open\]\:no-bg-scroll[data-state="open"]) {
|
| 1997 |
overflow: hidden;
|
| 1998 |
}
|
|
@@ -2537,6 +2608,11 @@ body {
|
|
| 2537 |
--tw-gradient-to: #d1d5db var(--tw-gradient-to-position);
|
| 2538 |
}
|
| 2539 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2540 |
.dark\:hover\:border-white:hover:where(.dark, .dark *) {
|
| 2541 |
--tw-border-opacity: 1;
|
| 2542 |
border-color: rgb(255 255 255 / var(--tw-border-opacity));
|
|
@@ -2610,4 +2686,4 @@ body {
|
|
| 2610 |
|
| 2611 |
.\[\&_tr\]\:border-b tr {
|
| 2612 |
border-bottom-width: 1px;
|
| 2613 |
-
}
|
|
|
|
| 927 |
max-height: 100vh;
|
| 928 |
}
|
| 929 |
|
| 930 |
+
.min-h-0 {
|
| 931 |
+
min-height: 0px;
|
| 932 |
+
}
|
| 933 |
+
|
| 934 |
.min-h-\[55px\] {
|
| 935 |
min-height: 55px;
|
| 936 |
}
|
|
|
|
| 1100 |
grid-template-columns: repeat(2, minmax(0, 1fr));
|
| 1101 |
}
|
| 1102 |
|
| 1103 |
+
.grid-cols-\[1fr\] {
|
| 1104 |
+
grid-template-columns: 1fr;
|
| 1105 |
+
}
|
| 1106 |
+
|
| 1107 |
+
.grid-cols-\[minmax\(0\2c _4fr\)_minmax\(0\2c _1fr\)\] {
|
| 1108 |
+
grid-template-columns: minmax(0, 4fr) minmax(0, 1fr);
|
| 1109 |
+
}
|
| 1110 |
+
|
| 1111 |
+
.grid-rows-\[55px_1fr\] {
|
| 1112 |
+
grid-template-rows: 55px 1fr;
|
| 1113 |
+
}
|
| 1114 |
+
|
| 1115 |
+
.grid-rows-\[auto_1fr_auto\] {
|
| 1116 |
+
grid-template-rows: auto 1fr auto;
|
| 1117 |
+
}
|
| 1118 |
+
|
| 1119 |
.flex-col {
|
| 1120 |
flex-direction: column;
|
| 1121 |
}
|
|
|
|
| 1132 |
align-content: flex-start;
|
| 1133 |
}
|
| 1134 |
|
| 1135 |
+
.items-end {
|
| 1136 |
+
align-items: flex-end;
|
| 1137 |
+
}
|
| 1138 |
+
|
| 1139 |
.items-center {
|
| 1140 |
align-items: center;
|
| 1141 |
}
|
| 1142 |
|
| 1143 |
+
.justify-end {
|
| 1144 |
+
justify-content: flex-end;
|
| 1145 |
+
}
|
| 1146 |
+
|
| 1147 |
.justify-center {
|
| 1148 |
justify-content: center;
|
| 1149 |
}
|
|
|
|
| 1164 |
gap: 0.5rem;
|
| 1165 |
}
|
| 1166 |
|
| 1167 |
+
.gap-3 {
|
| 1168 |
+
gap: 0.75rem;
|
| 1169 |
+
}
|
| 1170 |
+
|
| 1171 |
.gap-4 {
|
| 1172 |
gap: 1rem;
|
| 1173 |
}
|
|
|
|
| 1232 |
margin-bottom: calc(0.5rem * var(--tw-space-y-reverse));
|
| 1233 |
}
|
| 1234 |
|
| 1235 |
+
.self-end {
|
| 1236 |
+
align-self: flex-end;
|
| 1237 |
+
}
|
| 1238 |
+
|
| 1239 |
.self-stretch {
|
| 1240 |
align-self: stretch;
|
| 1241 |
}
|
|
|
|
| 1288 |
border-width: 2px;
|
| 1289 |
}
|
| 1290 |
|
| 1291 |
+
.border-x {
|
| 1292 |
+
border-left-width: 1px;
|
| 1293 |
+
border-right-width: 1px;
|
| 1294 |
+
}
|
| 1295 |
+
|
| 1296 |
.border-b {
|
| 1297 |
border-bottom-width: 1px;
|
| 1298 |
}
|
|
|
|
| 1534 |
padding-top: 1rem;
|
| 1535 |
}
|
| 1536 |
|
| 1537 |
+
.pr-3 {
|
| 1538 |
+
padding-right: 0.75rem;
|
| 1539 |
+
}
|
| 1540 |
+
|
| 1541 |
.text-left {
|
| 1542 |
text-align: left;
|
| 1543 |
}
|
|
|
|
| 1622 |
letter-spacing: 0.025em;
|
| 1623 |
}
|
| 1624 |
|
| 1625 |
+
.text-black {
|
| 1626 |
+
--tw-text-opacity: 1;
|
| 1627 |
+
color: rgb(0 0 0 / var(--tw-text-opacity));
|
| 1628 |
+
}
|
| 1629 |
+
|
| 1630 |
.text-card-foreground {
|
| 1631 |
color: hsl(var(--card-foreground));
|
| 1632 |
}
|
|
|
|
| 2043 |
z-index: 10;
|
| 2044 |
}
|
| 2045 |
|
| 2046 |
+
header {
|
| 2047 |
+
grid-column: 1/-1;
|
| 2048 |
+
}
|
| 2049 |
+
|
| 2050 |
+
main {
|
| 2051 |
+
overflow: auto;
|
| 2052 |
+
}
|
| 2053 |
+
|
| 2054 |
+
aside {
|
| 2055 |
+
overflow: auto;
|
| 2056 |
+
}
|
| 2057 |
+
|
| 2058 |
+
.scroll-container {
|
| 2059 |
+
padding-right: 10px;
|
| 2060 |
+
}
|
| 2061 |
+
|
| 2062 |
+
.question-message {
|
| 2063 |
+
background-color: #61D790;
|
| 2064 |
+
color: #2E2F27;
|
| 2065 |
+
}
|
| 2066 |
+
|
| 2067 |
:root:has(.data-\[state\=open\]\:no-bg-scroll[data-state="open"]) {
|
| 2068 |
overflow: hidden;
|
| 2069 |
}
|
|
|
|
| 2608 |
--tw-gradient-to: #d1d5db var(--tw-gradient-to-position);
|
| 2609 |
}
|
| 2610 |
|
| 2611 |
+
.dark\:text-white:where(.dark, .dark *) {
|
| 2612 |
+
--tw-text-opacity: 1;
|
| 2613 |
+
color: rgb(255 255 255 / var(--tw-text-opacity));
|
| 2614 |
+
}
|
| 2615 |
+
|
| 2616 |
.dark\:hover\:border-white:hover:where(.dark, .dark *) {
|
| 2617 |
--tw-border-opacity: 1;
|
| 2618 |
border-color: rgb(255 255 255 / var(--tw-border-opacity));
|
|
|
|
| 2686 |
|
| 2687 |
.\[\&_tr\]\:border-b tr {
|
| 2688 |
border-bottom-width: 1px;
|
| 2689 |
+
}
|
prepare_feed_deploy.py
ADDED
|
@@ -0,0 +1,977 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # Visual PDF Retrieval - demo application
|
| 3 |
+
#
|
| 4 |
+
# In this notebook, we will prepare the Vespa backend application for our visual retrieval demo.
|
| 5 |
+
# We will use ColPali as the model to extract patch vectors from images of pdf pages.
|
| 6 |
+
# At query time, we use MaxSim to retrieve and/or (based on the configuration) rank the page results.
|
| 7 |
+
#
|
| 8 |
+
# To see the application in action, visit TODO:
|
| 9 |
+
#
|
| 10 |
+
# The web application is written in FastHTML, meaning the complete application is written in python.
|
| 11 |
+
#
|
| 12 |
+
# The steps we will take in this notebook are:
|
| 13 |
+
#
|
| 14 |
+
# 0. Setup and configuration
|
| 15 |
+
# 1. Download the data
|
| 16 |
+
# 2. Prepare the data
|
| 17 |
+
# 3. Generate queries for evaluation and typeahead search suggestions
|
| 18 |
+
# 4. Deploy the Vespa application
|
| 19 |
+
# 5. Create the Vespa application
|
| 20 |
+
# 6. Feed the data to the Vespa application
|
| 21 |
+
#
|
| 22 |
+
# All the steps that are needed to provision the Vespa application, including feeding the data, can be done from this notebook.
|
| 23 |
+
# We have tried to make it easy for others to run this notebook, to create your own PDF Enterprise Search application using Vespa.
|
| 24 |
+
#
|
| 25 |
+
|
| 26 |
+
# %% [markdown]
|
| 27 |
+
# ## 0. Setup and Configuration
|
| 28 |
+
#
|
| 29 |
+
|
| 30 |
+
# %%
|
| 31 |
+
import os
|
| 32 |
+
import asyncio
|
| 33 |
+
import json
|
| 34 |
+
from typing import Tuple
|
| 35 |
+
import hashlib
|
| 36 |
+
import numpy as np
|
| 37 |
+
|
| 38 |
+
# Vespa
|
| 39 |
+
from vespa.package import (
|
| 40 |
+
ApplicationPackage,
|
| 41 |
+
Field,
|
| 42 |
+
Schema,
|
| 43 |
+
Document,
|
| 44 |
+
HNSW,
|
| 45 |
+
RankProfile,
|
| 46 |
+
Function,
|
| 47 |
+
FieldSet,
|
| 48 |
+
SecondPhaseRanking,
|
| 49 |
+
Summary,
|
| 50 |
+
DocumentSummary,
|
| 51 |
+
)
|
| 52 |
+
from vespa.deployment import VespaCloud
|
| 53 |
+
from vespa.application import Vespa
|
| 54 |
+
from vespa.io import VespaResponse
|
| 55 |
+
|
| 56 |
+
# Google Generative AI
|
| 57 |
+
import google.generativeai as genai
|
| 58 |
+
|
| 59 |
+
# Torch and other ML libraries
|
| 60 |
+
import torch
|
| 61 |
+
from torch.utils.data import DataLoader
|
| 62 |
+
from tqdm import tqdm
|
| 63 |
+
from pdf2image import convert_from_path
|
| 64 |
+
from pypdf import PdfReader
|
| 65 |
+
|
| 66 |
+
# ColPali model and processor
|
| 67 |
+
from colpali_engine.models import ColPali, ColPaliProcessor
|
| 68 |
+
from colpali_engine.utils.torch_utils import get_torch_device
|
| 69 |
+
from vidore_benchmark.utils.image_utils import scale_image, get_base64_image
|
| 70 |
+
|
| 71 |
+
# Other utilities
|
| 72 |
+
from bs4 import BeautifulSoup
|
| 73 |
+
import httpx
|
| 74 |
+
from urllib.parse import urljoin, urlparse
|
| 75 |
+
|
| 76 |
+
# Load environment variables
|
| 77 |
+
from dotenv import load_dotenv
|
| 78 |
+
|
| 79 |
+
load_dotenv()
|
| 80 |
+
|
| 81 |
+
# Avoid warning from huggingface tokenizers
|
| 82 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 83 |
+
|
| 84 |
+
# %% [markdown]
|
| 85 |
+
# ### Create a free trial in Vespa Cloud
|
| 86 |
+
#
|
| 87 |
+
# Create a tenant from [here](https://vespa.ai/free-trial/).
|
| 88 |
+
# The trial includes $300 credit.
|
| 89 |
+
# Take note of your tenant name.
|
| 90 |
+
#
|
| 91 |
+
|
| 92 |
+
# %%
|
| 93 |
+
VESPA_TENANT_NAME = "vespa-team"
|
| 94 |
+
|
| 95 |
+
# %% [markdown]
|
| 96 |
+
# Here, set your desired application name. (Will be created in later steps)
|
| 97 |
+
# Note that you can not have hyphen `-` or underscore `_` in the application name.
|
| 98 |
+
#
|
| 99 |
+
|
| 100 |
+
# %%
|
| 101 |
+
VESPA_APPLICATION_NAME = "colpalidemo2"
|
| 102 |
+
VESPA_SCHEMA_NAME = "pdf_page"
|
| 103 |
+
|
| 104 |
+
# %% [markdown]
|
| 105 |
+
# Next, you need to create some tokens for feeding data, and querying the application.
|
| 106 |
+
# We recommend separate tokens for feeding and querying, (the former with write permission, and the latter with read permission).
|
| 107 |
+
# The tokens can be created from the [Vespa Cloud console](https://console.vespa-cloud.com/) in the 'Account' -> 'Tokens' section.
|
| 108 |
+
#
|
| 109 |
+
|
| 110 |
+
# %%
|
| 111 |
+
VESPA_TOKEN_ID_WRITE = "colpalidemo_write"
|
| 112 |
+
VESPA_TOKEN_ID_READ = "colpalidemo_read"
|
| 113 |
+
|
| 114 |
+
# %% [markdown]
|
| 115 |
+
# We also need to set the value of the write token to be able to feed data to the Vespa application.
|
| 116 |
+
#
|
| 117 |
+
|
| 118 |
+
# %%
|
| 119 |
+
VESPA_CLOUD_SECRET_TOKEN = os.getenv("VESPA_CLOUD_SECRET_TOKEN") or input(
|
| 120 |
+
"Enter Vespa cloud secret token: "
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# %% [markdown]
|
| 124 |
+
# We will also use the Gemini API to create sample queries for our images.
|
| 125 |
+
# You can also use other VLM's to create these queries.
|
| 126 |
+
# Create a Gemini API key from [here](https://aistudio.google.com/app/apikey).
|
| 127 |
+
#
|
| 128 |
+
|
| 129 |
+
# %%
|
| 130 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") or input(
|
| 131 |
+
"Enter Google Generative AI API key: "
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# %%
|
| 135 |
+
MODEL_NAME = "vidore/colpali-v1.2"
|
| 136 |
+
|
| 137 |
+
# Configure Google Generative AI
|
| 138 |
+
genai.configure(api_key=GEMINI_API_KEY)
|
| 139 |
+
|
| 140 |
+
# Set device for Torch
|
| 141 |
+
device = get_torch_device("auto")
|
| 142 |
+
print(f"Using device: {device}")
|
| 143 |
+
|
| 144 |
+
# Load the ColPali model and processor
|
| 145 |
+
model = ColPali.from_pretrained(
|
| 146 |
+
MODEL_NAME,
|
| 147 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
| 148 |
+
device_map=device,
|
| 149 |
+
).eval()
|
| 150 |
+
|
| 151 |
+
processor = ColPaliProcessor.from_pretrained(MODEL_NAME)
|
| 152 |
+
|
| 153 |
+
# %% [markdown]
|
| 154 |
+
# ## 1. Download PDFs
|
| 155 |
+
#
|
| 156 |
+
# We are going to use public reports from the Norwegian Government Pension Fund Global (also known as the Oil Fund).
|
| 157 |
+
# The fund puts transparency at the forefront and publishes reports on its investments, holdings, and returns, as well as its strategy and governance.
|
| 158 |
+
#
|
| 159 |
+
# These reports are the ones we are going to use for this showcase.
|
| 160 |
+
# Here are some sample images:
|
| 161 |
+
#
|
| 162 |
+
# 
|
| 163 |
+
# 
|
| 164 |
+
#
|
| 165 |
+
|
| 166 |
+
# %% [markdown]
|
| 167 |
+
# As we can see, a lot of the information is in the form of tables, charts and numbers.
|
| 168 |
+
# These are not easily extractable using pdf-readers or OCR tools.
|
| 169 |
+
#
|
| 170 |
+
|
| 171 |
+
# %%
|
| 172 |
+
import requests
|
| 173 |
+
|
| 174 |
+
url = "https://www.nbim.no/en/publications/reports/"
|
| 175 |
+
response = requests.get(url)
|
| 176 |
+
response.raise_for_status()
|
| 177 |
+
html_content = response.text
|
| 178 |
+
|
| 179 |
+
# Parse with BeautifulSoup
|
| 180 |
+
soup = BeautifulSoup(html_content, "html.parser")
|
| 181 |
+
|
| 182 |
+
links = []
|
| 183 |
+
|
| 184 |
+
# Find all <a> elements with the specific classes
|
| 185 |
+
for a_tag in soup.find_all("a", href=True):
|
| 186 |
+
classes = a_tag.get("class", [])
|
| 187 |
+
if "button" in classes and "button--download-secondary" in classes:
|
| 188 |
+
href = a_tag["href"]
|
| 189 |
+
full_url = urljoin(url, href)
|
| 190 |
+
links.append(full_url)
|
| 191 |
+
|
| 192 |
+
links
|
| 193 |
+
|
| 194 |
+
# %%
|
| 195 |
+
# Limit the number of PDFs to download
|
| 196 |
+
NUM_PDFS = 2 # Set to None to download all PDFs
|
| 197 |
+
links = links[:NUM_PDFS] if NUM_PDFS else links
|
| 198 |
+
links
|
| 199 |
+
|
| 200 |
+
# %%
|
| 201 |
+
from nest_asyncio import apply
|
| 202 |
+
from typing import List
|
| 203 |
+
|
| 204 |
+
apply()
|
| 205 |
+
|
| 206 |
+
max_attempts = 3
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
async def download_pdf(session, url, filename):
|
| 210 |
+
attempt = 0
|
| 211 |
+
while attempt < max_attempts:
|
| 212 |
+
try:
|
| 213 |
+
response = await session.get(url)
|
| 214 |
+
response.raise_for_status()
|
| 215 |
+
|
| 216 |
+
# Use Content-Disposition header to get the filename if available
|
| 217 |
+
content_disposition = response.headers.get("Content-Disposition")
|
| 218 |
+
if content_disposition:
|
| 219 |
+
import re
|
| 220 |
+
|
| 221 |
+
fname = re.findall('filename="(.+)"', content_disposition)
|
| 222 |
+
if fname:
|
| 223 |
+
filename = fname[0]
|
| 224 |
+
|
| 225 |
+
# Ensure the filename is safe to use on the filesystem
|
| 226 |
+
safe_filename = filename.replace("/", "_").replace("\\", "_")
|
| 227 |
+
if not safe_filename or safe_filename == "_":
|
| 228 |
+
print(f"Invalid filename: {filename}")
|
| 229 |
+
continue
|
| 230 |
+
|
| 231 |
+
filepath = os.path.join("pdfs", safe_filename)
|
| 232 |
+
with open(filepath, "wb") as f:
|
| 233 |
+
f.write(response.content)
|
| 234 |
+
print(f"Downloaded {safe_filename}")
|
| 235 |
+
return filepath
|
| 236 |
+
except Exception as e:
|
| 237 |
+
print(f"Error downloading {filename}: {e}")
|
| 238 |
+
print(f"Retrying ({attempt})...")
|
| 239 |
+
await asyncio.sleep(1) # Wait a bit before retrying
|
| 240 |
+
attempt += 1
|
| 241 |
+
return None
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
async def download_pdfs(links: List[str]) -> List[dict]:
|
| 245 |
+
"""Download PDFs from a list of URLs. Add the filename to the dictionary."""
|
| 246 |
+
async with httpx.AsyncClient() as client:
|
| 247 |
+
tasks = []
|
| 248 |
+
|
| 249 |
+
for idx, link in enumerate(links):
|
| 250 |
+
# Try to get the filename from the URL
|
| 251 |
+
path = urlparse(link).path
|
| 252 |
+
filename = os.path.basename(path)
|
| 253 |
+
|
| 254 |
+
# If filename is empty,skip
|
| 255 |
+
if not filename:
|
| 256 |
+
continue
|
| 257 |
+
tasks.append(download_pdf(client, link, filename))
|
| 258 |
+
|
| 259 |
+
# Run the tasks concurrently
|
| 260 |
+
paths = await asyncio.gather(*tasks)
|
| 261 |
+
pdf_files = [
|
| 262 |
+
{"url": link, "path": path} for link, path in zip(links, paths) if path
|
| 263 |
+
]
|
| 264 |
+
return pdf_files
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# Create the pdfs directory if it doesn't exist
|
| 268 |
+
os.makedirs("pdfs", exist_ok=True)
|
| 269 |
+
# Now run the download_pdfs function with the URL
|
| 270 |
+
pdfs = asyncio.run(download_pdfs(links))
|
| 271 |
+
|
| 272 |
+
# %%
|
| 273 |
+
pdfs
|
| 274 |
+
|
| 275 |
+
# %% [markdown]
|
| 276 |
+
# ## 2. Convert PDFs to Images
|
| 277 |
+
#
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# %%
|
| 281 |
+
def get_pdf_images(pdf_path):
|
| 282 |
+
reader = PdfReader(pdf_path)
|
| 283 |
+
page_texts = []
|
| 284 |
+
for page_number in range(len(reader.pages)):
|
| 285 |
+
page = reader.pages[page_number]
|
| 286 |
+
text = page.extract_text()
|
| 287 |
+
page_texts.append(text)
|
| 288 |
+
images = convert_from_path(pdf_path)
|
| 289 |
+
# Convert to PIL images
|
| 290 |
+
assert len(images) == len(page_texts)
|
| 291 |
+
return images, page_texts
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
pdf_folder = "pdfs"
|
| 295 |
+
pdf_pages = []
|
| 296 |
+
for pdf in tqdm(pdfs):
|
| 297 |
+
pdf_file = pdf["path"]
|
| 298 |
+
title = os.path.splitext(os.path.basename(pdf_file))[0]
|
| 299 |
+
images, texts = get_pdf_images(pdf_file)
|
| 300 |
+
for page_no, (image, text) in enumerate(zip(images, texts)):
|
| 301 |
+
pdf_pages.append(
|
| 302 |
+
{
|
| 303 |
+
"title": title,
|
| 304 |
+
"url": pdf["url"],
|
| 305 |
+
"path": pdf_file,
|
| 306 |
+
"image": image,
|
| 307 |
+
"text": text,
|
| 308 |
+
"page_no": page_no,
|
| 309 |
+
}
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# %%
|
| 313 |
+
len(pdf_pages)
|
| 314 |
+
|
| 315 |
+
# %%
|
| 316 |
+
from collections import Counter
|
| 317 |
+
|
| 318 |
+
# Print the length of the text fields - mean, max and min
|
| 319 |
+
text_lengths = [len(page["text"]) for page in pdf_pages]
|
| 320 |
+
print(f"Mean text length: {np.mean(text_lengths)}")
|
| 321 |
+
print(f"Max text length: {np.max(text_lengths)}")
|
| 322 |
+
print(f"Min text length: {np.min(text_lengths)}")
|
| 323 |
+
print(f"Median text length: {np.median(text_lengths)}")
|
| 324 |
+
print(f"Number of text with length == 0: {Counter(text_lengths)[0]}")
|
| 325 |
+
|
| 326 |
+
# %% [markdown]
|
| 327 |
+
# ## 3. Generate Queries
|
| 328 |
+
#
|
| 329 |
+
# In this step, we want to generate queries for each page image.
|
| 330 |
+
# These will be useful for 2 reasons:
|
| 331 |
+
#
|
| 332 |
+
# 1. We can use these queries as typeahead suggestions in the search bar.
|
| 333 |
+
# 2. We can use the queries to generate an evaluation dataset. See [Improving Retrieval with LLM-as-a-judge](https://blog.vespa.ai/improving-retrieval-with-llm-as-a-judge/) for a deeper dive into this topic.
|
| 334 |
+
#
|
| 335 |
+
# The prompt for generating queries is taken from [this](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html#an-update-retrieval-focused-prompt) wonderful blog post by Daniel van Strien.
|
| 336 |
+
#
|
| 337 |
+
# We will use the Gemini API to generate these queries, with `gemini-1.5-flash-8b` as the model.
|
| 338 |
+
#
|
| 339 |
+
|
| 340 |
+
# %%
|
| 341 |
+
from pydantic import BaseModel
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class GeneratedQueries(BaseModel):
|
| 345 |
+
broad_topical_question: str
|
| 346 |
+
broad_topical_query: str
|
| 347 |
+
specific_detail_question: str
|
| 348 |
+
specific_detail_query: str
|
| 349 |
+
visual_element_question: str
|
| 350 |
+
visual_element_query: str
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def get_retrieval_prompt() -> Tuple[str, GeneratedQueries]:
|
| 354 |
+
prompt = (
|
| 355 |
+
prompt
|
| 356 |
+
) = """You are an investor, stock analyst and financial expert. You will be presented an image of a document page from a report published by the Norwegian Government Pension Fund Global (GPFG). The report may be annual or quarterly reports, or policy reports, on topics such as responsible investment, risk etc.
|
| 357 |
+
Your task is to generate retrieval queries and questions that you would use to retrieve this document (or ask based on this document) in a large corpus.
|
| 358 |
+
Please generate 3 different types of retrieval queries and questions.
|
| 359 |
+
A retrieval query is a keyword based query, made up of 2-5 words, that you would type into a search engine to find this document.
|
| 360 |
+
A question is a natural language question that you would ask, for which the document contains the answer.
|
| 361 |
+
The queries should be of the following types:
|
| 362 |
+
1. A broad topical query: This should cover the main subject of the document.
|
| 363 |
+
2. A specific detail query: This should cover a specific detail or aspect of the document.
|
| 364 |
+
3. A visual element query: This should cover a visual element of the document, such as a chart, graph, or image.
|
| 365 |
+
|
| 366 |
+
Important guidelines:
|
| 367 |
+
- Ensure the queries are relevant for retrieval tasks, not just describing the page content.
|
| 368 |
+
- Use a fact-based natural language style for the questions.
|
| 369 |
+
- Frame the queries as if someone is searching for this document in a large corpus.
|
| 370 |
+
- Make the queries diverse and representative of different search strategies.
|
| 371 |
+
|
| 372 |
+
Format your response as a JSON object with the structure of the following example:
|
| 373 |
+
{
|
| 374 |
+
"broad_topical_question": "What was the Responsible Investment Policy in 2019?",
|
| 375 |
+
"broad_topical_query": "responsible investment policy 2019",
|
| 376 |
+
"specific_detail_question": "What is the percentage of investments in renewable energy?",
|
| 377 |
+
"specific_detail_query": "renewable energy investments percentage",
|
| 378 |
+
"visual_element_question": "What is the trend of total holding value over time?",
|
| 379 |
+
"visual_element_query": "total holding value trend"
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
If there are no relevant visual elements, provide an empty string for the visual element question and query.
|
| 383 |
+
Here is the document image to analyze:
|
| 384 |
+
Generate the queries based on this image and provide the response in the specified JSON format.
|
| 385 |
+
Only return JSON. Don't return any extra explanation text. """
|
| 386 |
+
|
| 387 |
+
return prompt, GeneratedQueries
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
prompt_text, pydantic_model = get_retrieval_prompt()
|
| 391 |
+
|
| 392 |
+
# %%
|
| 393 |
+
gemini_model = genai.GenerativeModel("gemini-1.5-flash-8b")
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def generate_queries(image, prompt_text, pydantic_model):
|
| 397 |
+
try:
|
| 398 |
+
response = gemini_model.generate_content(
|
| 399 |
+
[image, "\n\n", prompt_text],
|
| 400 |
+
generation_config=genai.GenerationConfig(
|
| 401 |
+
response_mime_type="application/json",
|
| 402 |
+
response_schema=pydantic_model,
|
| 403 |
+
),
|
| 404 |
+
)
|
| 405 |
+
queries = json.loads(response.text)
|
| 406 |
+
except Exception as _e:
|
| 407 |
+
queries = {
|
| 408 |
+
"broad_topical_question": "",
|
| 409 |
+
"broad_topical_query": "",
|
| 410 |
+
"specific_detail_question": "",
|
| 411 |
+
"specific_detail_query": "",
|
| 412 |
+
"visual_element_question": "",
|
| 413 |
+
"visual_element_query": "",
|
| 414 |
+
}
|
| 415 |
+
return queries
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
# %%
|
| 419 |
+
for pdf in tqdm(pdf_pages):
|
| 420 |
+
image = pdf.get("image")
|
| 421 |
+
pdf["queries"] = generate_queries(image, prompt_text, pydantic_model)
|
| 422 |
+
|
| 423 |
+
# %%
|
| 424 |
+
pdf_pages[46]["image"]
|
| 425 |
+
|
| 426 |
+
# %%
|
| 427 |
+
pdf_pages[46]["queries"]
|
| 428 |
+
|
| 429 |
+
# %%
|
| 430 |
+
# Generate queries async - keeping for now as we probably need when applying to the full dataset
|
| 431 |
+
# import asyncio
|
| 432 |
+
# from tenacity import retry, stop_after_attempt, wait_exponential
|
| 433 |
+
# import google.generativeai as genai
|
| 434 |
+
# from tqdm.asyncio import tqdm_asyncio
|
| 435 |
+
|
| 436 |
+
# max_in_flight = 200 # Maximum number of concurrent requests
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
# async def generate_queries_for_image_async(model, image, semaphore):
|
| 440 |
+
# @retry(stop=stop_after_attempt(3), wait=wait_exponential(), reraise=True)
|
| 441 |
+
# async def _generate():
|
| 442 |
+
# async with semaphore:
|
| 443 |
+
# result = await model.generate_content_async(
|
| 444 |
+
# [image, "\n\n", prompt_text],
|
| 445 |
+
# generation_config=genai.GenerationConfig(
|
| 446 |
+
# response_mime_type="application/json",
|
| 447 |
+
# response_schema=pydantic_model,
|
| 448 |
+
# ),
|
| 449 |
+
# )
|
| 450 |
+
# return json.loads(result.text)
|
| 451 |
+
|
| 452 |
+
# try:
|
| 453 |
+
# return await _generate()
|
| 454 |
+
# except Exception as e:
|
| 455 |
+
# print(f"Error generating queries for image: {e}")
|
| 456 |
+
# return None # Return None or handle as needed
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
# async def enrich_pdfs():
|
| 460 |
+
# gemini_model = genai.GenerativeModel("gemini-1.5-flash-8b")
|
| 461 |
+
# semaphore = asyncio.Semaphore(max_in_flight)
|
| 462 |
+
# tasks = []
|
| 463 |
+
# for pdf in pdf_pages:
|
| 464 |
+
# pdf["queries"] = []
|
| 465 |
+
# image = pdf.get("image")
|
| 466 |
+
# if image:
|
| 467 |
+
# task = generate_queries_for_image_async(gemini_model, image, semaphore)
|
| 468 |
+
# tasks.append((pdf, task))
|
| 469 |
+
|
| 470 |
+
# # Run the tasks concurrently using asyncio.gather()
|
| 471 |
+
# for pdf, task in tqdm_asyncio(tasks):
|
| 472 |
+
# result = await task
|
| 473 |
+
# if result:
|
| 474 |
+
# pdf["queries"] = result
|
| 475 |
+
# return pdf_pages
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
# pdf_pages = asyncio.run(enrich_pdfs())
|
| 479 |
+
|
| 480 |
+
# %%
|
| 481 |
+
# write title, url, page_no, text, queries, not image to JSON
|
| 482 |
+
with open("output/pdf_pages.json", "w") as f:
|
| 483 |
+
to_write = [{k: v for k, v in pdf.items() if k != "image"} for pdf in pdf_pages]
|
| 484 |
+
json.dump(to_write, f, indent=2)
|
| 485 |
+
|
| 486 |
+
# with open("pdfs/pdf_pages.json", "r") as f:
|
| 487 |
+
# saved_pdf_pages = json.load(f)
|
| 488 |
+
# for pdf, saved_pdf in zip(pdf_pages, saved_pdf_pages):
|
| 489 |
+
# pdf.update(saved_pdf)
|
| 490 |
+
|
| 491 |
+
# %% [markdown]
|
| 492 |
+
# ## 4. Generate embeddings
|
| 493 |
+
#
|
| 494 |
+
# Now that we have the queries, we can use the ColPali model to generate embeddings for each page image.
|
| 495 |
+
#
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
# %%
|
| 499 |
+
def generate_embeddings(images, model, processor, batch_size=2) -> np.ndarray:
|
| 500 |
+
"""
|
| 501 |
+
Generate embeddings for a list of images.
|
| 502 |
+
Move to CPU only once per batch.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
images (List[PIL.Image]): List of PIL images.
|
| 506 |
+
model (nn.Module): The model to generate embeddings.
|
| 507 |
+
processor: The processor to preprocess images.
|
| 508 |
+
batch_size (int, optional): Batch size for processing. Defaults to 64.
|
| 509 |
+
|
| 510 |
+
Returns:
|
| 511 |
+
np.ndarray: Embeddings for the images, shape
|
| 512 |
+
(len(images), processor.max_patch_length (1030 for ColPali), model.config.hidden_size (Patch embedding dimension - 128 for ColPali)).
|
| 513 |
+
"""
|
| 514 |
+
embeddings_list = []
|
| 515 |
+
|
| 516 |
+
def collate_fn(batch):
|
| 517 |
+
# Batch is a list of images
|
| 518 |
+
return processor.process_images(batch) # Should return a dict of tensors
|
| 519 |
+
|
| 520 |
+
dataloader = DataLoader(
|
| 521 |
+
images,
|
| 522 |
+
shuffle=False,
|
| 523 |
+
collate_fn=collate_fn,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
for batch_doc in tqdm(dataloader, desc="Generating embeddings"):
|
| 527 |
+
with torch.no_grad():
|
| 528 |
+
# Move batch to the device
|
| 529 |
+
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
| 530 |
+
embeddings_batch = model(**batch_doc)
|
| 531 |
+
embeddings_list.append(torch.unbind(embeddings_batch.to("cpu"), dim=0))
|
| 532 |
+
# Concatenate all embeddings and create a numpy array
|
| 533 |
+
all_embeddings = np.concatenate(embeddings_list, axis=0)
|
| 534 |
+
return all_embeddings
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
# %%
|
| 538 |
+
# Generate embeddings for all images
|
| 539 |
+
images = [pdf["image"] for pdf in pdf_pages]
|
| 540 |
+
embeddings = generate_embeddings(images, model, processor)
|
| 541 |
+
|
| 542 |
+
# %%
|
| 543 |
+
embeddings.shape
|
| 544 |
+
|
| 545 |
+
# %% [markdown]
|
| 546 |
+
# ## 5. Prepare Data on Vespa Format
|
| 547 |
+
#
|
| 548 |
+
# Now, that we have all the data we need, all that remains is to make sure it is in the right format for Vespa.
|
| 549 |
+
#
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
# %%
|
| 553 |
+
def float_to_binary_embedding(float_query_embedding: dict) -> dict:
|
| 554 |
+
"""Utility function to convert float query embeddings to binary query embeddings."""
|
| 555 |
+
binary_query_embeddings = {}
|
| 556 |
+
for k, v in float_query_embedding.items():
|
| 557 |
+
binary_vector = (
|
| 558 |
+
np.packbits(np.where(np.array(v) > 0, 1, 0)).astype(np.int8).tolist()
|
| 559 |
+
)
|
| 560 |
+
binary_query_embeddings[k] = binary_vector
|
| 561 |
+
return binary_query_embeddings
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
# %%
|
| 565 |
+
vespa_feed = []
|
| 566 |
+
for pdf, embedding in zip(pdf_pages, embeddings):
|
| 567 |
+
url = pdf["url"]
|
| 568 |
+
title = pdf["title"]
|
| 569 |
+
image = pdf["image"]
|
| 570 |
+
text = pdf.get("text", "")
|
| 571 |
+
page_no = pdf["page_no"]
|
| 572 |
+
query_dict = pdf["queries"]
|
| 573 |
+
questions = [v for k, v in query_dict.items() if "question" in k and v]
|
| 574 |
+
queries = [v for k, v in query_dict.items() if "query" in k and v]
|
| 575 |
+
base_64_image = get_base64_image(
|
| 576 |
+
scale_image(image, 32), add_url_prefix=False
|
| 577 |
+
) # Scaled down image to return fast on search (~1kb)
|
| 578 |
+
base_64_full_image = get_base64_image(image, add_url_prefix=False)
|
| 579 |
+
embedding_dict = {k: v for k, v in enumerate(embedding)}
|
| 580 |
+
binary_embedding = float_to_binary_embedding(embedding_dict)
|
| 581 |
+
# id_hash should be md5 hash of url and page_number
|
| 582 |
+
id_hash = hashlib.md5(f"{url}_{page_no}".encode()).hexdigest()
|
| 583 |
+
page = {
|
| 584 |
+
"id": id_hash,
|
| 585 |
+
"fields": {
|
| 586 |
+
"id": id_hash,
|
| 587 |
+
"url": url,
|
| 588 |
+
"title": title,
|
| 589 |
+
"page_number": page_no,
|
| 590 |
+
"blur_image": base_64_image,
|
| 591 |
+
"full_image": base_64_full_image,
|
| 592 |
+
"text": text,
|
| 593 |
+
"embedding": binary_embedding,
|
| 594 |
+
"queries": queries,
|
| 595 |
+
"questions": questions,
|
| 596 |
+
},
|
| 597 |
+
}
|
| 598 |
+
vespa_feed.append(page)
|
| 599 |
+
|
| 600 |
+
# %%
|
| 601 |
+
# We will prepare the Vespa feed data, including the embeddings and the generated queries
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
# Save vespa_feed to vespa_feed.json
|
| 605 |
+
os.makedirs("output", exist_ok=True)
|
| 606 |
+
with open("output/vespa_feed.json", "w") as f:
|
| 607 |
+
vespa_feed_to_save = []
|
| 608 |
+
for page in vespa_feed:
|
| 609 |
+
document_id = page["id"]
|
| 610 |
+
put_id = f"id:{VESPA_APPLICATION_NAME}:{VESPA_SCHEMA_NAME}::{document_id}"
|
| 611 |
+
vespa_feed_to_save.append({"put": put_id, "fields": page["fields"]})
|
| 612 |
+
json.dump(vespa_feed_to_save, f)
|
| 613 |
+
|
| 614 |
+
# %%
|
| 615 |
+
# import json
|
| 616 |
+
|
| 617 |
+
# with open("output/vespa_feed.json", "r") as f:
|
| 618 |
+
# vespa_feed = json.load(f)
|
| 619 |
+
|
| 620 |
+
# %%
|
| 621 |
+
len(vespa_feed)
|
| 622 |
+
|
| 623 |
+
# %% [markdown]
|
| 624 |
+
# ## 5. Prepare Vespa Application
|
| 625 |
+
#
|
| 626 |
+
|
| 627 |
+
# %%
|
| 628 |
+
# Define the Vespa schema
|
| 629 |
+
colpali_schema = Schema(
|
| 630 |
+
name=VESPA_SCHEMA_NAME,
|
| 631 |
+
document=Document(
|
| 632 |
+
fields=[
|
| 633 |
+
Field(
|
| 634 |
+
name="id",
|
| 635 |
+
type="string",
|
| 636 |
+
indexing=["summary", "index"],
|
| 637 |
+
match=["word"],
|
| 638 |
+
),
|
| 639 |
+
Field(name="url", type="string", indexing=["summary", "index"]),
|
| 640 |
+
Field(
|
| 641 |
+
name="title",
|
| 642 |
+
type="string",
|
| 643 |
+
indexing=["summary", "index"],
|
| 644 |
+
match=["text"],
|
| 645 |
+
index="enable-bm25",
|
| 646 |
+
),
|
| 647 |
+
Field(name="page_number", type="int", indexing=["summary", "attribute"]),
|
| 648 |
+
Field(name="blur_image", type="raw", indexing=["summary"]),
|
| 649 |
+
Field(name="full_image", type="raw", indexing=["summary"]),
|
| 650 |
+
Field(
|
| 651 |
+
name="text",
|
| 652 |
+
type="string",
|
| 653 |
+
indexing=["summary", "index"],
|
| 654 |
+
match=["text"],
|
| 655 |
+
index="enable-bm25",
|
| 656 |
+
),
|
| 657 |
+
Field(
|
| 658 |
+
name="embedding",
|
| 659 |
+
type="tensor<int8>(patch{}, v[16])",
|
| 660 |
+
indexing=[
|
| 661 |
+
"attribute",
|
| 662 |
+
"index",
|
| 663 |
+
],
|
| 664 |
+
ann=HNSW(
|
| 665 |
+
distance_metric="hamming",
|
| 666 |
+
max_links_per_node=32,
|
| 667 |
+
neighbors_to_explore_at_insert=400,
|
| 668 |
+
),
|
| 669 |
+
),
|
| 670 |
+
Field(
|
| 671 |
+
name="questions",
|
| 672 |
+
type="array<string>",
|
| 673 |
+
indexing=["summary", "index", "attribute"],
|
| 674 |
+
index="enable-bm25",
|
| 675 |
+
stemming="best",
|
| 676 |
+
),
|
| 677 |
+
Field(
|
| 678 |
+
name="queries",
|
| 679 |
+
type="array<string>",
|
| 680 |
+
indexing=["summary", "index", "attribute"],
|
| 681 |
+
index="enable-bm25",
|
| 682 |
+
stemming="best",
|
| 683 |
+
),
|
| 684 |
+
# Add synthetic fields for the questions and queries
|
| 685 |
+
# Field(
|
| 686 |
+
# name="questions_exact",
|
| 687 |
+
# type="array<string>",
|
| 688 |
+
# indexing=["input questions", "index", "attribute"],
|
| 689 |
+
# match=["word"],
|
| 690 |
+
# is_document_field=False,
|
| 691 |
+
# ),
|
| 692 |
+
# Field(
|
| 693 |
+
# name="queries_exact",
|
| 694 |
+
# type="array<string>",
|
| 695 |
+
# indexing=["input queries", "index"],
|
| 696 |
+
# match=["word"],
|
| 697 |
+
# is_document_field=False,
|
| 698 |
+
# ),
|
| 699 |
+
]
|
| 700 |
+
),
|
| 701 |
+
fieldsets=[
|
| 702 |
+
FieldSet(
|
| 703 |
+
name="default",
|
| 704 |
+
fields=["title", "url", "blur_image", "page_number", "text"],
|
| 705 |
+
),
|
| 706 |
+
FieldSet(
|
| 707 |
+
name="image",
|
| 708 |
+
fields=["full_image"],
|
| 709 |
+
),
|
| 710 |
+
],
|
| 711 |
+
document_summaries=[
|
| 712 |
+
DocumentSummary(
|
| 713 |
+
name="default",
|
| 714 |
+
summary_fields=[
|
| 715 |
+
Summary(
|
| 716 |
+
name="text",
|
| 717 |
+
fields=[("bolding", "on")],
|
| 718 |
+
),
|
| 719 |
+
Summary(
|
| 720 |
+
name="snippet",
|
| 721 |
+
fields=[("source", "text"), "dynamic"],
|
| 722 |
+
),
|
| 723 |
+
],
|
| 724 |
+
from_disk=True,
|
| 725 |
+
),
|
| 726 |
+
],
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
# Define similarity functions used in all rank profiles
|
| 730 |
+
mapfunctions = [
|
| 731 |
+
Function(
|
| 732 |
+
name="similarities", # computes similarity scores between each query token and image patch
|
| 733 |
+
expression="""
|
| 734 |
+
sum(
|
| 735 |
+
query(qt) * unpack_bits(attribute(embedding)), v
|
| 736 |
+
)
|
| 737 |
+
""",
|
| 738 |
+
),
|
| 739 |
+
Function(
|
| 740 |
+
name="normalized", # normalizes the similarity scores to [-1, 1]
|
| 741 |
+
expression="""
|
| 742 |
+
(similarities - reduce(similarities, min)) / (reduce((similarities - reduce(similarities, min)), max)) * 2 - 1
|
| 743 |
+
""",
|
| 744 |
+
),
|
| 745 |
+
Function(
|
| 746 |
+
name="quantized", # quantizes the normalized similarity scores to signed 8-bit integers [-128, 127]
|
| 747 |
+
expression="""
|
| 748 |
+
cell_cast(normalized * 127.999, int8)
|
| 749 |
+
""",
|
| 750 |
+
),
|
| 751 |
+
]
|
| 752 |
+
|
| 753 |
+
# Define the 'bm25' rank profile
|
| 754 |
+
colpali_bm25_profile = RankProfile(
|
| 755 |
+
name="bm25",
|
| 756 |
+
inputs=[("query(qt)", "tensor<float>(querytoken{}, v[128])")],
|
| 757 |
+
first_phase="bm25(title) + bm25(text)",
|
| 758 |
+
functions=mapfunctions,
|
| 759 |
+
summary_features=["quantized"],
|
| 760 |
+
)
|
| 761 |
+
colpali_schema.add_rank_profile(colpali_bm25_profile)
|
| 762 |
+
|
| 763 |
+
# Update the 'default' rank profile
|
| 764 |
+
colpali_profile = RankProfile(
|
| 765 |
+
name="default",
|
| 766 |
+
inputs=[("query(qt)", "tensor<float>(querytoken{}, v[128])")],
|
| 767 |
+
first_phase="bm25_score",
|
| 768 |
+
second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=10),
|
| 769 |
+
functions=mapfunctions
|
| 770 |
+
+ [
|
| 771 |
+
Function(
|
| 772 |
+
name="max_sim",
|
| 773 |
+
expression="""
|
| 774 |
+
sum(
|
| 775 |
+
reduce(
|
| 776 |
+
sum(
|
| 777 |
+
query(qt) * unpack_bits(attribute(embedding)), v
|
| 778 |
+
),
|
| 779 |
+
max, patch
|
| 780 |
+
),
|
| 781 |
+
querytoken
|
| 782 |
+
)
|
| 783 |
+
""",
|
| 784 |
+
),
|
| 785 |
+
Function(name="bm25_score", expression="bm25(title) + bm25(text)"),
|
| 786 |
+
],
|
| 787 |
+
summary_features=["quantized"],
|
| 788 |
+
)
|
| 789 |
+
colpali_schema.add_rank_profile(colpali_profile)
|
| 790 |
+
|
| 791 |
+
# Update the 'retrieval-and-rerank' rank profile
|
| 792 |
+
input_query_tensors = []
|
| 793 |
+
MAX_QUERY_TERMS = 64
|
| 794 |
+
for i in range(MAX_QUERY_TERMS):
|
| 795 |
+
input_query_tensors.append((f"query(rq{i})", "tensor<int8>(v[16])"))
|
| 796 |
+
|
| 797 |
+
input_query_tensors.extend(
|
| 798 |
+
[
|
| 799 |
+
("query(qt)", "tensor<float>(querytoken{}, v[128])"),
|
| 800 |
+
("query(qtb)", "tensor<int8>(querytoken{}, v[16])"),
|
| 801 |
+
]
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
colpali_retrieval_profile = RankProfile(
|
| 805 |
+
name="retrieval-and-rerank",
|
| 806 |
+
inputs=input_query_tensors,
|
| 807 |
+
first_phase="max_sim_binary",
|
| 808 |
+
second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=10),
|
| 809 |
+
functions=mapfunctions
|
| 810 |
+
+ [
|
| 811 |
+
Function(
|
| 812 |
+
name="max_sim",
|
| 813 |
+
expression="""
|
| 814 |
+
sum(
|
| 815 |
+
reduce(
|
| 816 |
+
sum(
|
| 817 |
+
query(qt) * unpack_bits(attribute(embedding)), v
|
| 818 |
+
),
|
| 819 |
+
max, patch
|
| 820 |
+
),
|
| 821 |
+
querytoken
|
| 822 |
+
)
|
| 823 |
+
""",
|
| 824 |
+
),
|
| 825 |
+
Function(
|
| 826 |
+
name="max_sim_binary",
|
| 827 |
+
expression="""
|
| 828 |
+
sum(
|
| 829 |
+
reduce(
|
| 830 |
+
1 / (1 + sum(
|
| 831 |
+
hamming(query(qtb), attribute(embedding)), v)
|
| 832 |
+
),
|
| 833 |
+
max, patch
|
| 834 |
+
),
|
| 835 |
+
querytoken
|
| 836 |
+
)
|
| 837 |
+
""",
|
| 838 |
+
),
|
| 839 |
+
],
|
| 840 |
+
summary_features=["quantized"],
|
| 841 |
+
)
|
| 842 |
+
colpali_schema.add_rank_profile(colpali_retrieval_profile)
|
| 843 |
+
|
| 844 |
+
# %%
|
| 845 |
+
from vespa.configuration.services import (
|
| 846 |
+
services,
|
| 847 |
+
container,
|
| 848 |
+
search,
|
| 849 |
+
document_api,
|
| 850 |
+
document_processing,
|
| 851 |
+
clients,
|
| 852 |
+
client,
|
| 853 |
+
config,
|
| 854 |
+
content,
|
| 855 |
+
redundancy,
|
| 856 |
+
documents,
|
| 857 |
+
node,
|
| 858 |
+
certificate,
|
| 859 |
+
token,
|
| 860 |
+
document,
|
| 861 |
+
nodes,
|
| 862 |
+
)
|
| 863 |
+
from vespa.configuration.vt import vt
|
| 864 |
+
from vespa.package import ServicesConfiguration
|
| 865 |
+
|
| 866 |
+
service_config = ServicesConfiguration(
|
| 867 |
+
application_name=VESPA_APPLICATION_NAME,
|
| 868 |
+
services_config=services(
|
| 869 |
+
container(
|
| 870 |
+
search(),
|
| 871 |
+
document_api(),
|
| 872 |
+
document_processing(),
|
| 873 |
+
clients(
|
| 874 |
+
client(
|
| 875 |
+
certificate(file="security/clients.pem"),
|
| 876 |
+
id="mtls",
|
| 877 |
+
permissions="read,write",
|
| 878 |
+
),
|
| 879 |
+
client(
|
| 880 |
+
token(id=f"{VESPA_TOKEN_ID_WRITE}"),
|
| 881 |
+
id="token_write",
|
| 882 |
+
permissions="read,write",
|
| 883 |
+
),
|
| 884 |
+
client(
|
| 885 |
+
token(id=f"{VESPA_TOKEN_ID_READ}"),
|
| 886 |
+
id="token_read",
|
| 887 |
+
permissions="read",
|
| 888 |
+
),
|
| 889 |
+
),
|
| 890 |
+
config(
|
| 891 |
+
vt("tag")(
|
| 892 |
+
vt("bold")(
|
| 893 |
+
vt("open", "<strong>"),
|
| 894 |
+
vt("close", "</strong>"),
|
| 895 |
+
),
|
| 896 |
+
vt("separator", "..."),
|
| 897 |
+
),
|
| 898 |
+
name="container.qr-searchers",
|
| 899 |
+
),
|
| 900 |
+
id=f"{VESPA_APPLICATION_NAME}_container",
|
| 901 |
+
version="1.0",
|
| 902 |
+
),
|
| 903 |
+
content(
|
| 904 |
+
redundancy("1"),
|
| 905 |
+
documents(document(type="pdf_page", mode="index")),
|
| 906 |
+
nodes(node(distribution_key="0", hostalias="node1")),
|
| 907 |
+
config(
|
| 908 |
+
vt("max_matches", "2", replace_underscores=False),
|
| 909 |
+
vt("length", "1000"),
|
| 910 |
+
vt("surround_max", "500", replace_underscores=False),
|
| 911 |
+
vt("min_length", "300", replace_underscores=False),
|
| 912 |
+
name="vespa.config.search.summary.juniperrc",
|
| 913 |
+
),
|
| 914 |
+
id=f"{VESPA_APPLICATION_NAME}_content",
|
| 915 |
+
version="1.0",
|
| 916 |
+
),
|
| 917 |
+
version="1.0",
|
| 918 |
+
),
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
# %%
|
| 922 |
+
# Create the Vespa application package
|
| 923 |
+
vespa_application_package = ApplicationPackage(
|
| 924 |
+
name=VESPA_APPLICATION_NAME,
|
| 925 |
+
schema=[colpali_schema],
|
| 926 |
+
services_config=service_config,
|
| 927 |
+
)
|
| 928 |
+
|
| 929 |
+
# %% [markdown]
|
| 930 |
+
# ## 6. Deploy Vespa Application
|
| 931 |
+
#
|
| 932 |
+
|
| 933 |
+
# %%
|
| 934 |
+
VESPA_TEAM_API_KEY = os.getenv("VESPA_TEAM_API_KEY") or input(
|
| 935 |
+
"Enter Vespa team API key: "
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
# %%
|
| 939 |
+
vespa_cloud = VespaCloud(
|
| 940 |
+
tenant=VESPA_TENANT_NAME,
|
| 941 |
+
application=VESPA_APPLICATION_NAME,
|
| 942 |
+
key_content=VESPA_TEAM_API_KEY,
|
| 943 |
+
application_package=vespa_application_package,
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
# Deploy the application
|
| 947 |
+
vespa_cloud.deploy()
|
| 948 |
+
|
| 949 |
+
# Output the endpoint URL
|
| 950 |
+
endpoint_url = vespa_cloud.get_token_endpoint()
|
| 951 |
+
print(f"Application deployed. Token endpoint URL: {endpoint_url}")
|
| 952 |
+
|
| 953 |
+
# %% [markdown]
|
| 954 |
+
# Make sure to take note of the token endpoint_url.
|
| 955 |
+
# You need to put this in your `.env` file - `VESPA_APP_URL=https://abcd.vespa-app.cloud` - to access the Vespa application from your web application.
|
| 956 |
+
#
|
| 957 |
+
|
| 958 |
+
# %% [markdown]
|
| 959 |
+
# ## 8. Feed Data to Vespa
|
| 960 |
+
#
|
| 961 |
+
|
| 962 |
+
# %%
|
| 963 |
+
# Instantiate Vespa connection using token
|
| 964 |
+
app = Vespa(url=endpoint_url, vespa_cloud_secret_token=VESPA_CLOUD_SECRET_TOKEN)
|
| 965 |
+
app.get_application_status()
|
| 966 |
+
|
| 967 |
+
|
| 968 |
+
# %%
|
| 969 |
+
def callback(response: VespaResponse, id: str):
|
| 970 |
+
if not response.is_successful():
|
| 971 |
+
print(
|
| 972 |
+
f"Failed to feed document {id} with status code {response.status_code}: Reason {response.get_json()}"
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
|
| 976 |
+
# Feed data into Vespa asynchronously
|
| 977 |
+
app.feed_async_iterable(vespa_feed, schema=VESPA_SCHEMA_NAME, callback=callback)
|
pyproject.toml
CHANGED
|
@@ -8,7 +8,7 @@ license = { text = "Apache-2.0" }
|
|
| 8 |
dependencies = [
|
| 9 |
"python-fasthtml",
|
| 10 |
"huggingface-hub",
|
| 11 |
-
"pyvespa
|
| 12 |
"vespacli",
|
| 13 |
"torch",
|
| 14 |
"vidore-benchmark[interpretability]>=4.0.0,<5.0.0",
|
|
@@ -18,6 +18,7 @@ dependencies = [
|
|
| 18 |
"setuptools",
|
| 19 |
"python-dotenv",
|
| 20 |
"shad4fast>=1.2.1",
|
|
|
|
| 21 |
]
|
| 22 |
|
| 23 |
# dev-dependencies
|
|
@@ -27,3 +28,11 @@ dev = [
|
|
| 27 |
"python-dotenv",
|
| 28 |
"huggingface_hub[cli]"
|
| 29 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
dependencies = [
|
| 9 |
"python-fasthtml",
|
| 10 |
"huggingface-hub",
|
| 11 |
+
"pyvespa>=0.50.0",
|
| 12 |
"vespacli",
|
| 13 |
"torch",
|
| 14 |
"vidore-benchmark[interpretability]>=4.0.0,<5.0.0",
|
|
|
|
| 18 |
"setuptools",
|
| 19 |
"python-dotenv",
|
| 20 |
"shad4fast>=1.2.1",
|
| 21 |
+
"google-generativeai>=0.7.2"
|
| 22 |
]
|
| 23 |
|
| 24 |
# dev-dependencies
|
|
|
|
| 28 |
"python-dotenv",
|
| 29 |
"huggingface_hub[cli]"
|
| 30 |
]
|
| 31 |
+
feed = [
|
| 32 |
+
"ipykernel",
|
| 33 |
+
"jupytext",
|
| 34 |
+
"pydantic",
|
| 35 |
+
"beautifulsoup4",
|
| 36 |
+
"pdf2image",
|
| 37 |
+
"google-generativeai"
|
| 38 |
+
]
|
static/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|