Spaces:
Running
Running
deploy at 2024-08-24 17:35:22.783475
Browse files- main copy.py +861 -0
- main.py +11 -43
main copy.py
ADDED
|
@@ -0,0 +1,861 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fasthtml_hf import setup_hf_backup
|
| 2 |
+
from fasthtml.common import (
|
| 3 |
+
picolink,
|
| 4 |
+
serve,
|
| 5 |
+
Div,
|
| 6 |
+
Title,
|
| 7 |
+
Main,
|
| 8 |
+
Input,
|
| 9 |
+
Button,
|
| 10 |
+
A,
|
| 11 |
+
Section,
|
| 12 |
+
H2,
|
| 13 |
+
Ul,
|
| 14 |
+
Li,
|
| 15 |
+
P,
|
| 16 |
+
Img,
|
| 17 |
+
Details,
|
| 18 |
+
MarkdownJS,
|
| 19 |
+
HighlightJS,
|
| 20 |
+
Summary,
|
| 21 |
+
Script,
|
| 22 |
+
I,
|
| 23 |
+
Form,
|
| 24 |
+
RedirectResponse,
|
| 25 |
+
dataclass,
|
| 26 |
+
Favicon,
|
| 27 |
+
database,
|
| 28 |
+
get_key,
|
| 29 |
+
Table,
|
| 30 |
+
Thead,
|
| 31 |
+
Tr,
|
| 32 |
+
Th,
|
| 33 |
+
Tbody,
|
| 34 |
+
Td,
|
| 35 |
+
FileResponse,
|
| 36 |
+
fast_app,
|
| 37 |
+
Beforeware,
|
| 38 |
+
Hidden,
|
| 39 |
+
Request,
|
| 40 |
+
H3,
|
| 41 |
+
Style,
|
| 42 |
+
)
|
| 43 |
+
from fasthtml.components import Nav, Article, Header, Mark
|
| 44 |
+
from fasthtml.pico import Search, Grid, Fieldset, Label
|
| 45 |
+
from starlette.middleware import Middleware
|
| 46 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 47 |
+
from starlette.middleware.sessions import SessionMiddleware
|
| 48 |
+
from vespa.application import Vespa
|
| 49 |
+
import json
|
| 50 |
+
import os
|
| 51 |
+
import re
|
| 52 |
+
import time
|
| 53 |
+
from hmac import compare_digest
|
| 54 |
+
from io import StringIO
|
| 55 |
+
import csv
|
| 56 |
+
import tempfile
|
| 57 |
+
from enum import Enum
|
| 58 |
+
from typing import Tuple as T
|
| 59 |
+
from urllib.parse import quote
|
| 60 |
+
import uuid
|
| 61 |
+
|
| 62 |
+
DEV_MODE = False
|
| 63 |
+
|
| 64 |
+
if DEV_MODE:
|
| 65 |
+
print("Running in DEV_MODE - Hot reload enabled")
|
| 66 |
+
print("Loading environment variables from .env")
|
| 67 |
+
from dotenv import load_dotenv
|
| 68 |
+
|
| 69 |
+
load_dotenv()
|
| 70 |
+
else:
|
| 71 |
+
print("DEV_MODE disabled - environment variables loaded from system")
|
| 72 |
+
|
| 73 |
+
vespa_app_url = os.getenv("VESPA_APP_URL", None)
|
| 74 |
+
if vespa_app_url is None:
|
| 75 |
+
print("Please set the VESPA_APP_URL environment variable")
|
| 76 |
+
exit(1)
|
| 77 |
+
|
| 78 |
+
ADMIN_NAME = os.getenv("ADMIN_NAME", "admin")
|
| 79 |
+
ADMIN_PWD = os.getenv("ADMIN_PWD", "admin")
|
| 80 |
+
|
| 81 |
+
vespa_app: Vespa = Vespa(
|
| 82 |
+
url=vespa_app_url,
|
| 83 |
+
vespa_cloud_secret_token=os.getenv("VESPA_CLOUD_SECRET_TOKEN"),
|
| 84 |
+
)
|
| 85 |
+
status = vespa_app.get_application_status()
|
| 86 |
+
if status is None:
|
| 87 |
+
print("Could not connect to Vespa application")
|
| 88 |
+
else:
|
| 89 |
+
print("Connected to Vespa application!")
|
| 90 |
+
|
| 91 |
+
fa = Script(src="https://kit.fontawesome.com/664eb1a115.js", crossorigin="anonymous")
|
| 92 |
+
favicon = Favicon(
|
| 93 |
+
"https://search.vespa.ai/favicon.ico",
|
| 94 |
+
"https://search.vespa.ai/favicon.ico",
|
| 95 |
+
)
|
| 96 |
+
DB_FILE = "db/vespa.db"
|
| 97 |
+
db = database(DB_FILE)
|
| 98 |
+
queries = db.t.queries
|
| 99 |
+
if queries not in db.t:
|
| 100 |
+
# You can pass a dict, or kwargs, to most MiniDataAPI methods.
|
| 101 |
+
queries.create(
|
| 102 |
+
dict(qid=int, query=str, ranking=str, sess_id=str, timestamp=int), pk="qid"
|
| 103 |
+
)
|
| 104 |
+
# Add autoincrement to the qid column
|
| 105 |
+
db.query("ALTER TABLE queries ADD COLUMN qid INTEGER PRIMARY KEY AUTOINCREMENT")
|
| 106 |
+
Query = queries.dataclass()
|
| 107 |
+
|
| 108 |
+
# Add a classmethod to the Query dataclass to convert timestamp field to a human readable format
|
| 109 |
+
Query.get_datetime = lambda self: time.strftime(
|
| 110 |
+
"%Y-%m-%d %H:%M:%S", time.localtime(self.timestamp)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Status code 303 is a redirect that can change POST to GET,
|
| 114 |
+
# so it's appropriate for a login page.
|
| 115 |
+
login_redir = RedirectResponse("/login", status_code=303)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def user_auth_before(req, sess):
|
| 119 |
+
# The `auth` key in the request scope is automatically provided
|
| 120 |
+
# to any handler which requests it, and can not be injected
|
| 121 |
+
# by the user using query params, cookies, etc, so it should
|
| 122 |
+
# be secure to use.
|
| 123 |
+
print(f"Session Data before route: {sess}")
|
| 124 |
+
auth = req.scope["auth"] = sess.get("auth", None)
|
| 125 |
+
print(f"Auth: {auth}")
|
| 126 |
+
if not auth:
|
| 127 |
+
return login_redir
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
spinner_css = Style("""
|
| 131 |
+
.htmx-indicator {
|
| 132 |
+
display: none; /* Hide spinner by default */
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
.htmx-indicator.htmx-request {
|
| 136 |
+
display: block;
|
| 137 |
+
}
|
| 138 |
+
""")
|
| 139 |
+
|
| 140 |
+
headers = (
|
| 141 |
+
picolink,
|
| 142 |
+
MarkdownJS(),
|
| 143 |
+
HighlightJS(langs=["json", "python"]),
|
| 144 |
+
favicon,
|
| 145 |
+
fa,
|
| 146 |
+
spinner_css,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# Read file contents once before starting the server
|
| 150 |
+
with open("README.md") as f:
|
| 151 |
+
README = f.read()
|
| 152 |
+
with open("main.py") as f:
|
| 153 |
+
SOURCE = f.read()
|
| 154 |
+
|
| 155 |
+
# Sesskey
|
| 156 |
+
sess_key_path = "session/.sesskey"
|
| 157 |
+
# Make sure session directory exists
|
| 158 |
+
os.makedirs("session", exist_ok=True)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# Middleware
|
| 162 |
+
class XFrameOptionsMiddleware(BaseHTTPMiddleware):
|
| 163 |
+
async def dispatch(self, request, call_next):
|
| 164 |
+
response = await call_next(request)
|
| 165 |
+
response.headers["X-Frame-Options"] = "ALLOW-FROM https://huggingface.co/"
|
| 166 |
+
return response
|
| 167 |
+
|
| 168 |
+
class SessionLoggingMiddleware(BaseHTTPMiddleware):
|
| 169 |
+
async def dispatch(self, request, call_next):
|
| 170 |
+
print(f"Before request: Session data: {request.session}")
|
| 171 |
+
response = await call_next(request)
|
| 172 |
+
print(f"After request: Session data: {request.session}")
|
| 173 |
+
return response
|
| 174 |
+
|
| 175 |
+
class DebugSessionMiddleware(SessionMiddleware):
|
| 176 |
+
async def __call__(self, scope, receive, send):
|
| 177 |
+
print(f"DebugSessionMiddleware: Before processing - Scope: {scope}")
|
| 178 |
+
await super().__call__(scope, receive, send)
|
| 179 |
+
print(f"DebugSessionMiddleware: After processing - Scope: {scope}")
|
| 180 |
+
|
| 181 |
+
from starlette.middleware.cors import CORSMiddleware
|
| 182 |
+
|
| 183 |
+
middlewares = [
|
| 184 |
+
Middleware(
|
| 185 |
+
SessionMiddleware,
|
| 186 |
+
secret_key=get_key(fname=sess_key_path),
|
| 187 |
+
max_age=3600,
|
| 188 |
+
#same_site='lax',
|
| 189 |
+
),
|
| 190 |
+
Middleware(CORSMiddleware, allow_origins=['*']),
|
| 191 |
+
Middleware(XFrameOptionsMiddleware),
|
| 192 |
+
Middleware(SessionLoggingMiddleware),
|
| 193 |
+
#Middleware(DebugSessionMiddleware, secret_key=get_key(fname=sess_key_path)),
|
| 194 |
+
]
|
| 195 |
+
bware = Beforeware(
|
| 196 |
+
user_auth_before,
|
| 197 |
+
skip=[
|
| 198 |
+
r"/favicon\.ico",
|
| 199 |
+
r"/static/.*",
|
| 200 |
+
r".*\.css",
|
| 201 |
+
r".*\.js",
|
| 202 |
+
"/",
|
| 203 |
+
"/login",
|
| 204 |
+
"/search",
|
| 205 |
+
"/document/.*",
|
| 206 |
+
"/expand/.*",
|
| 207 |
+
"/source",
|
| 208 |
+
"/about",
|
| 209 |
+
],
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
app, rt = fast_app(
|
| 213 |
+
before=bware,
|
| 214 |
+
live=DEV_MODE,
|
| 215 |
+
hdrs=headers,
|
| 216 |
+
middleware=middlewares,
|
| 217 |
+
key_fname=sess_key_path,
|
| 218 |
+
same_site="None",
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
sesskey = get_key(fname=sess_key_path)
|
| 223 |
+
print(f"Session key: {sesskey}")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# enum class for rank profiles
|
| 227 |
+
class RankProfile(str, Enum):
|
| 228 |
+
bm25 = "bm25"
|
| 229 |
+
semantic = "semantic"
|
| 230 |
+
fusion = "fusion"
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def get_navbar(admin: bool):
|
| 234 |
+
print(f"In get_navbar: {admin}")
|
| 235 |
+
bar = Nav(
|
| 236 |
+
Ul(
|
| 237 |
+
Li(
|
| 238 |
+
A(
|
| 239 |
+
Img(src="https://vespa.ai/assets/vespa-ai-logo-heather.svg"),
|
| 240 |
+
href="https://cloud.vespa.ai",
|
| 241 |
+
target="_blank",
|
| 242 |
+
style="margin: 10px;",
|
| 243 |
+
),
|
| 244 |
+
)
|
| 245 |
+
),
|
| 246 |
+
Ul(H2("Vespa-fastHTML demo")),
|
| 247 |
+
Ul(
|
| 248 |
+
# A question mark icon with link to an about page
|
| 249 |
+
A(
|
| 250 |
+
I(cls="fa fa-question-circle fa-2x"),
|
| 251 |
+
href="/about",
|
| 252 |
+
style="margin: 10px;",
|
| 253 |
+
title="About this app",
|
| 254 |
+
),
|
| 255 |
+
A(
|
| 256 |
+
I(cls="fab fa-slack fa-2x"),
|
| 257 |
+
href="https://slack.vespa.ai/",
|
| 258 |
+
style="margin: 10px;",
|
| 259 |
+
target="_blank",
|
| 260 |
+
title="Join Vespa Slack channel",
|
| 261 |
+
),
|
| 262 |
+
A(
|
| 263 |
+
I(cls="fab fa-github fa-2x"),
|
| 264 |
+
href="https://github.com/vespa-engine/sample-apps/tree/master/examples/fasthtml-demo",
|
| 265 |
+
style="margin: 10px;",
|
| 266 |
+
target="_blank",
|
| 267 |
+
title="View source code on GitHub",
|
| 268 |
+
),
|
| 269 |
+
A(
|
| 270 |
+
I(cls="fa fa-code fa-2x"),
|
| 271 |
+
href="/source",
|
| 272 |
+
style="margin: 10px;",
|
| 273 |
+
title="View source code",
|
| 274 |
+
),
|
| 275 |
+
# Login icon (link to /login) show tooltip on hover. MAke it hidden if admin is logged in
|
| 276 |
+
A(
|
| 277 |
+
I(cls="fa fa-shield fa-2x"),
|
| 278 |
+
href="/login" if not admin else "/admin",
|
| 279 |
+
style="margin: 10px;",
|
| 280 |
+
title="Admin login",
|
| 281 |
+
),
|
| 282 |
+
# Logout icon if admin is logged in
|
| 283 |
+
A(
|
| 284 |
+
I(cls="fa fa-sign-out fa-2x"),
|
| 285 |
+
href="/logout",
|
| 286 |
+
style="margin: 10px;" if admin else "display: none;",
|
| 287 |
+
title="Logout",
|
| 288 |
+
),
|
| 289 |
+
),
|
| 290 |
+
# 10px margin to right of navbar
|
| 291 |
+
style="margin-right: 10px;",
|
| 292 |
+
)
|
| 293 |
+
return bar
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def spinner_div(hidden: bool = False):
|
| 297 |
+
return Div(
|
| 298 |
+
A(
|
| 299 |
+
id="spinner",
|
| 300 |
+
aria_busy="true",
|
| 301 |
+
cls="htmx-indicator",
|
| 302 |
+
style="font-size: 2em;",
|
| 303 |
+
),
|
| 304 |
+
style="text-align: center; margin-top: 40px;"
|
| 305 |
+
if not hidden
|
| 306 |
+
else "display: none;",
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
@app.route("/")
|
| 311 |
+
def get(sess):
|
| 312 |
+
# Can not get auth directly, as it is skipped in beforeware
|
| 313 |
+
auth = sess.get("auth", False)
|
| 314 |
+
queries = [
|
| 315 |
+
"Breast Cancer Cells Feed on Cholesterol",
|
| 316 |
+
"Treating Asthma With Plants vs. Pills",
|
| 317 |
+
"Testing Turmeric on Smokers",
|
| 318 |
+
"The Role of Pesticides in Parkinson's Disease",
|
| 319 |
+
]
|
| 320 |
+
return (
|
| 321 |
+
Title("Vespa demo"),
|
| 322 |
+
get_navbar(auth),
|
| 323 |
+
Main(
|
| 324 |
+
# Search bar
|
| 325 |
+
Search(
|
| 326 |
+
Input(
|
| 327 |
+
type="search",
|
| 328 |
+
placeholder="Ask/search for medical information?",
|
| 329 |
+
id="userquery",
|
| 330 |
+
),
|
| 331 |
+
# Get search results on button click with search-input as query parameter
|
| 332 |
+
Button(
|
| 333 |
+
"Search",
|
| 334 |
+
hx_get="/search",
|
| 335 |
+
# include userquery and id of selected ranking radio button
|
| 336 |
+
hx_include="#userquery, input[name=ranking]:checked",
|
| 337 |
+
hx_target="#results",
|
| 338 |
+
hx_indicator="#spinner",
|
| 339 |
+
),
|
| 340 |
+
style="margin: 10% 10px 0 0;",
|
| 341 |
+
),
|
| 342 |
+
Fieldset(
|
| 343 |
+
Input(type="radio", id="bm25", name="ranking", value="bm25"),
|
| 344 |
+
Label("BM25", htmlfor="bm25"),
|
| 345 |
+
Input(type="radio", id="semantic", name="ranking", value="semantic"),
|
| 346 |
+
Label("Semantic", htmlfor="semantic"),
|
| 347 |
+
Input(
|
| 348 |
+
type="radio",
|
| 349 |
+
id="fusion",
|
| 350 |
+
name="ranking",
|
| 351 |
+
value="fusion",
|
| 352 |
+
checked="",
|
| 353 |
+
),
|
| 354 |
+
Label("Reciprocal Rank fusion", htmlfor="fusion"),
|
| 355 |
+
style="margin: 10px; text-align: center;",
|
| 356 |
+
id="ranking",
|
| 357 |
+
),
|
| 358 |
+
H3("Example queries"),
|
| 359 |
+
# Buttons with predefined search queries
|
| 360 |
+
Grid(
|
| 361 |
+
*[
|
| 362 |
+
Button(
|
| 363 |
+
query,
|
| 364 |
+
hx_get="/search?userquery=" + query,
|
| 365 |
+
hx_include="input[name=ranking]:checked",
|
| 366 |
+
hx_target="#results",
|
| 367 |
+
hx_indicator="#spinner",
|
| 368 |
+
hx_on_click=f"document.getElementById('userquery').value='{query}'",
|
| 369 |
+
style="margin: 10px; padding: 5px;",
|
| 370 |
+
cls="secondary outline",
|
| 371 |
+
id=f"example-{qid}",
|
| 372 |
+
)
|
| 373 |
+
for qid, query in enumerate(queries)
|
| 374 |
+
],
|
| 375 |
+
# Make the grid buttons have same height and distribute evenly and center align
|
| 376 |
+
style="grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));",
|
| 377 |
+
),
|
| 378 |
+
# Section(
|
| 379 |
+
# Input(
|
| 380 |
+
# id="suggestion-input",
|
| 381 |
+
# list="search-options",
|
| 382 |
+
# placeholder="Search options",
|
| 383 |
+
# ),
|
| 384 |
+
# Datalist(
|
| 385 |
+
# *[
|
| 386 |
+
# Option(
|
| 387 |
+
# "Covid-19",
|
| 388 |
+
# value="Covid-19",
|
| 389 |
+
# ),
|
| 390 |
+
# Option(
|
| 391 |
+
# "Vaccine",
|
| 392 |
+
# value="Vaccine",
|
| 393 |
+
# ),
|
| 394 |
+
# ],
|
| 395 |
+
# id="search-options",
|
| 396 |
+
# ),
|
| 397 |
+
# id="suggestions",
|
| 398 |
+
# ),
|
| 399 |
+
# Display spinner div only if it #spinner does not exist
|
| 400 |
+
Section(
|
| 401 |
+
spinner_div(),
|
| 402 |
+
id="results",
|
| 403 |
+
hx_swap="innerHTML",
|
| 404 |
+
style="margin: 20px;",
|
| 405 |
+
),
|
| 406 |
+
style="margin: 0 auto; width: 70%;",
|
| 407 |
+
id="main",
|
| 408 |
+
),
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
@dataclass
|
| 413 |
+
class Login:
|
| 414 |
+
name: str
|
| 415 |
+
pwd: str
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
@app.get("/login")
|
| 419 |
+
def get_login_form(sess, error: bool = False):
|
| 420 |
+
auth = sess.get("auth", False)
|
| 421 |
+
frm = Form(
|
| 422 |
+
Input(id="name", placeholder="Name"),
|
| 423 |
+
Input(id="pwd", type="password", placeholder="Password"),
|
| 424 |
+
Button("login"),
|
| 425 |
+
action="/login",
|
| 426 |
+
method="post",
|
| 427 |
+
)
|
| 428 |
+
err_msg = P("Incorrect password", style="color: red;") if error else ""
|
| 429 |
+
return (
|
| 430 |
+
Title("Admin login"),
|
| 431 |
+
get_navbar(auth),
|
| 432 |
+
Main(
|
| 433 |
+
err_msg,
|
| 434 |
+
frm,
|
| 435 |
+
style="width: 50%; margin: 10% auto;",
|
| 436 |
+
),
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
@app.post("/login")
|
| 441 |
+
def post(login: Login, sess):
|
| 442 |
+
if not compare_digest(ADMIN_PWD.encode("utf-8"), login.pwd.encode("utf-8")):
|
| 443 |
+
# Incorrect password - add error message
|
| 444 |
+
return RedirectResponse("/login?error=True", status_code=303)
|
| 445 |
+
print(f"Session after setting auth: {sess}")
|
| 446 |
+
response = RedirectResponse("/admin", status_code=303)
|
| 447 |
+
print(f"Cookies being set: {response.headers.get('Set-Cookie')}")
|
| 448 |
+
return response
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
@app.get("/logout")
|
| 452 |
+
def logout(sess):
|
| 453 |
+
sess["auth"] = False
|
| 454 |
+
return RedirectResponse("/")
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def replace_hi_with_strong(text):
|
| 458 |
+
parts = re.split(r"(<hi>|</hi>)", text)
|
| 459 |
+
elements = []
|
| 460 |
+
open_tag = False
|
| 461 |
+
for part in parts:
|
| 462 |
+
if part == "<hi>":
|
| 463 |
+
open_tag = True
|
| 464 |
+
elif part == "</hi>":
|
| 465 |
+
open_tag = False
|
| 466 |
+
elif open_tag:
|
| 467 |
+
elements.append(Mark(part))
|
| 468 |
+
else:
|
| 469 |
+
elements.append(part)
|
| 470 |
+
return elements
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def log_query_to_db(query, ranking, sess):
|
| 474 |
+
queries.insert(
|
| 475 |
+
Query(query=query, ranking=ranking, sess_id=sesskey, timestamp=int(time.time()))
|
| 476 |
+
)
|
| 477 |
+
if 'user_id' not in sess:
|
| 478 |
+
sess['user_id'] = str(uuid.uuid4())
|
| 479 |
+
|
| 480 |
+
if 'queries' not in sess:
|
| 481 |
+
sess['queries'] = []
|
| 482 |
+
|
| 483 |
+
query_data = {
|
| 484 |
+
'query': query,
|
| 485 |
+
'ranking': ranking,
|
| 486 |
+
'timestamp': int(time.time())
|
| 487 |
+
}
|
| 488 |
+
sess['queries'].append(query_data)
|
| 489 |
+
|
| 490 |
+
# Limit the number of queries stored in the session to prevent it from growing too large
|
| 491 |
+
sess['queries'] = sess['queries'][-100:] # Keep only the last 100 queries
|
| 492 |
+
|
| 493 |
+
return query_data
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def parse_results(records):
|
| 497 |
+
return [
|
| 498 |
+
Article(
|
| 499 |
+
Header(
|
| 500 |
+
H2(
|
| 501 |
+
A(
|
| 502 |
+
result["title"],
|
| 503 |
+
hx_get=f"/document/{result['id']}",
|
| 504 |
+
hx_target="#results",
|
| 505 |
+
)
|
| 506 |
+
)
|
| 507 |
+
),
|
| 508 |
+
Div(
|
| 509 |
+
P(
|
| 510 |
+
*replace_hi_with_strong(
|
| 511 |
+
result["body"][:300] + "..."
|
| 512 |
+
), # Display first 300 characters of body
|
| 513 |
+
),
|
| 514 |
+
Div(
|
| 515 |
+
# Button with "Show more" - center align
|
| 516 |
+
Button(
|
| 517 |
+
"Show more",
|
| 518 |
+
hx_post=f"/expand/{result['id']}?expand=true",
|
| 519 |
+
hx_target=f"#{result['id']}",
|
| 520 |
+
hx_include=f"#{result['id']}-full",
|
| 521 |
+
cls="outline secondary",
|
| 522 |
+
# Style to fill whole width of parent div
|
| 523 |
+
style="width: 100%;",
|
| 524 |
+
),
|
| 525 |
+
style="text-align: center;",
|
| 526 |
+
),
|
| 527 |
+
id=result["id"],
|
| 528 |
+
),
|
| 529 |
+
Hidden(result["body"], id=f"{result['id']}-full"),
|
| 530 |
+
)
|
| 531 |
+
for result in records
|
| 532 |
+
]
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
@app.post("/expand/{docid}")
|
| 536 |
+
async def expand(request: Request, docid: str, expand: bool):
|
| 537 |
+
print(f"Expanding {docid}")
|
| 538 |
+
form_data = await request.form()
|
| 539 |
+
result = form_data.get(f"{docid}-full")
|
| 540 |
+
if not expand:
|
| 541 |
+
result = result[:300] + "..."
|
| 542 |
+
return (
|
| 543 |
+
Div(
|
| 544 |
+
P(
|
| 545 |
+
*replace_hi_with_strong(result), # Display full body
|
| 546 |
+
),
|
| 547 |
+
Div(
|
| 548 |
+
# Button with "Show less" - center align
|
| 549 |
+
Button(
|
| 550 |
+
"Show less" if expand else "Show more",
|
| 551 |
+
hx_post=f"/expand/{docid}?expand="
|
| 552 |
+
+ ("false" if expand else "true"),
|
| 553 |
+
hx_target=f"#{docid}",
|
| 554 |
+
hx_include=f"#{docid}-full",
|
| 555 |
+
cls="outline secondary",
|
| 556 |
+
# Style to fill whole width of parent div
|
| 557 |
+
style="width: 100%;",
|
| 558 |
+
),
|
| 559 |
+
style="text-align: center;",
|
| 560 |
+
),
|
| 561 |
+
id=docid,
|
| 562 |
+
),
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
# Returns tuple of (yql, body(dict)) based on the ranking profile
|
| 567 |
+
def get_yql(ranking: RankProfile, userquery: str) -> T[str, dict]:
|
| 568 |
+
if ranking == RankProfile.bm25:
|
| 569 |
+
yql = "select * from sources * where userQuery() limit 10"
|
| 570 |
+
body = {}
|
| 571 |
+
elif ranking == RankProfile.semantic:
|
| 572 |
+
yql = "select * from sources * where ({targetHits:10}nearestNeighbor(embedding,q)) limit 10"
|
| 573 |
+
body = {"input.query(q)": f"embed({userquery})"}
|
| 574 |
+
elif ranking == RankProfile.fusion:
|
| 575 |
+
yql = "select * from sources * where rank({targetHits:1000}nearestNeighbor(embedding,q), userQuery()) limit 10"
|
| 576 |
+
body = {"input.query(q)": f"embed({userquery})"}
|
| 577 |
+
return yql, body
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
@app.get("/search")
|
| 581 |
+
async def search(userquery: str, ranking: str, sess):
|
| 582 |
+
print(sess)
|
| 583 |
+
quoted = quote(userquery) + "&ranking=" + ranking
|
| 584 |
+
log_query_to_db(userquery, ranking, sess)
|
| 585 |
+
yql, body = get_yql(ranking, userquery)
|
| 586 |
+
async with vespa_app.asyncio() as session:
|
| 587 |
+
resp = await session.query(
|
| 588 |
+
yql=yql,
|
| 589 |
+
query=userquery,
|
| 590 |
+
hits=10,
|
| 591 |
+
ranking=str(ranking),
|
| 592 |
+
body=body,
|
| 593 |
+
)
|
| 594 |
+
records = []
|
| 595 |
+
fields = ["id", "title", "body"]
|
| 596 |
+
for hit in resp.hits:
|
| 597 |
+
record = {}
|
| 598 |
+
for field in fields:
|
| 599 |
+
record[field] = hit["fields"][field]
|
| 600 |
+
records.append(record)
|
| 601 |
+
results = parse_results(records)
|
| 602 |
+
json_dump = json.dumps(resp.get_json(), indent=4)
|
| 603 |
+
return Div(
|
| 604 |
+
spinner_div(),
|
| 605 |
+
# Accordion (with Details)
|
| 606 |
+
Details(
|
| 607 |
+
Summary("Full JSON response"),
|
| 608 |
+
Div(
|
| 609 |
+
f"""```json\n{json_dump}\n```""",
|
| 610 |
+
cls="marked",
|
| 611 |
+
),
|
| 612 |
+
),
|
| 613 |
+
H2(
|
| 614 |
+
"Search Results",
|
| 615 |
+
),
|
| 616 |
+
Div(
|
| 617 |
+
*results,
|
| 618 |
+
id="all-searchresults",
|
| 619 |
+
),
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
@app.get("/download_csv")
|
| 624 |
+
def download_csv(auth):
|
| 625 |
+
queries_dict = list(db.query("SELECT * FROM queries"))
|
| 626 |
+
queries = [Query(**query) for query in queries_dict]
|
| 627 |
+
|
| 628 |
+
# Create CSV in memory
|
| 629 |
+
csv_file = StringIO()
|
| 630 |
+
csv_writer = csv.writer(csv_file)
|
| 631 |
+
csv_writer.writerow(["Query", "Session ID", "Timestamp"])
|
| 632 |
+
for query in queries:
|
| 633 |
+
csv_writer.writerow([query.query, query.sess_id, query.timestamp])
|
| 634 |
+
|
| 635 |
+
# Move to the beginning of the StringIO object
|
| 636 |
+
csv_file.seek(0)
|
| 637 |
+
|
| 638 |
+
# Save CSV to a temporary file
|
| 639 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
|
| 640 |
+
temp_file.write(csv_file.getvalue().encode("utf-8"))
|
| 641 |
+
temp_file.close()
|
| 642 |
+
|
| 643 |
+
return FileResponse(
|
| 644 |
+
temp_file.name,
|
| 645 |
+
filename="queries.csv",
|
| 646 |
+
media_type="text/csv",
|
| 647 |
+
content_disposition_type="attachment",
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
@app.get("/admin")
|
| 652 |
+
def get_admin(auth, page: int = 1):
|
| 653 |
+
limit = 15
|
| 654 |
+
offset = (page - 1) * limit
|
| 655 |
+
total_queries_result = list(
|
| 656 |
+
db.query("SELECT COUNT(*) AS count FROM queries ORDER BY timestamp DESC")
|
| 657 |
+
)
|
| 658 |
+
total_queries = total_queries_result[0]["count"]
|
| 659 |
+
queries_dict = list(
|
| 660 |
+
db.query(f"SELECT * FROM queries LIMIT {limit} OFFSET {offset}")
|
| 661 |
+
)
|
| 662 |
+
queries = [Query(**query) for query in queries_dict]
|
| 663 |
+
|
| 664 |
+
total_pages = (
|
| 665 |
+
total_queries + limit - 1
|
| 666 |
+
) // limit # Calculate total number of pages
|
| 667 |
+
|
| 668 |
+
# Define the range of pages to display
|
| 669 |
+
page_window = 5 # Number of pages to display at once
|
| 670 |
+
start_page = max(1, page - page_window // 2)
|
| 671 |
+
end_page = min(total_pages, start_page + page_window - 1)
|
| 672 |
+
|
| 673 |
+
# Adjust the start and end pages if they exceed the limits
|
| 674 |
+
if end_page - start_page < page_window:
|
| 675 |
+
start_page = max(1, end_page - page_window + 1)
|
| 676 |
+
|
| 677 |
+
# Pagination controls with "First", "Previous", "Next", and "Last"
|
| 678 |
+
pagination_controls = Div(
|
| 679 |
+
A(
|
| 680 |
+
"First",
|
| 681 |
+
href="/admin?page=1",
|
| 682 |
+
style="margin: 5px;"
|
| 683 |
+
if page > 1
|
| 684 |
+
else "margin: 5px; color: grey; pointer-events: none;",
|
| 685 |
+
),
|
| 686 |
+
A(
|
| 687 |
+
"Previous",
|
| 688 |
+
href=f"/admin?page={page - 1}",
|
| 689 |
+
style="margin: 5px;"
|
| 690 |
+
if page > 1
|
| 691 |
+
else "margin: 5px; color: grey; pointer-events: none;",
|
| 692 |
+
),
|
| 693 |
+
*[
|
| 694 |
+
A(
|
| 695 |
+
f"{i}",
|
| 696 |
+
href=f"/admin?page={i}",
|
| 697 |
+
style="margin: 5px;"
|
| 698 |
+
if i != page
|
| 699 |
+
else "margin: 5px; font-weight: bold;",
|
| 700 |
+
)
|
| 701 |
+
for i in range(start_page, end_page + 1)
|
| 702 |
+
],
|
| 703 |
+
A(
|
| 704 |
+
"Next",
|
| 705 |
+
href=f"/admin?page={page + 1}",
|
| 706 |
+
style="margin: 5px;"
|
| 707 |
+
if page < total_pages
|
| 708 |
+
else "margin: 5px; color: grey; pointer-events: none;",
|
| 709 |
+
),
|
| 710 |
+
A(
|
| 711 |
+
"Last",
|
| 712 |
+
href=f"/admin?page={total_pages}",
|
| 713 |
+
style="margin: 5px;"
|
| 714 |
+
if page < total_pages
|
| 715 |
+
else "margin: 5px; color: grey; pointer-events: none;",
|
| 716 |
+
),
|
| 717 |
+
style="text-align: center; margin: 20px;",
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
# Total pages indication
|
| 721 |
+
total_pages_indicator = Div(
|
| 722 |
+
f"Page {page} of {total_pages}",
|
| 723 |
+
style="text-align: center; margin: 10px;",
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
return (
|
| 727 |
+
Title("Admin"),
|
| 728 |
+
get_navbar(auth),
|
| 729 |
+
Main(
|
| 730 |
+
Div(
|
| 731 |
+
A(
|
| 732 |
+
I(cls="fa fa-arrow-left"),
|
| 733 |
+
"Back",
|
| 734 |
+
href="/",
|
| 735 |
+
title="Back to main page",
|
| 736 |
+
style="margin: 10px;",
|
| 737 |
+
),
|
| 738 |
+
style="margin: 10px;",
|
| 739 |
+
),
|
| 740 |
+
H2("Queries"),
|
| 741 |
+
# Table of all queries
|
| 742 |
+
Table(
|
| 743 |
+
Thead(
|
| 744 |
+
Tr(
|
| 745 |
+
Th("Query"),
|
| 746 |
+
Th("Session ID"),
|
| 747 |
+
Th("Datetime"),
|
| 748 |
+
)
|
| 749 |
+
),
|
| 750 |
+
Tbody(
|
| 751 |
+
*[
|
| 752 |
+
Tr(
|
| 753 |
+
Td(query.query),
|
| 754 |
+
Td(query.sess_id),
|
| 755 |
+
Td(query.get_datetime()),
|
| 756 |
+
)
|
| 757 |
+
for query in queries
|
| 758 |
+
],
|
| 759 |
+
),
|
| 760 |
+
cls="striped",
|
| 761 |
+
),
|
| 762 |
+
total_pages_indicator, # Include the total pages indicator here
|
| 763 |
+
pagination_controls,
|
| 764 |
+
Div(
|
| 765 |
+
A(
|
| 766 |
+
I(cls="fa fa-download fa-2x"),
|
| 767 |
+
" Download CSV",
|
| 768 |
+
href="/download_csv",
|
| 769 |
+
style="margin: 10px; float: right;",
|
| 770 |
+
title="Download queries as CSV",
|
| 771 |
+
),
|
| 772 |
+
style="text-align: right; margin: 20px;",
|
| 773 |
+
),
|
| 774 |
+
style="width: 80%; margin: 40px auto;",
|
| 775 |
+
),
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
@app.get("/source")
|
| 780 |
+
def get_source(auth, sess):
|
| 781 |
+
# Back icon to go back to main page in top left corner
|
| 782 |
+
return (
|
| 783 |
+
Title("Source code"),
|
| 784 |
+
get_navbar(auth),
|
| 785 |
+
Main(
|
| 786 |
+
Div(
|
| 787 |
+
A(
|
| 788 |
+
I(cls="fa fa-arrow-left"),
|
| 789 |
+
"Back",
|
| 790 |
+
href="/",
|
| 791 |
+
title="Back to main page",
|
| 792 |
+
style="margin: 10px;",
|
| 793 |
+
),
|
| 794 |
+
Div(
|
| 795 |
+
f"""### `main.py`\n### This is the complete source code for this app \n```python\n{SOURCE}\n```""",
|
| 796 |
+
cls="marked",
|
| 797 |
+
style="margin: 10px;",
|
| 798 |
+
),
|
| 799 |
+
style="width: 80%; margin: 40px auto;",
|
| 800 |
+
),
|
| 801 |
+
),
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
@app.get("/about")
|
| 806 |
+
def get_about(auth, sess):
|
| 807 |
+
# Strip everything before the FIRST # in the README
|
| 808 |
+
stripped_readme = re.sub(
|
| 809 |
+
r"^.*?(?=# FastHTML Vespa frontend)", "", README, flags=re.DOTALL
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
return (
|
| 813 |
+
Title("About this app"),
|
| 814 |
+
get_navbar(auth),
|
| 815 |
+
Main(
|
| 816 |
+
Div(
|
| 817 |
+
A(
|
| 818 |
+
I(cls="fa fa-arrow-left"),
|
| 819 |
+
"Back",
|
| 820 |
+
href="/",
|
| 821 |
+
title="Back to main page",
|
| 822 |
+
style="margin: 10px;",
|
| 823 |
+
),
|
| 824 |
+
Div(
|
| 825 |
+
stripped_readme,
|
| 826 |
+
cls="marked",
|
| 827 |
+
style="margin: 10px;",
|
| 828 |
+
),
|
| 829 |
+
style="width: 80%; margin: 40px auto;",
|
| 830 |
+
),
|
| 831 |
+
),
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
@app.get("/document/{docid}")
|
| 836 |
+
def get_document(docid: str, sess):
|
| 837 |
+
resp = vespa_app.get_data(data_id=docid, schema="doc", namespace="tutorial")
|
| 838 |
+
doc = resp.json
|
| 839 |
+
# Link with Back to search results at top of page
|
| 840 |
+
last_query = sess.get('queries', [{}])[-1].get('query', '')
|
| 841 |
+
return Main(
|
| 842 |
+
Div(
|
| 843 |
+
A(
|
| 844 |
+
I(cls="fa fa-arrow-left"),
|
| 845 |
+
"Back to search results",
|
| 846 |
+
hx_get=f"/search?userquery={last_query}",
|
| 847 |
+
hx_target="#results",
|
| 848 |
+
style="margin: 10px;",
|
| 849 |
+
),
|
| 850 |
+
H2(doc["fields"]["title"], style="margin: 10px;"),
|
| 851 |
+
P(doc["fields"]["body"], cls="marked"),
|
| 852 |
+
),
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
if not DEV_MODE:
|
| 857 |
+
try:
|
| 858 |
+
setup_hf_backup(app)
|
| 859 |
+
except Exception as e:
|
| 860 |
+
print(f"Error setting up hf backup: {e}")
|
| 861 |
+
serve()
|
main.py
CHANGED
|
@@ -57,7 +57,6 @@ import tempfile
|
|
| 57 |
from enum import Enum
|
| 58 |
from typing import Tuple as T
|
| 59 |
from urllib.parse import quote
|
| 60 |
-
import uuid
|
| 61 |
|
| 62 |
DEV_MODE = False
|
| 63 |
|
|
@@ -165,32 +164,14 @@ class XFrameOptionsMiddleware(BaseHTTPMiddleware):
|
|
| 165 |
response.headers["X-Frame-Options"] = "ALLOW-FROM https://huggingface.co/"
|
| 166 |
return response
|
| 167 |
|
| 168 |
-
class SessionLoggingMiddleware(BaseHTTPMiddleware):
|
| 169 |
-
async def dispatch(self, request, call_next):
|
| 170 |
-
print(f"Before request: Session data: {request.session}")
|
| 171 |
-
response = await call_next(request)
|
| 172 |
-
print(f"After request: Session data: {request.session}")
|
| 173 |
-
return response
|
| 174 |
-
|
| 175 |
-
class DebugSessionMiddleware(SessionMiddleware):
|
| 176 |
-
async def __call__(self, scope, receive, send):
|
| 177 |
-
print(f"DebugSessionMiddleware: Before processing - Scope: {scope}")
|
| 178 |
-
await super().__call__(scope, receive, send)
|
| 179 |
-
print(f"DebugSessionMiddleware: After processing - Scope: {scope}")
|
| 180 |
-
|
| 181 |
-
from starlette.middleware.cors import CORSMiddleware
|
| 182 |
|
| 183 |
middlewares = [
|
| 184 |
Middleware(
|
| 185 |
SessionMiddleware,
|
| 186 |
secret_key=get_key(fname=sess_key_path),
|
| 187 |
max_age=3600,
|
| 188 |
-
#same_site='lax',
|
| 189 |
),
|
| 190 |
-
Middleware(CORSMiddleware, allow_origins=['*']),
|
| 191 |
Middleware(XFrameOptionsMiddleware),
|
| 192 |
-
Middleware(SessionLoggingMiddleware),
|
| 193 |
-
#Middleware(DebugSessionMiddleware, secret_key=get_key(fname=sess_key_path)),
|
| 194 |
]
|
| 195 |
bware = Beforeware(
|
| 196 |
user_auth_before,
|
|
@@ -314,6 +295,7 @@ def get(sess):
|
|
| 314 |
queries = [
|
| 315 |
"Breast Cancer Cells Feed on Cholesterol",
|
| 316 |
"Treating Asthma With Plants vs. Pills",
|
|
|
|
| 317 |
"Testing Turmeric on Smokers",
|
| 318 |
"The Role of Pesticides in Parkinson's Disease",
|
| 319 |
]
|
|
@@ -442,10 +424,9 @@ def post(login: Login, sess):
|
|
| 442 |
if not compare_digest(ADMIN_PWD.encode("utf-8"), login.pwd.encode("utf-8")):
|
| 443 |
# Incorrect password - add error message
|
| 444 |
return RedirectResponse("/login?error=True", status_code=303)
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
return response
|
| 449 |
|
| 450 |
|
| 451 |
@app.get("/logout")
|
|
@@ -471,26 +452,9 @@ def replace_hi_with_strong(text):
|
|
| 471 |
|
| 472 |
|
| 473 |
def log_query_to_db(query, ranking, sess):
|
| 474 |
-
queries.insert(
|
| 475 |
Query(query=query, ranking=ranking, sess_id=sesskey, timestamp=int(time.time()))
|
| 476 |
)
|
| 477 |
-
if 'user_id' not in sess:
|
| 478 |
-
sess['user_id'] = str(uuid.uuid4())
|
| 479 |
-
|
| 480 |
-
if 'queries' not in sess:
|
| 481 |
-
sess['queries'] = []
|
| 482 |
-
|
| 483 |
-
query_data = {
|
| 484 |
-
'query': query,
|
| 485 |
-
'ranking': ranking,
|
| 486 |
-
'timestamp': int(time.time())
|
| 487 |
-
}
|
| 488 |
-
sess['queries'].append(query_data)
|
| 489 |
-
|
| 490 |
-
# Limit the number of queries stored in the session to prevent it from growing too large
|
| 491 |
-
sess['queries'] = sess['queries'][-100:] # Keep only the last 100 queries
|
| 492 |
-
|
| 493 |
-
return query_data
|
| 494 |
|
| 495 |
|
| 496 |
def parse_results(records):
|
|
@@ -580,7 +544,12 @@ def get_yql(ranking: RankProfile, userquery: str) -> T[str, dict]:
|
|
| 580 |
@app.get("/search")
|
| 581 |
async def search(userquery: str, ranking: str, sess):
|
| 582 |
print(sess)
|
|
|
|
|
|
|
| 583 |
quoted = quote(userquery) + "&ranking=" + ranking
|
|
|
|
|
|
|
|
|
|
| 584 |
log_query_to_db(userquery, ranking, sess)
|
| 585 |
yql, body = get_yql(ranking, userquery)
|
| 586 |
async with vespa_app.asyncio() as session:
|
|
@@ -837,13 +806,12 @@ def get_document(docid: str, sess):
|
|
| 837 |
resp = vespa_app.get_data(data_id=docid, schema="doc", namespace="tutorial")
|
| 838 |
doc = resp.json
|
| 839 |
# Link with Back to search results at top of page
|
| 840 |
-
last_query = sess.get('queries', [{}])[-1].get('query', '')
|
| 841 |
return Main(
|
| 842 |
Div(
|
| 843 |
A(
|
| 844 |
I(cls="fa fa-arrow-left"),
|
| 845 |
"Back to search results",
|
| 846 |
-
hx_get=f"/search?userquery={
|
| 847 |
hx_target="#results",
|
| 848 |
style="margin: 10px;",
|
| 849 |
),
|
|
|
|
| 57 |
from enum import Enum
|
| 58 |
from typing import Tuple as T
|
| 59 |
from urllib.parse import quote
|
|
|
|
| 60 |
|
| 61 |
DEV_MODE = False
|
| 62 |
|
|
|
|
| 164 |
response.headers["X-Frame-Options"] = "ALLOW-FROM https://huggingface.co/"
|
| 165 |
return response
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
middlewares = [
|
| 169 |
Middleware(
|
| 170 |
SessionMiddleware,
|
| 171 |
secret_key=get_key(fname=sess_key_path),
|
| 172 |
max_age=3600,
|
|
|
|
| 173 |
),
|
|
|
|
| 174 |
Middleware(XFrameOptionsMiddleware),
|
|
|
|
|
|
|
| 175 |
]
|
| 176 |
bware = Beforeware(
|
| 177 |
user_auth_before,
|
|
|
|
| 295 |
queries = [
|
| 296 |
"Breast Cancer Cells Feed on Cholesterol",
|
| 297 |
"Treating Asthma With Plants vs. Pills",
|
| 298 |
+
"Alkylphenol Endocrine Disruptors",
|
| 299 |
"Testing Turmeric on Smokers",
|
| 300 |
"The Role of Pesticides in Parkinson's Disease",
|
| 301 |
]
|
|
|
|
| 424 |
if not compare_digest(ADMIN_PWD.encode("utf-8"), login.pwd.encode("utf-8")):
|
| 425 |
# Incorrect password - add error message
|
| 426 |
return RedirectResponse("/login?error=True", status_code=303)
|
| 427 |
+
sess["auth"] = True
|
| 428 |
+
print(f"Sess after login: {sess}")
|
| 429 |
+
return RedirectResponse("/admin", status_code=303)
|
|
|
|
| 430 |
|
| 431 |
|
| 432 |
@app.get("/logout")
|
|
|
|
| 452 |
|
| 453 |
|
| 454 |
def log_query_to_db(query, ranking, sess):
|
| 455 |
+
return queries.insert(
|
| 456 |
Query(query=query, ranking=ranking, sess_id=sesskey, timestamp=int(time.time()))
|
| 457 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
|
| 459 |
|
| 460 |
def parse_results(records):
|
|
|
|
| 544 |
@app.get("/search")
|
| 545 |
async def search(userquery: str, ranking: str, sess):
|
| 546 |
print(sess)
|
| 547 |
+
if "queries" not in sess:
|
| 548 |
+
sess["queries"] = []
|
| 549 |
quoted = quote(userquery) + "&ranking=" + ranking
|
| 550 |
+
sess["queries"].append(quoted)
|
| 551 |
+
print(f"Searching for: {userquery}")
|
| 552 |
+
print(f"Ranking: {ranking}")
|
| 553 |
log_query_to_db(userquery, ranking, sess)
|
| 554 |
yql, body = get_yql(ranking, userquery)
|
| 555 |
async with vespa_app.asyncio() as session:
|
|
|
|
| 806 |
resp = vespa_app.get_data(data_id=docid, schema="doc", namespace="tutorial")
|
| 807 |
doc = resp.json
|
| 808 |
# Link with Back to search results at top of page
|
|
|
|
| 809 |
return Main(
|
| 810 |
Div(
|
| 811 |
A(
|
| 812 |
I(cls="fa fa-arrow-left"),
|
| 813 |
"Back to search results",
|
| 814 |
+
hx_get=f"/search?userquery={sess['queries'][-1]}",
|
| 815 |
hx_target="#results",
|
| 816 |
style="margin: 10px;",
|
| 817 |
),
|