Spaces:
Sleeping
Sleeping
Abhipsha Das
commited on
initial spaces deploy
Browse files- README.md +12 -7
- app.py +14 -0
- config.py +445 -0
- requirements.txt +8 -0
- scripts/__init__.py +6 -0
- scripts/__pycache__/__init__.cpython-311.pyc +0 -0
- scripts/__pycache__/create_db.cpython-311.pyc +0 -0
- scripts/__pycache__/run_db_interface.cpython-311.pyc +0 -0
- scripts/__pycache__/run_db_interface_improved.cpython-311.pyc +0 -0
- scripts/create_db.py +246 -0
- scripts/run_db_interface.py +704 -0
- scripts/run_db_interface_basic.py +361 -0
- scripts/run_db_interface_js.py +0 -0
README.md
CHANGED
|
@@ -1,14 +1,19 @@
|
|
| 1 |
---
|
| 2 |
-
title: Surveyor
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
license: openrail
|
| 11 |
-
short_description: Interface for exploring scientific concepts with KGs
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Surveyor
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: blue
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.40.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# Surveyor
|
| 13 |
+
|
| 14 |
+
An interactive interface for querying and visualizing scientific paper databases with concept co-occurrence graphs.
|
| 15 |
+
## Features
|
| 16 |
+
- Interactive concept co-occurrence graphs
|
| 17 |
+
- SQL query interface with pre-built queries
|
| 18 |
+
- Support for multiple scientific domains
|
| 19 |
+
- Graph filtering and highlighting
|
app.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
# Add the project root directory to Python path
|
| 5 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 6 |
+
if ROOT_DIR not in sys.path:
|
| 7 |
+
sys.path.insert(0, ROOT_DIR)
|
| 8 |
+
|
| 9 |
+
from scripts.run_db_interface import create_demo
|
| 10 |
+
|
| 11 |
+
demo = create_demo()
|
| 12 |
+
|
| 13 |
+
if __name__ == "__main__":
|
| 14 |
+
demo.launch()
|
config.py
ADDED
|
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DEFAULT_MODEL_ID = "Meta-Llama-3-70B-Instruct"
|
| 2 |
+
DEFAULT_INTERFACE_MODEL_ID = "NumbersStation/nsql-llama-2-7B"
|
| 3 |
+
DEFAULT_KIND = "json"
|
| 4 |
+
DEFAULT_TEMPERATURE = 0.6
|
| 5 |
+
DEFAULT_TOP_P = 0.95
|
| 6 |
+
DEFAULT_FEW_SHOT_NUM = 3
|
| 7 |
+
DEFAULT_FEW_SHOT_SELECTION = "random"
|
| 8 |
+
DEFAULT_SAVE_INTERVAL = 3
|
| 9 |
+
DEFAULT_RES_DIR = "data/results"
|
| 10 |
+
DEFAULT_LOG_DIR = "logs"
|
| 11 |
+
DEFAULT_TABLES_DIR = "data/databases"
|
| 12 |
+
|
| 13 |
+
COOCCURRENCE_QUERY = """
|
| 14 |
+
WITH concept_pairs AS (
|
| 15 |
+
SELECT p1.concept AS concept1, p2.concept AS concept2, p1.paper_id, p1.tag_type
|
| 16 |
+
FROM predictions p1
|
| 17 |
+
JOIN predictions p2 ON p1.paper_id = p2.paper_id AND p1.concept < p2.concept
|
| 18 |
+
WHERE p1.tag_type = p2.tag_type
|
| 19 |
+
)
|
| 20 |
+
SELECT concept1, concept2, tag_type, COUNT(DISTINCT paper_id) AS co_occurrences
|
| 21 |
+
FROM concept_pairs
|
| 22 |
+
GROUP BY concept1, concept2, tag_type
|
| 23 |
+
HAVING co_occurrences > 5
|
| 24 |
+
ORDER BY co_occurrences DESC;
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
canned_queries = [
|
| 28 |
+
(
|
| 29 |
+
"Modalities in Physics and Astronomy papers",
|
| 30 |
+
"""
|
| 31 |
+
SELECT DISTINCT LOWER(concept) AS concept
|
| 32 |
+
FROM predictions
|
| 33 |
+
JOIN (
|
| 34 |
+
SELECT paper_id, url
|
| 35 |
+
FROM papers
|
| 36 |
+
WHERE primary_category LIKE '%physics.space-ph%'
|
| 37 |
+
OR primary_category LIKE '%astro-ph.%'
|
| 38 |
+
) AS paper_ids
|
| 39 |
+
ON predictions.paper_id = paper_ids.paper_id
|
| 40 |
+
WHERE predictions.tag_type = 'modality'
|
| 41 |
+
""",
|
| 42 |
+
),
|
| 43 |
+
(
|
| 44 |
+
"Datasets in Evolutionary Biology that use PDEs",
|
| 45 |
+
"""
|
| 46 |
+
WITH pde_predictions AS (
|
| 47 |
+
SELECT paper_id, concept AS pde_concept, tag_type AS pde_tag_type
|
| 48 |
+
FROM predictions
|
| 49 |
+
WHERE tag_type IN ('method', 'model')
|
| 50 |
+
AND (
|
| 51 |
+
LOWER(concept) LIKE '%pde%'
|
| 52 |
+
OR LOWER(concept) LIKE '%partial differential equation%'
|
| 53 |
+
)
|
| 54 |
+
)
|
| 55 |
+
SELECT DISTINCT
|
| 56 |
+
papers.paper_id,
|
| 57 |
+
papers.url,
|
| 58 |
+
LOWER(p_dataset.concept) AS dataset,
|
| 59 |
+
pde_predictions.pde_concept AS pde_related_concept,
|
| 60 |
+
pde_predictions.pde_tag_type AS pde_related_type
|
| 61 |
+
FROM papers
|
| 62 |
+
JOIN pde_predictions ON papers.paper_id = pde_predictions.paper_id
|
| 63 |
+
LEFT JOIN predictions p_dataset ON papers.paper_id = p_dataset.paper_id
|
| 64 |
+
WHERE papers.primary_category LIKE '%q-bio.PE%'
|
| 65 |
+
AND (p_dataset.tag_type = 'dataset' OR p_dataset.tag_type IS NULL)
|
| 66 |
+
ORDER BY papers.paper_id, dataset, pde_related_concept;
|
| 67 |
+
""",
|
| 68 |
+
),
|
| 69 |
+
(
|
| 70 |
+
"Trends in objects of study in Cosmology since 2019",
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
SELECT
|
| 74 |
+
substr(papers.updated_on, 2, 4) as year,
|
| 75 |
+
predictions.concept as object,
|
| 76 |
+
COUNT(DISTINCT papers.paper_id) as paper_count
|
| 77 |
+
FROM
|
| 78 |
+
papers
|
| 79 |
+
JOIN
|
| 80 |
+
predictions ON papers.paper_id = predictions.paper_id
|
| 81 |
+
WHERE
|
| 82 |
+
predictions.tag_type = 'object'
|
| 83 |
+
AND CAST(SUBSTR(papers.updated_on, 2, 4) AS INTEGER) >= 2019
|
| 84 |
+
GROUP BY
|
| 85 |
+
year, object
|
| 86 |
+
ORDER BY
|
| 87 |
+
year DESC, paper_count DESC;
|
| 88 |
+
""",
|
| 89 |
+
),
|
| 90 |
+
(
|
| 91 |
+
"New datasets in fluid dynamics since 2020",
|
| 92 |
+
"""
|
| 93 |
+
WITH ranked_datasets AS (
|
| 94 |
+
SELECT
|
| 95 |
+
p.paper_id,
|
| 96 |
+
p.url,
|
| 97 |
+
pred.concept AS dataset,
|
| 98 |
+
p.updated_on,
|
| 99 |
+
ROW_NUMBER() OVER (PARTITION BY pred.concept ORDER BY p.updated_on ASC) AS rn
|
| 100 |
+
FROM
|
| 101 |
+
papers p
|
| 102 |
+
JOIN
|
| 103 |
+
predictions pred ON p.paper_id = pred.paper_id
|
| 104 |
+
WHERE
|
| 105 |
+
pred.tag_type = 'dataset'
|
| 106 |
+
AND p.primary_category LIKE '%physics.flu-dyn%'
|
| 107 |
+
AND CAST(SUBSTR(p.updated_on, 2, 4) AS INTEGER) >= 2020
|
| 108 |
+
)
|
| 109 |
+
SELECT
|
| 110 |
+
paper_id,
|
| 111 |
+
url,
|
| 112 |
+
dataset,
|
| 113 |
+
updated_on
|
| 114 |
+
FROM
|
| 115 |
+
ranked_datasets
|
| 116 |
+
WHERE
|
| 117 |
+
rn = 1
|
| 118 |
+
ORDER BY
|
| 119 |
+
updated_on ASC
|
| 120 |
+
""",
|
| 121 |
+
),
|
| 122 |
+
(
|
| 123 |
+
"Evolutionary biology datasets that use spatiotemporal dynamics",
|
| 124 |
+
"""
|
| 125 |
+
WITH evo_bio_papers AS (
|
| 126 |
+
SELECT paper_id
|
| 127 |
+
FROM papers
|
| 128 |
+
WHERE primary_category LIKE '%q-bio.PE%'
|
| 129 |
+
),
|
| 130 |
+
spatiotemporal_keywords AS (
|
| 131 |
+
SELECT 'spatio-temporal' AS keyword
|
| 132 |
+
UNION SELECT 'spatiotemporal'
|
| 133 |
+
UNION SELECT 'spatio-temporal'
|
| 134 |
+
UNION SELECT 'spatial and temporal'
|
| 135 |
+
UNION SELECT 'space-time'
|
| 136 |
+
UNION SELECT 'geographic distribution'
|
| 137 |
+
UNION SELECT 'phylogeograph'
|
| 138 |
+
UNION SELECT 'biogeograph'
|
| 139 |
+
UNION SELECT 'dispersal'
|
| 140 |
+
UNION SELECT 'migration'
|
| 141 |
+
UNION SELECT 'range expansion'
|
| 142 |
+
UNION SELECT 'population dynamics'
|
| 143 |
+
)
|
| 144 |
+
SELECT DISTINCT
|
| 145 |
+
p.paper_id,
|
| 146 |
+
p.updated_on,
|
| 147 |
+
p.abstract,
|
| 148 |
+
d.concept AS dataset,
|
| 149 |
+
GROUP_CONCAT(DISTINCT stk.keyword) AS spatiotemporal_keywords_found
|
| 150 |
+
FROM
|
| 151 |
+
evo_bio_papers ebp
|
| 152 |
+
JOIN
|
| 153 |
+
papers p ON ebp.paper_id = p.paper_id
|
| 154 |
+
JOIN
|
| 155 |
+
predictions d ON p.paper_id = d.paper_id
|
| 156 |
+
JOIN
|
| 157 |
+
predictions st ON p.paper_id = st.paper_id
|
| 158 |
+
JOIN
|
| 159 |
+
spatiotemporal_keywords stk
|
| 160 |
+
WHERE
|
| 161 |
+
d.tag_type = 'dataset'
|
| 162 |
+
AND st.tag_type = 'modality'
|
| 163 |
+
AND LOWER(st.concept) LIKE '%' || stk.keyword || '%'
|
| 164 |
+
GROUP BY
|
| 165 |
+
p.paper_id, p.updated_on, p.abstract, d.concept
|
| 166 |
+
ORDER BY
|
| 167 |
+
p.updated_on DESC
|
| 168 |
+
""",
|
| 169 |
+
),
|
| 170 |
+
(
|
| 171 |
+
"What percentage of papers use only galaxy or spectra, or both or neither?",
|
| 172 |
+
"""
|
| 173 |
+
WITH paper_modalities AS (
|
| 174 |
+
SELECT
|
| 175 |
+
p.paper_id,
|
| 176 |
+
MAX(CASE WHEN LOWER(pred.concept) LIKE '%imag%' THEN 1 ELSE 0 END) AS uses_galaxy_images,
|
| 177 |
+
MAX(CASE WHEN LOWER(pred.concept) LIKE '%spectr%' THEN 1 ELSE 0 END) AS uses_spectra
|
| 178 |
+
FROM
|
| 179 |
+
papers p
|
| 180 |
+
LEFT JOIN
|
| 181 |
+
predictions pred ON p.paper_id = pred.paper_id
|
| 182 |
+
WHERE
|
| 183 |
+
p.primary_category LIKE '%astro-ph%'
|
| 184 |
+
AND pred.tag_type = 'modality'
|
| 185 |
+
GROUP BY
|
| 186 |
+
p.paper_id
|
| 187 |
+
),
|
| 188 |
+
categorized_papers AS (
|
| 189 |
+
SELECT
|
| 190 |
+
CASE
|
| 191 |
+
WHEN uses_galaxy_images = 1 AND uses_spectra = 1 THEN 'Both'
|
| 192 |
+
WHEN uses_galaxy_images = 1 THEN 'Only Galaxy Images'
|
| 193 |
+
WHEN uses_spectra = 1 THEN 'Only Spectra'
|
| 194 |
+
ELSE 'Neither'
|
| 195 |
+
END AS category,
|
| 196 |
+
COUNT(*) AS paper_count
|
| 197 |
+
FROM
|
| 198 |
+
paper_modalities
|
| 199 |
+
GROUP BY
|
| 200 |
+
CASE
|
| 201 |
+
WHEN uses_galaxy_images = 1 AND uses_spectra = 1 THEN 'Both'
|
| 202 |
+
WHEN uses_galaxy_images = 1 THEN 'Only Galaxy Images'
|
| 203 |
+
WHEN uses_spectra = 1 THEN 'Only Spectra'
|
| 204 |
+
ELSE 'Neither'
|
| 205 |
+
END
|
| 206 |
+
)
|
| 207 |
+
SELECT
|
| 208 |
+
category,
|
| 209 |
+
paper_count,
|
| 210 |
+
ROUND(CAST(paper_count AS FLOAT) / (SELECT SUM(paper_count) FROM categorized_papers) * 100, 2) AS percentage
|
| 211 |
+
FROM
|
| 212 |
+
categorized_papers
|
| 213 |
+
ORDER BY
|
| 214 |
+
paper_count DESC
|
| 215 |
+
""",
|
| 216 |
+
),
|
| 217 |
+
(
|
| 218 |
+
"What are all the next highest data modalities after images and spectra?",
|
| 219 |
+
"""
|
| 220 |
+
SELECT
|
| 221 |
+
LOWER(concept) AS modality,
|
| 222 |
+
COUNT(DISTINCT paper_id) AS usage_count
|
| 223 |
+
FROM
|
| 224 |
+
predictions
|
| 225 |
+
WHERE
|
| 226 |
+
tag_type = 'modality'
|
| 227 |
+
AND LOWER(concept) NOT LIKE '%imag%'
|
| 228 |
+
AND LOWER(concept) NOT LIKE '%spectr%'
|
| 229 |
+
GROUP BY
|
| 230 |
+
LOWER(concept)
|
| 231 |
+
ORDER BY
|
| 232 |
+
usage_count DESC
|
| 233 |
+
""",
|
| 234 |
+
),
|
| 235 |
+
(
|
| 236 |
+
"If we include the next biggest data modality, how much does coverage change?",
|
| 237 |
+
"""
|
| 238 |
+
WITH modality_counts AS (
|
| 239 |
+
SELECT
|
| 240 |
+
LOWER(concept) AS modality,
|
| 241 |
+
COUNT(DISTINCT paper_id) AS usage_count
|
| 242 |
+
FROM
|
| 243 |
+
predictions
|
| 244 |
+
WHERE
|
| 245 |
+
tag_type = 'modality'
|
| 246 |
+
AND LOWER(concept) NOT LIKE '%imag%'
|
| 247 |
+
AND LOWER(concept) NOT LIKE '%spectr%'
|
| 248 |
+
GROUP BY
|
| 249 |
+
LOWER(concept)
|
| 250 |
+
ORDER BY
|
| 251 |
+
usage_count DESC
|
| 252 |
+
LIMIT 1
|
| 253 |
+
),
|
| 254 |
+
paper_modalities AS (
|
| 255 |
+
SELECT
|
| 256 |
+
p.paper_id,
|
| 257 |
+
MAX(CASE WHEN LOWER(pred.concept) LIKE '%imag%' THEN 1 ELSE 0 END) AS uses_galaxy_images,
|
| 258 |
+
MAX(CASE WHEN LOWER(pred.concept) LIKE '%spectr%' THEN 1 ELSE 0 END) AS uses_spectra,
|
| 259 |
+
MAX(CASE WHEN LOWER(pred.concept) LIKE (SELECT '%' || modality || '%' FROM modality_counts) THEN 1 ELSE 0 END) AS uses_third_modality
|
| 260 |
+
FROM
|
| 261 |
+
papers p
|
| 262 |
+
LEFT JOIN
|
| 263 |
+
predictions pred ON p.paper_id = pred.paper_id
|
| 264 |
+
WHERE
|
| 265 |
+
p.primary_category LIKE '%astro-ph%'
|
| 266 |
+
AND pred.tag_type = 'modality'
|
| 267 |
+
GROUP BY
|
| 268 |
+
p.paper_id
|
| 269 |
+
),
|
| 270 |
+
coverage_before AS (
|
| 271 |
+
SELECT
|
| 272 |
+
SUM(CASE WHEN uses_galaxy_images = 1 OR uses_spectra = 1 THEN 1 ELSE 0 END) AS covered_papers,
|
| 273 |
+
COUNT(*) AS total_papers
|
| 274 |
+
FROM
|
| 275 |
+
paper_modalities
|
| 276 |
+
),
|
| 277 |
+
coverage_after AS (
|
| 278 |
+
SELECT
|
| 279 |
+
SUM(CASE WHEN uses_galaxy_images = 1 OR uses_spectra = 1 OR uses_third_modality = 1 THEN 1 ELSE 0 END) AS covered_papers,
|
| 280 |
+
COUNT(*) AS total_papers
|
| 281 |
+
FROM
|
| 282 |
+
paper_modalities
|
| 283 |
+
)
|
| 284 |
+
SELECT
|
| 285 |
+
(SELECT modality FROM modality_counts) AS third_modality,
|
| 286 |
+
ROUND(CAST(covered_papers AS FLOAT) / total_papers * 100, 2) AS coverage_before_percent,
|
| 287 |
+
ROUND(CAST((SELECT covered_papers FROM coverage_after) AS FLOAT) / total_papers * 100, 2) AS coverage_after_percent,
|
| 288 |
+
ROUND(CAST((SELECT covered_papers FROM coverage_after) AS FLOAT) / total_papers * 100, 2) -
|
| 289 |
+
ROUND(CAST(covered_papers AS FLOAT) / total_papers * 100, 2) AS coverage_increase_percent
|
| 290 |
+
FROM
|
| 291 |
+
coverage_before
|
| 292 |
+
""",
|
| 293 |
+
),
|
| 294 |
+
(
|
| 295 |
+
"Coverage if we select the next 5 highest modalities?",
|
| 296 |
+
"""
|
| 297 |
+
WITH ranked_modalities AS (
|
| 298 |
+
SELECT
|
| 299 |
+
LOWER(concept) AS modality,
|
| 300 |
+
COUNT(DISTINCT paper_id) AS usage_count,
|
| 301 |
+
ROW_NUMBER() OVER (ORDER BY COUNT(DISTINCT paper_id) DESC) AS rank
|
| 302 |
+
FROM
|
| 303 |
+
predictions
|
| 304 |
+
WHERE
|
| 305 |
+
tag_type = 'modality'
|
| 306 |
+
AND LOWER(concept) NOT LIKE '%imag%'
|
| 307 |
+
AND LOWER(concept) NOT LIKE '%spectr%'
|
| 308 |
+
GROUP BY
|
| 309 |
+
LOWER(concept)
|
| 310 |
+
),
|
| 311 |
+
paper_modalities AS (
|
| 312 |
+
SELECT
|
| 313 |
+
p.paper_id,
|
| 314 |
+
MAX(CASE WHEN LOWER(pred.concept) LIKE '%imag%' THEN 1 ELSE 0 END) AS uses_images,
|
| 315 |
+
MAX(CASE WHEN LOWER(pred.concept) LIKE '%spectr%' THEN 1 ELSE 0 END) AS uses_spectra,
|
| 316 |
+
MAX(CASE WHEN rm.rank = 1 THEN 1 ELSE 0 END) AS uses_modality_1,
|
| 317 |
+
MAX(CASE WHEN rm.rank = 2 THEN 1 ELSE 0 END) AS uses_modality_2,
|
| 318 |
+
MAX(CASE WHEN rm.rank = 3 THEN 1 ELSE 0 END) AS uses_modality_3,
|
| 319 |
+
MAX(CASE WHEN rm.rank = 4 THEN 1 ELSE 0 END) AS uses_modality_4,
|
| 320 |
+
MAX(CASE WHEN rm.rank = 5 THEN 1 ELSE 0 END) AS uses_modality_5
|
| 321 |
+
FROM
|
| 322 |
+
papers p
|
| 323 |
+
LEFT JOIN
|
| 324 |
+
predictions pred ON p.paper_id = pred.paper_id
|
| 325 |
+
LEFT JOIN
|
| 326 |
+
ranked_modalities rm ON LOWER(pred.concept) = rm.modality
|
| 327 |
+
WHERE
|
| 328 |
+
p.primary_category LIKE '%astro-ph%'
|
| 329 |
+
AND pred.tag_type = 'modality'
|
| 330 |
+
GROUP BY
|
| 331 |
+
p.paper_id
|
| 332 |
+
),
|
| 333 |
+
cumulative_coverage AS (
|
| 334 |
+
SELECT
|
| 335 |
+
'Images and Spectra' AS modalities,
|
| 336 |
+
0 AS added_modality_rank,
|
| 337 |
+
SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 THEN 1 ELSE 0 END) AS covered_papers,
|
| 338 |
+
COUNT(*) AS total_papers
|
| 339 |
+
FROM
|
| 340 |
+
paper_modalities
|
| 341 |
+
|
| 342 |
+
UNION ALL
|
| 343 |
+
|
| 344 |
+
SELECT
|
| 345 |
+
'Images, Spectra, and Modality 1' AS modalities,
|
| 346 |
+
1 AS added_modality_rank,
|
| 347 |
+
SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 OR uses_modality_1 = 1 THEN 1 ELSE 0 END) AS covered_papers,
|
| 348 |
+
COUNT(*) AS total_papers
|
| 349 |
+
FROM
|
| 350 |
+
paper_modalities
|
| 351 |
+
|
| 352 |
+
UNION ALL
|
| 353 |
+
|
| 354 |
+
SELECT
|
| 355 |
+
'Images, Spectra, Modality 1, and 2' AS modalities,
|
| 356 |
+
2 AS added_modality_rank,
|
| 357 |
+
SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 OR uses_modality_1 = 1 OR uses_modality_2 = 1 THEN 1 ELSE 0 END) AS covered_papers,
|
| 358 |
+
COUNT(*) AS total_papers
|
| 359 |
+
FROM
|
| 360 |
+
paper_modalities
|
| 361 |
+
|
| 362 |
+
UNION ALL
|
| 363 |
+
|
| 364 |
+
SELECT
|
| 365 |
+
'Images, Spectra, Modality 1, 2, and 3' AS modalities,
|
| 366 |
+
3 AS added_modality_rank,
|
| 367 |
+
SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 OR uses_modality_1 = 1 OR uses_modality_2 = 1 OR uses_modality_3 = 1 THEN 1 ELSE 0 END) AS covered_papers,
|
| 368 |
+
COUNT(*) AS total_papers
|
| 369 |
+
FROM
|
| 370 |
+
paper_modalities
|
| 371 |
+
|
| 372 |
+
UNION ALL
|
| 373 |
+
|
| 374 |
+
SELECT
|
| 375 |
+
'Images, Spectra, Modality 1, 2, 3, and 4' AS modalities,
|
| 376 |
+
4 AS added_modality_rank,
|
| 377 |
+
SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 OR uses_modality_1 = 1 OR uses_modality_2 = 1 OR uses_modality_3 = 1 OR uses_modality_4 = 1 THEN 1 ELSE 0 END) AS covered_papers,
|
| 378 |
+
COUNT(*) AS total_papers
|
| 379 |
+
FROM
|
| 380 |
+
paper_modalities
|
| 381 |
+
|
| 382 |
+
UNION ALL
|
| 383 |
+
|
| 384 |
+
SELECT
|
| 385 |
+
'Images, Spectra, Modality 1, 2, 3, 4, and 5' AS modalities,
|
| 386 |
+
5 AS added_modality_rank,
|
| 387 |
+
SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 OR uses_modality_1 = 1 OR uses_modality_2 = 1 OR uses_modality_3 = 1 OR uses_modality_4 = 1 OR uses_modality_5 = 1 THEN 1 ELSE 0 END) AS covered_papers,
|
| 388 |
+
COUNT(*) AS total_papers
|
| 389 |
+
FROM
|
| 390 |
+
paper_modalities
|
| 391 |
+
)
|
| 392 |
+
SELECT
|
| 393 |
+
cc.modalities,
|
| 394 |
+
COALESCE(rm.modality, 'N/A') AS added_modality,
|
| 395 |
+
rm.usage_count AS added_modality_usage,
|
| 396 |
+
ROUND(CAST(cc.covered_papers AS FLOAT) / cc.total_papers * 100, 2) AS coverage_percent,
|
| 397 |
+
ROUND(CAST(cc.covered_papers AS FLOAT) / cc.total_papers * 100, 2) -
|
| 398 |
+
LAG(ROUND(CAST(cc.covered_papers AS FLOAT) / cc.total_papers * 100, 2), 1, 0) OVER (ORDER BY cc.added_modality_rank) AS coverage_increase_percent
|
| 399 |
+
FROM
|
| 400 |
+
cumulative_coverage cc
|
| 401 |
+
LEFT JOIN
|
| 402 |
+
ranked_modalities rm ON cc.added_modality_rank = rm.rank
|
| 403 |
+
ORDER BY
|
| 404 |
+
cc.added_modality_rank
|
| 405 |
+
""",
|
| 406 |
+
),
|
| 407 |
+
(
|
| 408 |
+
"List all papers",
|
| 409 |
+
"SELECT paper_id, abstract AS abstract_preview, authors, primary_category FROM papers",
|
| 410 |
+
),
|
| 411 |
+
(
|
| 412 |
+
"Count papers by category",
|
| 413 |
+
"SELECT primary_category, COUNT(*) as paper_count FROM papers GROUP BY primary_category ORDER BY paper_count DESC",
|
| 414 |
+
),
|
| 415 |
+
(
|
| 416 |
+
"Top authors with most papers",
|
| 417 |
+
"""
|
| 418 |
+
WITH author_papers AS (
|
| 419 |
+
SELECT json_each.value AS author
|
| 420 |
+
FROM papers, json_each(papers.authors)
|
| 421 |
+
)
|
| 422 |
+
SELECT author, COUNT(*) as paper_count
|
| 423 |
+
FROM author_papers
|
| 424 |
+
GROUP BY author
|
| 425 |
+
ORDER BY paper_count DESC
|
| 426 |
+
""",
|
| 427 |
+
),
|
| 428 |
+
(
|
| 429 |
+
"Papers with 'quantum' in abstract",
|
| 430 |
+
"SELECT paper_id, abstract AS abstract_preview FROM papers WHERE abstract LIKE '%quantum%'",
|
| 431 |
+
),
|
| 432 |
+
(
|
| 433 |
+
"Most common concepts",
|
| 434 |
+
"SELECT concept, COUNT(*) as concept_count FROM predictions GROUP BY concept ORDER BY concept_count DESC",
|
| 435 |
+
),
|
| 436 |
+
(
|
| 437 |
+
"Papers with multiple authors",
|
| 438 |
+
"""
|
| 439 |
+
SELECT paper_id, json_array_length(authors) as author_count, authors
|
| 440 |
+
FROM papers
|
| 441 |
+
WHERE json_array_length(authors) > 1
|
| 442 |
+
ORDER BY author_count DESC
|
| 443 |
+
""",
|
| 444 |
+
),
|
| 445 |
+
]
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.40.0
|
| 2 |
+
networkx==3.3
|
| 3 |
+
pandas==2.2.2
|
| 4 |
+
plotly==5.23.0
|
| 5 |
+
tabulate==0.9.0
|
| 6 |
+
fastapi==0.104.1
|
| 7 |
+
pydantic==2.5.3
|
| 8 |
+
uvicorn==0.27.1
|
scripts/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 5 |
+
if ROOT_DIR not in sys.path:
|
| 6 |
+
sys.path.insert(0, ROOT_DIR)
|
scripts/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (566 Bytes). View file
|
|
|
scripts/__pycache__/create_db.cpython-311.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
scripts/__pycache__/run_db_interface.cpython-311.pyc
ADDED
|
Binary file (29 kB). View file
|
|
|
scripts/__pycache__/run_db_interface_improved.cpython-311.pyc
ADDED
|
Binary file (29.2 kB). View file
|
|
|
scripts/create_db.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import click
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sqlite3
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
| 8 |
+
from config import DEFAULT_TABLES_DIR, DEFAULT_MODEL_ID, DEFAULT_INTERFACE_MODEL_ID
|
| 9 |
+
from src.processing.generate import get_sentences, generate_prediction
|
| 10 |
+
from src.utils.utils import load_model_and_tokenizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ArxivDatabase:
|
| 14 |
+
def __init__(self, db_path, model_id=None):
|
| 15 |
+
self.conn = None
|
| 16 |
+
self.cursor = None
|
| 17 |
+
self.db_path = db_path
|
| 18 |
+
self.model_id = model_id if model_id else DEFAULT_INTERFACE_MODEL_ID
|
| 19 |
+
self.model = None
|
| 20 |
+
self.tokenizer = None
|
| 21 |
+
self.is_db_empty = True
|
| 22 |
+
self.paper_table = """CREATE TABLE IF NOT EXISTS papers
|
| 23 |
+
(paper_id TEXT PRIMARY KEY, abstract TEXT, authors TEXT,
|
| 24 |
+
primary_category TEXT, url TEXT, updated_on TEXT, sentence_count INTEGER)"""
|
| 25 |
+
self.pred_table = """CREATE TABLE IF NOT EXISTS predictions
|
| 26 |
+
(id INTEGER PRIMARY KEY AUTOINCREMENT, paper_id TEXT, sentence_index INTEGER,
|
| 27 |
+
tag_type TEXT, concept TEXT,
|
| 28 |
+
FOREIGN KEY (paper_id) REFERENCES papers(paper_id))"""
|
| 29 |
+
|
| 30 |
+
# def init_db(self):
|
| 31 |
+
# self.cursor.execute(self.paper_table)
|
| 32 |
+
# self.cursor.execute(self.pred_table)
|
| 33 |
+
|
| 34 |
+
# print("Database and tables created successfully.")
|
| 35 |
+
# self.is_db_empty = self.is_empty()
|
| 36 |
+
|
| 37 |
+
def init_db(self):
|
| 38 |
+
self.conn = sqlite3.connect(self.db_path)
|
| 39 |
+
self.cursor = self.conn.cursor()
|
| 40 |
+
self.cursor.execute(self.paper_table)
|
| 41 |
+
self.cursor.execute(self.pred_table)
|
| 42 |
+
self.conn.commit()
|
| 43 |
+
self.is_db_empty = self.is_empty()
|
| 44 |
+
if not self.is_db_empty:
|
| 45 |
+
print("Database already contains data.")
|
| 46 |
+
else:
|
| 47 |
+
print("Database and tables created successfully.")
|
| 48 |
+
|
| 49 |
+
def is_empty(self):
|
| 50 |
+
try:
|
| 51 |
+
self.cursor.execute("SELECT COUNT(*) FROM papers")
|
| 52 |
+
count = self.cursor.fetchone()[0]
|
| 53 |
+
return count == 0
|
| 54 |
+
except sqlite3.OperationalError:
|
| 55 |
+
return True
|
| 56 |
+
|
| 57 |
+
def get_connection(self):
|
| 58 |
+
return sqlite3.connect(self.conn.path)
|
| 59 |
+
|
| 60 |
+
def populate_db(self, data_path, pred_path):
|
| 61 |
+
papers_info = self._insert_papers(data_path)
|
| 62 |
+
self._insert_predictions(pred_path, papers_info)
|
| 63 |
+
print("Database population completed.")
|
| 64 |
+
|
| 65 |
+
def _insert_papers(self, data_path):
|
| 66 |
+
papers_info = []
|
| 67 |
+
seen_papers = set()
|
| 68 |
+
with open(data_path, "r") as f:
|
| 69 |
+
for line in f:
|
| 70 |
+
paper = json.loads(line)
|
| 71 |
+
if paper["id"] in seen_papers:
|
| 72 |
+
continue
|
| 73 |
+
seen_papers.add(paper["id"])
|
| 74 |
+
sentence_count = len(get_sentences(paper["id"])) + len(
|
| 75 |
+
get_sentences(paper["abstract"])
|
| 76 |
+
)
|
| 77 |
+
papers_info.append((paper["id"], sentence_count))
|
| 78 |
+
self.cursor.execute(
|
| 79 |
+
"""INSERT OR REPLACE INTO papers VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
| 80 |
+
(
|
| 81 |
+
paper["id"],
|
| 82 |
+
paper["abstract"],
|
| 83 |
+
json.dumps(paper["authors"]),
|
| 84 |
+
json.dumps(paper["primary_category"]),
|
| 85 |
+
json.dumps(paper["url"]),
|
| 86 |
+
json.dumps(paper["updated"]),
|
| 87 |
+
sentence_count,
|
| 88 |
+
),
|
| 89 |
+
)
|
| 90 |
+
print(f"Inserted {len(papers_info)} papers.")
|
| 91 |
+
return papers_info
|
| 92 |
+
|
| 93 |
+
def _insert_predictions(self, pred_path, papers_info):
|
| 94 |
+
with open(pred_path, "r") as f:
|
| 95 |
+
predictions = json.load(f)
|
| 96 |
+
predicted_tags = predictions["predicted_tags"]
|
| 97 |
+
|
| 98 |
+
k = 0
|
| 99 |
+
papers_with_predictions = set()
|
| 100 |
+
papers_without_predictions = []
|
| 101 |
+
for paper_id, sentence_count in papers_info:
|
| 102 |
+
paper_predictions = predicted_tags[k : k + sentence_count]
|
| 103 |
+
|
| 104 |
+
has_predictions = False
|
| 105 |
+
for sentence_index, pred in enumerate(paper_predictions):
|
| 106 |
+
if pred: # If the prediction is not an empty dictionary
|
| 107 |
+
has_predictions = True
|
| 108 |
+
for tag_type, concepts in pred.items():
|
| 109 |
+
for concept in concepts:
|
| 110 |
+
self.cursor.execute(
|
| 111 |
+
"""INSERT INTO predictions (paper_id, sentence_index, tag_type, concept)
|
| 112 |
+
VALUES (?, ?, ?, ?)""",
|
| 113 |
+
(paper_id, sentence_index, tag_type, concept),
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
# Insert a null prediction to ensure the paper is counted
|
| 117 |
+
self.cursor.execute(
|
| 118 |
+
"""INSERT INTO predictions (paper_id, sentence_index, tag_type, concept)
|
| 119 |
+
VALUES (?, ?, ?, ?)""",
|
| 120 |
+
(paper_id, sentence_index, "null", "null"),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if has_predictions:
|
| 124 |
+
papers_with_predictions.add(paper_id)
|
| 125 |
+
else:
|
| 126 |
+
papers_without_predictions.append(paper_id)
|
| 127 |
+
|
| 128 |
+
k += sentence_count
|
| 129 |
+
|
| 130 |
+
print(f"Inserted predictions for {len(papers_with_predictions)} papers.")
|
| 131 |
+
print(f"Papers without any predictions: {len(papers_without_predictions)}")
|
| 132 |
+
|
| 133 |
+
if k < len(predicted_tags):
|
| 134 |
+
print(f"Warning: {len(predicted_tags) - k} predictions were not inserted.")
|
| 135 |
+
|
| 136 |
+
def load_model(self):
|
| 137 |
+
if self.model is None:
|
| 138 |
+
try:
|
| 139 |
+
self.model, self.tokenizer = load_model_and_tokenizer(self.model_id)
|
| 140 |
+
return f"Model {self.model_id} loaded successfully."
|
| 141 |
+
except Exception as e:
|
| 142 |
+
return f"Error loading model: {str(e)}"
|
| 143 |
+
else:
|
| 144 |
+
return "Model is already loaded."
|
| 145 |
+
|
| 146 |
+
def natural_language_to_sql(self, question):
|
| 147 |
+
system_prompt = "You are an assistant who converts natural language questions to SQL queries to query a database of scientific papers."
|
| 148 |
+
table = self.paper_table + "; " + self.pred_table
|
| 149 |
+
prefix = (
|
| 150 |
+
f"[INST] Write SQLite query to answer the following question given the database schema. Please wrap your code answer using "
|
| 151 |
+
f"```: Schema: {table} Question: {question}[/INST] Here is the SQLite query to answer to the question: {question}: ``` "
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
sql_query = generate_prediction(
|
| 155 |
+
self.model, self.tokenizer, prefix, question, "sql", system_prompt
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
sql_query = sql_query.split("```")[1]
|
| 159 |
+
|
| 160 |
+
return sql_query
|
| 161 |
+
|
| 162 |
+
def execute_query(self, sql_query):
|
| 163 |
+
try:
|
| 164 |
+
self.cursor.execute(sql_query)
|
| 165 |
+
results = self.cursor.fetchall()
|
| 166 |
+
return results if results else []
|
| 167 |
+
except sqlite3.Error as e:
|
| 168 |
+
return [(f"An error occurred: {e}",)]
|
| 169 |
+
|
| 170 |
+
def query_db(self, question, is_sql):
|
| 171 |
+
if self.is_db_empty:
|
| 172 |
+
return "The database is empty. Please populate it with data first."
|
| 173 |
+
|
| 174 |
+
try:
|
| 175 |
+
if is_sql:
|
| 176 |
+
sql_query = question.strip()
|
| 177 |
+
else:
|
| 178 |
+
nl_to_sql = self.natural_language_to_sql(question)
|
| 179 |
+
sql_query = nl_to_sql.replace("```sql", "").replace("```", "").strip()
|
| 180 |
+
|
| 181 |
+
results = self.execute_query(sql_query)
|
| 182 |
+
|
| 183 |
+
output = f"SQL Query: {sql_query}\n\nResults:\n"
|
| 184 |
+
if isinstance(results, list):
|
| 185 |
+
if len(results) > 0:
|
| 186 |
+
for row in results:
|
| 187 |
+
output += str(row) + "\n"
|
| 188 |
+
else:
|
| 189 |
+
output += "No results found."
|
| 190 |
+
else:
|
| 191 |
+
output += str(results) # In case of an error message
|
| 192 |
+
|
| 193 |
+
return output
|
| 194 |
+
except Exception as e:
|
| 195 |
+
return f"An error occurred: {str(e)}"
|
| 196 |
+
|
| 197 |
+
def close(self):
|
| 198 |
+
self.conn.commit()
|
| 199 |
+
self.conn.close()
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def check_db_exists(db_path):
|
| 203 |
+
return os.path.exists(db_path) and os.path.getsize(db_path) > 0
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@click.command()
|
| 207 |
+
@click.option(
|
| 208 |
+
"--data_path", help="Path to the data file containing the papers information."
|
| 209 |
+
)
|
| 210 |
+
@click.option("--pred_path", help="Path to the predictions file.")
|
| 211 |
+
@click.option("--db_name", default="arxiv.db", help="Name of the database to create.")
|
| 212 |
+
@click.option(
|
| 213 |
+
"--force", is_flag=True, help="Force overwrite if database already exists"
|
| 214 |
+
)
|
| 215 |
+
def main(data_path, pred_path, db_name, force):
|
| 216 |
+
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 217 |
+
tables_dir = os.path.join(ROOT, DEFAULT_TABLES_DIR)
|
| 218 |
+
os.makedirs(tables_dir, exist_ok=True)
|
| 219 |
+
db_path = os.path.join(tables_dir, db_name)
|
| 220 |
+
|
| 221 |
+
db_exists = check_db_exists(db_path)
|
| 222 |
+
|
| 223 |
+
db = ArxivDatabase(db_path)
|
| 224 |
+
db.init_db()
|
| 225 |
+
|
| 226 |
+
if db_exists and not db.is_db_empty:
|
| 227 |
+
if not force:
|
| 228 |
+
print(f"Warning: The database '{db_name}' already exists and is not empty.")
|
| 229 |
+
overwrite = input("Do you want to overwrite it? (y/N): ").lower().strip()
|
| 230 |
+
if overwrite != "y":
|
| 231 |
+
print("Operation cancelled.")
|
| 232 |
+
db.close()
|
| 233 |
+
return
|
| 234 |
+
else:
|
| 235 |
+
print(
|
| 236 |
+
f"Warning: Overwriting existing database '{db_name}' due to --force flag."
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
db.populate_db(data_path, pred_path)
|
| 240 |
+
db.close()
|
| 241 |
+
|
| 242 |
+
print(f"Database created and populated at: {db_path}")
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
if __name__ == "__main__":
|
| 246 |
+
main()
|
scripts/run_db_interface.py
ADDED
|
@@ -0,0 +1,704 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import networkx as nx
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import plotly.graph_objects as go
|
| 7 |
+
import re
|
| 8 |
+
import sys
|
| 9 |
+
import sqlite3
|
| 10 |
+
import tempfile
|
| 11 |
+
import time
|
| 12 |
+
import uvicorn
|
| 13 |
+
|
| 14 |
+
from contextlib import contextmanager
|
| 15 |
+
from fastapi import FastAPI, Request
|
| 16 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 17 |
+
from gradio.routes import mount_gradio_app
|
| 18 |
+
from plotly.subplots import make_subplots
|
| 19 |
+
from tabulate import tabulate
|
| 20 |
+
from typing import Optional
|
| 21 |
+
|
| 22 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 23 |
+
if ROOT_DIR not in sys.path:
|
| 24 |
+
sys.path.insert(0, ROOT_DIR)
|
| 25 |
+
|
| 26 |
+
from scripts.create_db import ArxivDatabase
|
| 27 |
+
from config import (
|
| 28 |
+
DEFAULT_TABLES_DIR,
|
| 29 |
+
DEFAULT_INTERFACE_MODEL_ID,
|
| 30 |
+
COOCCURRENCE_QUERY,
|
| 31 |
+
canned_queries,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
app = FastAPI()
|
| 35 |
+
|
| 36 |
+
# Add CORS middleware
|
| 37 |
+
app.add_middleware(
|
| 38 |
+
CORSMiddleware,
|
| 39 |
+
allow_origins=["*"],
|
| 40 |
+
allow_credentials=True,
|
| 41 |
+
allow_methods=["*"],
|
| 42 |
+
allow_headers=["*"],
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
db: Optional[ArxivDatabase] = None
|
| 46 |
+
|
| 47 |
+
last_update_time = 0
|
| 48 |
+
update_delay = 0.5 # Delay in seconds
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def truncate_or_wrap_text(text, max_length=50, wrap=False):
|
| 52 |
+
"""Truncate text to a maximum length, adding ellipsis if truncated, or wrap if specified."""
|
| 53 |
+
if wrap:
|
| 54 |
+
return "\n".join(
|
| 55 |
+
text[i : i + max_length] for i in range(0, len(text), max_length)
|
| 56 |
+
)
|
| 57 |
+
return text[:max_length] + "..." if len(text) > max_length else text
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def format_url(url):
|
| 61 |
+
"""Format URL to be more compact in the table."""
|
| 62 |
+
return url.split("/")[-1] if url.startswith("http") else url
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_db_path():
|
| 66 |
+
"""Get the database directory path based on environment"""
|
| 67 |
+
# First try local path
|
| 68 |
+
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 69 |
+
tables_dir = os.path.join(ROOT, DEFAULT_TABLES_DIR)
|
| 70 |
+
|
| 71 |
+
if not os.path.exists(tables_dir):
|
| 72 |
+
# If running on Spaces, try the root directory
|
| 73 |
+
tables_dir = os.path.join(ROOT, "data", "databases")
|
| 74 |
+
if not os.path.exists(tables_dir):
|
| 75 |
+
print(f"No database directory found")
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
print(f"Using database directory: {tables_dir}")
|
| 79 |
+
return tables_dir
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_available_databases():
|
| 83 |
+
"""Get available databases from either local path or Hugging Face cache."""
|
| 84 |
+
tables_dir = get_db_path()
|
| 85 |
+
if not tables_dir:
|
| 86 |
+
return []
|
| 87 |
+
|
| 88 |
+
files = os.listdir(tables_dir)
|
| 89 |
+
print(f"All files found: {files}")
|
| 90 |
+
|
| 91 |
+
# Include all files except .md files
|
| 92 |
+
databases = [f for f in files if not f.endswith(".md")]
|
| 93 |
+
print(f"Database files: {databases}")
|
| 94 |
+
|
| 95 |
+
return databases
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def query_db(query, is_sql, limit=None, wrap=False):
|
| 99 |
+
global db
|
| 100 |
+
if db is None:
|
| 101 |
+
return pd.DataFrame({"Error": ["Please load a database first."]})
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
with sqlite3.connect(db.db_path) as conn:
|
| 105 |
+
cursor = conn.cursor()
|
| 106 |
+
|
| 107 |
+
query = " ".join(query.strip().split("\n")).rstrip(";")
|
| 108 |
+
|
| 109 |
+
if limit is not None:
|
| 110 |
+
if "LIMIT" in query.upper():
|
| 111 |
+
# Replace existing LIMIT clause
|
| 112 |
+
query = re.sub(
|
| 113 |
+
r"LIMIT\s+\d+", f"LIMIT {limit}", query, flags=re.IGNORECASE
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
query += f" LIMIT {limit}"
|
| 117 |
+
|
| 118 |
+
cursor.execute(query)
|
| 119 |
+
|
| 120 |
+
column_names = [description[0] for description in cursor.description]
|
| 121 |
+
|
| 122 |
+
results = cursor.fetchall()
|
| 123 |
+
|
| 124 |
+
df = pd.DataFrame(results, columns=column_names)
|
| 125 |
+
|
| 126 |
+
for column in df.columns:
|
| 127 |
+
if df[column].dtype == "object":
|
| 128 |
+
df[column] = df[column].apply(
|
| 129 |
+
lambda x: (
|
| 130 |
+
format_url(x)
|
| 131 |
+
if column == "url"
|
| 132 |
+
else truncate_or_wrap_text(x, wrap=wrap)
|
| 133 |
+
)
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
return df
|
| 137 |
+
|
| 138 |
+
except sqlite3.Error as e:
|
| 139 |
+
return pd.DataFrame({"Error": [f"Database error: {str(e)}"]})
|
| 140 |
+
except Exception as e:
|
| 141 |
+
return pd.DataFrame({"Error": [f"An unexpected error occurred: {str(e)}"]})
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def generate_concept_cooccurrence_graph(db_path, tag_type=None):
|
| 145 |
+
conn = sqlite3.connect(db_path)
|
| 146 |
+
|
| 147 |
+
query = COOCCURRENCE_QUERY
|
| 148 |
+
if tag_type and tag_type != "All":
|
| 149 |
+
query = query.replace(
|
| 150 |
+
"WHERE p1.tag_type = p2.tag_type",
|
| 151 |
+
f"WHERE p1.tag_type = p2.tag_type AND p1.tag_type = '{tag_type}'",
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
df = pd.read_sql_query(query, conn)
|
| 155 |
+
conn.close()
|
| 156 |
+
|
| 157 |
+
G = nx.from_pandas_edgelist(df, "concept1", "concept2", "co_occurrences")
|
| 158 |
+
pos = nx.spring_layout(G, k=0.5, iterations=50)
|
| 159 |
+
|
| 160 |
+
edge_trace = go.Scatter(
|
| 161 |
+
x=[], y=[], line=dict(width=0.5, color="#888"), hoverinfo="none", mode="lines"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
node_trace = go.Scatter(
|
| 165 |
+
x=[],
|
| 166 |
+
y=[],
|
| 167 |
+
mode="markers",
|
| 168 |
+
hoverinfo="text",
|
| 169 |
+
marker=dict(
|
| 170 |
+
showscale=True,
|
| 171 |
+
colorscale="YlGnBu",
|
| 172 |
+
size=10,
|
| 173 |
+
colorbar=dict(
|
| 174 |
+
thickness=15,
|
| 175 |
+
title="Node Connections",
|
| 176 |
+
xanchor="left",
|
| 177 |
+
titleside="right",
|
| 178 |
+
),
|
| 179 |
+
),
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def update_traces(selected_node=None, depth=0):
|
| 183 |
+
nonlocal edge_trace, node_trace
|
| 184 |
+
|
| 185 |
+
if selected_node and depth > 0:
|
| 186 |
+
nodes_to_show = set([selected_node])
|
| 187 |
+
frontier = set([selected_node])
|
| 188 |
+
for _ in range(depth):
|
| 189 |
+
new_frontier = set()
|
| 190 |
+
for node in frontier:
|
| 191 |
+
new_frontier.update(G.neighbors(node))
|
| 192 |
+
nodes_to_show.update(new_frontier)
|
| 193 |
+
frontier = new_frontier
|
| 194 |
+
sub_G = G.subgraph(nodes_to_show)
|
| 195 |
+
else:
|
| 196 |
+
sub_G = G
|
| 197 |
+
|
| 198 |
+
edge_x, edge_y = [], []
|
| 199 |
+
for edge in sub_G.edges():
|
| 200 |
+
x0, y0 = pos[edge[0]]
|
| 201 |
+
x1, y1 = pos[edge[1]]
|
| 202 |
+
edge_x.extend([x0, x1, None])
|
| 203 |
+
edge_y.extend([y0, y1, None])
|
| 204 |
+
|
| 205 |
+
edge_trace.x = edge_x
|
| 206 |
+
edge_trace.y = edge_y
|
| 207 |
+
|
| 208 |
+
node_x, node_y = [], []
|
| 209 |
+
for node in sub_G.nodes():
|
| 210 |
+
x, y = pos[node]
|
| 211 |
+
node_x.append(x)
|
| 212 |
+
node_y.append(y)
|
| 213 |
+
|
| 214 |
+
node_trace.x = node_x
|
| 215 |
+
node_trace.y = node_y
|
| 216 |
+
|
| 217 |
+
node_adjacencies = []
|
| 218 |
+
node_text = []
|
| 219 |
+
for node in sub_G.nodes():
|
| 220 |
+
adjacencies = list(G.adj[node])
|
| 221 |
+
node_adjacencies.append(len(adjacencies))
|
| 222 |
+
node_text.append(f"{node}<br># of connections: {len(adjacencies)}")
|
| 223 |
+
|
| 224 |
+
node_trace.marker.color = node_adjacencies
|
| 225 |
+
node_trace.text = node_text
|
| 226 |
+
|
| 227 |
+
update_traces()
|
| 228 |
+
|
| 229 |
+
fig = go.Figure(
|
| 230 |
+
data=[edge_trace, node_trace],
|
| 231 |
+
layout=go.Layout(
|
| 232 |
+
title=f'Concept Co-occurrence Network {f"({tag_type})" if tag_type and tag_type != "All" else ""}',
|
| 233 |
+
titlefont_size=16,
|
| 234 |
+
showlegend=False,
|
| 235 |
+
hovermode="closest",
|
| 236 |
+
margin=dict(b=20, l=5, r=5, t=40),
|
| 237 |
+
annotations=[
|
| 238 |
+
dict(
|
| 239 |
+
text="",
|
| 240 |
+
showarrow=False,
|
| 241 |
+
xref="paper",
|
| 242 |
+
yref="paper",
|
| 243 |
+
x=0.005,
|
| 244 |
+
y=-0.002,
|
| 245 |
+
)
|
| 246 |
+
],
|
| 247 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 248 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 249 |
+
),
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
fig.update_layout(
|
| 253 |
+
updatemenus=[
|
| 254 |
+
dict(
|
| 255 |
+
type="buttons",
|
| 256 |
+
direction="left",
|
| 257 |
+
buttons=[
|
| 258 |
+
dict(
|
| 259 |
+
args=[{"visible": [True, True]}],
|
| 260 |
+
label="Full Graph",
|
| 261 |
+
method="update",
|
| 262 |
+
),
|
| 263 |
+
dict(
|
| 264 |
+
args=[
|
| 265 |
+
{
|
| 266 |
+
"visible": [True, True],
|
| 267 |
+
"xaxis.range": [-1, 1],
|
| 268 |
+
"yaxis.range": [-1, 1],
|
| 269 |
+
}
|
| 270 |
+
],
|
| 271 |
+
label="Core View",
|
| 272 |
+
method="relayout",
|
| 273 |
+
),
|
| 274 |
+
dict(
|
| 275 |
+
args=[
|
| 276 |
+
{
|
| 277 |
+
"visible": [True, True],
|
| 278 |
+
"xaxis.range": [-0.2, 0.2],
|
| 279 |
+
"yaxis.range": [-0.2, 0.2],
|
| 280 |
+
}
|
| 281 |
+
],
|
| 282 |
+
label="Detailed View",
|
| 283 |
+
method="relayout",
|
| 284 |
+
),
|
| 285 |
+
],
|
| 286 |
+
pad={"r": 10, "t": 10},
|
| 287 |
+
showactive=True,
|
| 288 |
+
x=0.11,
|
| 289 |
+
xanchor="left",
|
| 290 |
+
y=1.1,
|
| 291 |
+
yanchor="top",
|
| 292 |
+
),
|
| 293 |
+
]
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
return fig, G, pos, update_traces
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def load_database_with_graphs(db_name):
|
| 300 |
+
"""Load database from either local path or Hugging Face cache."""
|
| 301 |
+
global db
|
| 302 |
+
tables_dir = get_db_path()
|
| 303 |
+
if not tables_dir:
|
| 304 |
+
return f"No database directory found.", None
|
| 305 |
+
|
| 306 |
+
db_path = os.path.join(tables_dir, db_name)
|
| 307 |
+
if not os.path.exists(db_path):
|
| 308 |
+
return f"Database {db_name} does not exist.", None
|
| 309 |
+
|
| 310 |
+
db = ArxivDatabase(db_path)
|
| 311 |
+
db.init_db()
|
| 312 |
+
|
| 313 |
+
if db.is_db_empty:
|
| 314 |
+
return (
|
| 315 |
+
f"Database loaded from {db_path}, but it is empty. Please populate it with data.",
|
| 316 |
+
None,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
graph, _, _, _ = generate_concept_cooccurrence_graph(db_path)
|
| 320 |
+
return f"Database loaded from {db_path}", graph
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
css = """
|
| 324 |
+
#selected-query {
|
| 325 |
+
max-height: 100px;
|
| 326 |
+
overflow-y: auto;
|
| 327 |
+
white-space: pre-wrap;
|
| 328 |
+
word-break: break-word;
|
| 329 |
+
}
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def create_demo():
|
| 334 |
+
with gr.Blocks() as demo:
|
| 335 |
+
gr.Markdown("# ArXiv Database Query Interface")
|
| 336 |
+
|
| 337 |
+
with gr.Row():
|
| 338 |
+
db_dropdown = gr.Dropdown(
|
| 339 |
+
choices=get_available_databases(),
|
| 340 |
+
label="Select Database",
|
| 341 |
+
value=get_available_databases(),
|
| 342 |
+
)
|
| 343 |
+
# load_db_btn = gr.Button("Load Database", size="sm")
|
| 344 |
+
status = gr.Textbox(label="Status")
|
| 345 |
+
|
| 346 |
+
with gr.Row():
|
| 347 |
+
graph_output = gr.Plot(label="Concept Co-occurrence Graph")
|
| 348 |
+
|
| 349 |
+
with gr.Row():
|
| 350 |
+
tag_type_dropdown = gr.Dropdown(
|
| 351 |
+
choices=[
|
| 352 |
+
"All",
|
| 353 |
+
"model",
|
| 354 |
+
"task",
|
| 355 |
+
"dataset",
|
| 356 |
+
"field",
|
| 357 |
+
"modality",
|
| 358 |
+
"method",
|
| 359 |
+
"object",
|
| 360 |
+
"property",
|
| 361 |
+
"instrument",
|
| 362 |
+
],
|
| 363 |
+
label="Select Tag Type",
|
| 364 |
+
value="All",
|
| 365 |
+
)
|
| 366 |
+
highlight_input = gr.Textbox(label="Highlight Concepts (comma-separated)")
|
| 367 |
+
|
| 368 |
+
with gr.Row():
|
| 369 |
+
node_dropdown = gr.Dropdown(label="Select Node", choices=[])
|
| 370 |
+
depth_slider = gr.Slider(
|
| 371 |
+
minimum=0, maximum=5, step=1, value=0, label="Connection Depth"
|
| 372 |
+
)
|
| 373 |
+
update_graph_button = gr.Button("Update Graph")
|
| 374 |
+
|
| 375 |
+
with gr.Row():
|
| 376 |
+
wrap_checkbox = gr.Checkbox(label="Wrap long text", value=False)
|
| 377 |
+
canned_query_dropdown = gr.Dropdown(
|
| 378 |
+
choices=[q[0] for q in canned_queries], label="Select Query", scale=3
|
| 379 |
+
)
|
| 380 |
+
limit_input = gr.Number(
|
| 381 |
+
label="Limit", value=10000, step=1, minimum=1, scale=1
|
| 382 |
+
)
|
| 383 |
+
selected_query = gr.Textbox(
|
| 384 |
+
label="Selected Query",
|
| 385 |
+
interactive=False,
|
| 386 |
+
scale=2,
|
| 387 |
+
show_label=True,
|
| 388 |
+
show_copy_button=True,
|
| 389 |
+
elem_id="selected-query",
|
| 390 |
+
)
|
| 391 |
+
canned_query_submit = gr.Button("Submit Query", size="sm", scale=1)
|
| 392 |
+
|
| 393 |
+
with gr.Row():
|
| 394 |
+
sql_input = gr.Textbox(label="Custom SQL Query", lines=3, scale=4)
|
| 395 |
+
sql_submit = gr.Button("Submit Custom SQL", size="sm", scale=1)
|
| 396 |
+
|
| 397 |
+
# with gr.Row():
|
| 398 |
+
# nl_query_input = gr.Textbox(
|
| 399 |
+
# label="Natural Language Query", lines=2, scale=4
|
| 400 |
+
# )
|
| 401 |
+
# nl_query_submit = gr.Button("Convert to SQL", size="sm", scale=1)
|
| 402 |
+
|
| 403 |
+
output = gr.DataFrame(label="Results", wrap=True)
|
| 404 |
+
|
| 405 |
+
with gr.Row():
|
| 406 |
+
copy_button = gr.Button("Copy as Markdown")
|
| 407 |
+
download_button = gr.Button("Download as CSV")
|
| 408 |
+
|
| 409 |
+
def debounced_update_graph(
|
| 410 |
+
db_name, tag_type, highlight_concepts, selected_node, depth
|
| 411 |
+
):
|
| 412 |
+
global last_update_time
|
| 413 |
+
|
| 414 |
+
current_time = time.time()
|
| 415 |
+
if current_time - last_update_time < update_delay:
|
| 416 |
+
return None, [] # Return early if not enough time has passed
|
| 417 |
+
|
| 418 |
+
last_update_time = current_time
|
| 419 |
+
|
| 420 |
+
if not db_name:
|
| 421 |
+
return None, []
|
| 422 |
+
|
| 423 |
+
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 424 |
+
db_path = os.path.join(ROOT, DEFAULT_TABLES_DIR, db_name)
|
| 425 |
+
fig, G, pos, update_traces = generate_concept_cooccurrence_graph(
|
| 426 |
+
db_path, tag_type
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
if isinstance(selected_node, list):
|
| 430 |
+
selected_node = selected_node[0] if selected_node else None
|
| 431 |
+
|
| 432 |
+
highlight_nodes = (
|
| 433 |
+
[node.strip() for node in highlight_concepts.split(",")]
|
| 434 |
+
if highlight_concepts
|
| 435 |
+
else []
|
| 436 |
+
)
|
| 437 |
+
primary_node = highlight_nodes[0] if highlight_nodes else None
|
| 438 |
+
|
| 439 |
+
if primary_node and primary_node in G.nodes():
|
| 440 |
+
# Apply node selection and depth filter
|
| 441 |
+
nodes_to_show = set([primary_node])
|
| 442 |
+
if depth > 0:
|
| 443 |
+
frontier = set([primary_node])
|
| 444 |
+
for _ in range(depth):
|
| 445 |
+
new_frontier = set()
|
| 446 |
+
for node in frontier:
|
| 447 |
+
new_frontier.update(G.neighbors(node))
|
| 448 |
+
nodes_to_show.update(new_frontier)
|
| 449 |
+
frontier = new_frontier
|
| 450 |
+
|
| 451 |
+
sub_G = G.subgraph(nodes_to_show)
|
| 452 |
+
|
| 453 |
+
# Update traces with the filtered graph
|
| 454 |
+
edge_x, edge_y = [], []
|
| 455 |
+
for edge in sub_G.edges():
|
| 456 |
+
x0, y0 = pos[edge[0]]
|
| 457 |
+
x1, y1 = pos[edge[1]]
|
| 458 |
+
edge_x.extend([x0, x1, None])
|
| 459 |
+
edge_y.extend([y0, y1, None])
|
| 460 |
+
|
| 461 |
+
fig.data[0].x = edge_x
|
| 462 |
+
fig.data[0].y = edge_y
|
| 463 |
+
|
| 464 |
+
node_x, node_y = [], []
|
| 465 |
+
for node in sub_G.nodes():
|
| 466 |
+
x, y = pos[node]
|
| 467 |
+
node_x.append(x)
|
| 468 |
+
node_y.append(y)
|
| 469 |
+
|
| 470 |
+
fig.data[1].x = node_x
|
| 471 |
+
fig.data[1].y = node_y
|
| 472 |
+
|
| 473 |
+
# Color nodes based on their distance from the primary node and highlight status
|
| 474 |
+
node_colors = []
|
| 475 |
+
node_sizes = []
|
| 476 |
+
for node in sub_G.nodes():
|
| 477 |
+
if node in highlight_nodes:
|
| 478 |
+
node_colors.append(
|
| 479 |
+
"rgba(255,0,0,1)"
|
| 480 |
+
) # Red for highlighted nodes
|
| 481 |
+
node_sizes.append(15)
|
| 482 |
+
else:
|
| 483 |
+
distance = nx.shortest_path_length(
|
| 484 |
+
sub_G, source=primary_node, target=node
|
| 485 |
+
)
|
| 486 |
+
intensity = max(0, 1 - (distance / (depth + 1)))
|
| 487 |
+
node_colors.append(f"rgba(0,0,255,{intensity})")
|
| 488 |
+
node_sizes.append(10)
|
| 489 |
+
|
| 490 |
+
fig.data[1].marker.color = node_colors
|
| 491 |
+
fig.data[1].marker.size = node_sizes
|
| 492 |
+
|
| 493 |
+
# Update node text
|
| 494 |
+
node_text = [
|
| 495 |
+
f"{node}<br># of connections: {len(list(G.neighbors(node)))}"
|
| 496 |
+
for node in sub_G.nodes()
|
| 497 |
+
]
|
| 498 |
+
fig.data[1].text = node_text
|
| 499 |
+
|
| 500 |
+
# Get connected nodes for dropdown
|
| 501 |
+
connected_nodes = sorted(list(G.neighbors(primary_node)))
|
| 502 |
+
else:
|
| 503 |
+
# If no primary node or it's not in the graph, show the full graph
|
| 504 |
+
connected_nodes = sorted(list(G.nodes()))
|
| 505 |
+
|
| 506 |
+
return fig, connected_nodes
|
| 507 |
+
|
| 508 |
+
def update_node_dropdown(highlight_concepts):
|
| 509 |
+
if not highlight_concepts or not db:
|
| 510 |
+
return gr.Dropdown(choices=[])
|
| 511 |
+
|
| 512 |
+
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 513 |
+
db_path = os.path.join(ROOT, DEFAULT_TABLES_DIR, db.db_path)
|
| 514 |
+
_, G, _, _ = generate_concept_cooccurrence_graph(db_path)
|
| 515 |
+
|
| 516 |
+
primary_node = highlight_concepts.split(",")[0].strip()
|
| 517 |
+
if primary_node in G.nodes():
|
| 518 |
+
connected_nodes = sorted(list(G.neighbors(primary_node)))
|
| 519 |
+
return gr.Dropdown(choices=connected_nodes)
|
| 520 |
+
else:
|
| 521 |
+
return gr.Dropdown(choices=[])
|
| 522 |
+
|
| 523 |
+
def update_selected_query(query_description):
|
| 524 |
+
for desc, sql in canned_queries:
|
| 525 |
+
if desc == query_description:
|
| 526 |
+
return sql
|
| 527 |
+
return ""
|
| 528 |
+
|
| 529 |
+
def submit_canned_query(query_description, limit, wrap):
|
| 530 |
+
for desc, sql in canned_queries:
|
| 531 |
+
if desc == query_description:
|
| 532 |
+
return query_db(sql, True, limit, wrap)
|
| 533 |
+
return pd.DataFrame({"Error": ["Selected query not found."]})
|
| 534 |
+
|
| 535 |
+
def copy_as_markdown(df):
|
| 536 |
+
return df.to_markdown()
|
| 537 |
+
|
| 538 |
+
def download_as_csv(df):
|
| 539 |
+
if df is None or df.empty:
|
| 540 |
+
return None
|
| 541 |
+
|
| 542 |
+
with tempfile.NamedTemporaryFile(
|
| 543 |
+
mode="w", delete=False, suffix=".csv"
|
| 544 |
+
) as temp_file:
|
| 545 |
+
df.to_csv(temp_file.name, index=False)
|
| 546 |
+
temp_file_path = temp_file.name
|
| 547 |
+
|
| 548 |
+
return temp_file_path
|
| 549 |
+
|
| 550 |
+
# def nl_to_sql(nl_query):
|
| 551 |
+
# # Placeholder function for natural language to SQL conversion
|
| 552 |
+
# return f"SELECT * FROM papers WHERE abstract LIKE '%{nl_query}%' LIMIT 10;"
|
| 553 |
+
|
| 554 |
+
db_dropdown.change(
|
| 555 |
+
load_database_with_graphs,
|
| 556 |
+
inputs=[db_dropdown],
|
| 557 |
+
outputs=[status, graph_output],
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
# db_dropdown.change(
|
| 561 |
+
# debounced_update_graph,
|
| 562 |
+
# inputs=[db_dropdown, tag_type_dropdown, highlight_input, node_dropdown, depth_slider],
|
| 563 |
+
# outputs=[graph_output, node_dropdown],
|
| 564 |
+
# )
|
| 565 |
+
|
| 566 |
+
tag_type_dropdown.change(
|
| 567 |
+
debounced_update_graph,
|
| 568 |
+
inputs=[
|
| 569 |
+
db_dropdown,
|
| 570 |
+
tag_type_dropdown,
|
| 571 |
+
highlight_input,
|
| 572 |
+
node_dropdown,
|
| 573 |
+
depth_slider,
|
| 574 |
+
],
|
| 575 |
+
outputs=[graph_output, node_dropdown],
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
highlight_input.change(
|
| 579 |
+
update_node_dropdown,
|
| 580 |
+
inputs=[highlight_input],
|
| 581 |
+
outputs=[node_dropdown],
|
| 582 |
+
)
|
| 583 |
+
# node_dropdown.change(
|
| 584 |
+
# debounced_update_graph,
|
| 585 |
+
# inputs=[db_dropdown, tag_type_dropdown, highlight_input, node_dropdown, depth_slider],
|
| 586 |
+
# outputs=[graph_output, node_dropdown],
|
| 587 |
+
# )
|
| 588 |
+
|
| 589 |
+
# depth_slider.change(
|
| 590 |
+
# debounced_update_graph,
|
| 591 |
+
# inputs=[db_dropdown, tag_type_dropdown, highlight_input, node_dropdown, depth_slider],
|
| 592 |
+
# outputs=[graph_output, node_dropdown],
|
| 593 |
+
# )
|
| 594 |
+
update_graph_button.click(
|
| 595 |
+
debounced_update_graph,
|
| 596 |
+
inputs=[
|
| 597 |
+
db_dropdown,
|
| 598 |
+
tag_type_dropdown,
|
| 599 |
+
highlight_input,
|
| 600 |
+
node_dropdown,
|
| 601 |
+
depth_slider,
|
| 602 |
+
],
|
| 603 |
+
outputs=[graph_output, node_dropdown],
|
| 604 |
+
)
|
| 605 |
+
canned_query_dropdown.change(
|
| 606 |
+
update_selected_query,
|
| 607 |
+
inputs=[canned_query_dropdown],
|
| 608 |
+
outputs=[selected_query],
|
| 609 |
+
)
|
| 610 |
+
canned_query_submit.click(
|
| 611 |
+
submit_canned_query,
|
| 612 |
+
inputs=[canned_query_dropdown, limit_input, wrap_checkbox],
|
| 613 |
+
outputs=output,
|
| 614 |
+
)
|
| 615 |
+
sql_submit.click(
|
| 616 |
+
query_db,
|
| 617 |
+
inputs=[sql_input, gr.Checkbox(value=True), limit_input, wrap_checkbox],
|
| 618 |
+
outputs=output,
|
| 619 |
+
)
|
| 620 |
+
copy_button.click(
|
| 621 |
+
copy_as_markdown,
|
| 622 |
+
inputs=[output],
|
| 623 |
+
outputs=[gr.Textbox(label="Markdown Output", show_copy_button=True)],
|
| 624 |
+
)
|
| 625 |
+
download_button.click(
|
| 626 |
+
download_as_csv, inputs=[output], outputs=[gr.File(label="CSV Output")]
|
| 627 |
+
)
|
| 628 |
+
# nl_query_submit.click(nl_to_sql, inputs=[nl_query_input], outputs=[sql_input])
|
| 629 |
+
|
| 630 |
+
return demo
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
demo = create_demo()
|
| 634 |
+
|
| 635 |
+
def close_db():
|
| 636 |
+
global db
|
| 637 |
+
if db is not None:
|
| 638 |
+
db.close()
|
| 639 |
+
db = None
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def launch():
|
| 643 |
+
print("Launching Gradio app...", flush=True)
|
| 644 |
+
shared_demo = demo.launch(share=True, prevent_thread_lock=True)
|
| 645 |
+
|
| 646 |
+
if isinstance(shared_demo, tuple):
|
| 647 |
+
if len(shared_demo) >= 2:
|
| 648 |
+
local_url, share_url = shared_demo[:2]
|
| 649 |
+
else:
|
| 650 |
+
local_url, share_url = shared_demo[0], "N/A"
|
| 651 |
+
else:
|
| 652 |
+
local_url = getattr(shared_demo, "local_url", "N/A")
|
| 653 |
+
share_url = getattr(shared_demo, "share_url", "N/A")
|
| 654 |
+
|
| 655 |
+
print(f"Local URL: {local_url}", flush=True)
|
| 656 |
+
print(f"Shareable link: {share_url}", flush=True)
|
| 657 |
+
|
| 658 |
+
print(
|
| 659 |
+
"Gradio app launched.",
|
| 660 |
+
flush=True,
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
# Keep the script running
|
| 664 |
+
demo.block_thread()
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
if __name__ == "__main__":
|
| 668 |
+
launch()
|
| 669 |
+
|
| 670 |
+
# Mount the Gradio app
|
| 671 |
+
# app = mount_gradio_app(app, demo, path="/")
|
| 672 |
+
|
| 673 |
+
# print(f"Shareable link: {demo.share_url}")
|
| 674 |
+
|
| 675 |
+
# @app.exception_handler(Exception)
|
| 676 |
+
# async def exception_handler(request: Request, exc: Exception):
|
| 677 |
+
# print(f"An error occurred: {str(exc)}")
|
| 678 |
+
# return {"error": str(exc)}
|
| 679 |
+
|
| 680 |
+
# @contextmanager
|
| 681 |
+
# def get_db_connection():
|
| 682 |
+
# global db
|
| 683 |
+
# conn = db.conn.cursor().connection
|
| 684 |
+
# try:
|
| 685 |
+
# yield conn
|
| 686 |
+
# finally:
|
| 687 |
+
# conn.close()
|
| 688 |
+
|
| 689 |
+
# @app.on_event("startup")
|
| 690 |
+
# async def startup_event():
|
| 691 |
+
# global db
|
| 692 |
+
# ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 693 |
+
# db_path = os.path.join(ROOT, DEFAULT_TABLES_DIR, get_available_databases()[0]) # Use the first available database
|
| 694 |
+
# db = ArxivDatabase(db_path)
|
| 695 |
+
# db.init_db()
|
| 696 |
+
|
| 697 |
+
# @app.on_event("shutdown")
|
| 698 |
+
# async def shutdown_event():
|
| 699 |
+
# if db is not None:
|
| 700 |
+
# db.close()
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
# if __name__ == "__main__":
|
| 704 |
+
# uvicorn.run(app, host="0.0.0.0", port=7860)
|
scripts/run_db_interface_basic.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import networkx as nx
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import plotly.graph_objects as go
|
| 7 |
+
import re
|
| 8 |
+
import sys
|
| 9 |
+
import sqlite3
|
| 10 |
+
import time
|
| 11 |
+
import uvicorn
|
| 12 |
+
|
| 13 |
+
from fastapi import FastAPI, Request
|
| 14 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 15 |
+
from gradio.routes import mount_gradio_app
|
| 16 |
+
from plotly.subplots import make_subplots
|
| 17 |
+
from tabulate import tabulate
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 22 |
+
if ROOT_DIR not in sys.path:
|
| 23 |
+
sys.path.insert(0, ROOT_DIR)
|
| 24 |
+
|
| 25 |
+
from scripts.create_db import ArxivDatabase
|
| 26 |
+
from config import (
|
| 27 |
+
DEFAULT_TABLES_DIR,
|
| 28 |
+
DEFAULT_INTERFACE_MODEL_ID,
|
| 29 |
+
COOCCURRENCE_QUERY,
|
| 30 |
+
canned_queries,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
app = FastAPI()
|
| 34 |
+
|
| 35 |
+
# Add CORS middleware
|
| 36 |
+
app.add_middleware(
|
| 37 |
+
CORSMiddleware,
|
| 38 |
+
allow_origins=["*"],
|
| 39 |
+
allow_credentials=True,
|
| 40 |
+
allow_methods=["*"],
|
| 41 |
+
allow_headers=["*"],
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
db: Optional[ArxivDatabase] = None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def truncate_or_wrap_text(text, max_length=50, wrap=False):
|
| 48 |
+
"""Truncate text to a maximum length, adding ellipsis if truncated, or wrap if specified."""
|
| 49 |
+
if wrap:
|
| 50 |
+
return "\n".join(
|
| 51 |
+
text[i : i + max_length] for i in range(0, len(text), max_length)
|
| 52 |
+
)
|
| 53 |
+
return text[:max_length] + "..." if len(text) > max_length else text
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def format_url(url):
|
| 57 |
+
"""Format URL to be more compact in the table."""
|
| 58 |
+
return url.split("/")[-1] if url.startswith("http") else url
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_available_databases():
|
| 62 |
+
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 63 |
+
tables_dir = os.path.join(ROOT, DEFAULT_TABLES_DIR)
|
| 64 |
+
return [f for f in os.listdir(tables_dir) if f.endswith(".db")]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def query_db(query, is_sql, limit=None, wrap=False):
|
| 68 |
+
global db
|
| 69 |
+
if db is None:
|
| 70 |
+
return pd.DataFrame({"Error": ["Please load a database first."]})
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
cursor = db.conn.cursor()
|
| 74 |
+
|
| 75 |
+
query = " ".join(query.strip().split("\n")).rstrip(";")
|
| 76 |
+
|
| 77 |
+
if limit is not None:
|
| 78 |
+
if "LIMIT" in query.upper():
|
| 79 |
+
# Replace existing LIMIT clause
|
| 80 |
+
query = re.sub(
|
| 81 |
+
r"LIMIT\s+\d+", f"LIMIT {limit}", query, flags=re.IGNORECASE
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
query += f" LIMIT {limit}"
|
| 85 |
+
|
| 86 |
+
cursor.execute(query)
|
| 87 |
+
|
| 88 |
+
column_names = [description[0] for description in cursor.description]
|
| 89 |
+
|
| 90 |
+
results = cursor.fetchall()
|
| 91 |
+
|
| 92 |
+
df = pd.DataFrame(results, columns=column_names)
|
| 93 |
+
|
| 94 |
+
for column in df.columns:
|
| 95 |
+
if df[column].dtype == "object":
|
| 96 |
+
df[column] = df[column].apply(
|
| 97 |
+
lambda x: (
|
| 98 |
+
format_url(x)
|
| 99 |
+
if column == "url"
|
| 100 |
+
else truncate_or_wrap_text(x, wrap=wrap)
|
| 101 |
+
)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return df
|
| 105 |
+
|
| 106 |
+
except sqlite3.Error as e:
|
| 107 |
+
return pd.DataFrame({"Error": [f"Database error: {str(e)}"]})
|
| 108 |
+
except Exception as e:
|
| 109 |
+
return pd.DataFrame({"Error": [f"An unexpected error occurred: {str(e)}"]})
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def generate_concept_cooccurrence_graph(db_path):
|
| 113 |
+
conn = sqlite3.connect(db_path)
|
| 114 |
+
df = pd.read_sql_query(COOCCURRENCE_QUERY, conn)
|
| 115 |
+
conn.close()
|
| 116 |
+
|
| 117 |
+
G = nx.from_pandas_edgelist(df, "concept1", "concept2", "co_occurrences")
|
| 118 |
+
pos = nx.spring_layout(G)
|
| 119 |
+
|
| 120 |
+
edge_x = []
|
| 121 |
+
edge_y = []
|
| 122 |
+
for edge in G.edges():
|
| 123 |
+
x0, y0 = pos[edge[0]]
|
| 124 |
+
x1, y1 = pos[edge[1]]
|
| 125 |
+
edge_x.extend([x0, x1, None])
|
| 126 |
+
edge_y.extend([y0, y1, None])
|
| 127 |
+
|
| 128 |
+
edge_trace = go.Scatter(
|
| 129 |
+
x=edge_x,
|
| 130 |
+
y=edge_y,
|
| 131 |
+
line=dict(width=0.5, color="#888"),
|
| 132 |
+
hoverinfo="none",
|
| 133 |
+
mode="lines",
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
node_x = [pos[node][0] for node in G.nodes()]
|
| 137 |
+
node_y = [pos[node][1] for node in G.nodes()]
|
| 138 |
+
|
| 139 |
+
node_trace = go.Scatter(
|
| 140 |
+
x=node_x,
|
| 141 |
+
y=node_y,
|
| 142 |
+
mode="markers",
|
| 143 |
+
hoverinfo="text",
|
| 144 |
+
marker=dict(
|
| 145 |
+
showscale=True,
|
| 146 |
+
colorscale="YlGnBu",
|
| 147 |
+
size=10,
|
| 148 |
+
colorbar=dict(
|
| 149 |
+
thickness=15,
|
| 150 |
+
title="Node Connections",
|
| 151 |
+
xanchor="left",
|
| 152 |
+
titleside="right",
|
| 153 |
+
),
|
| 154 |
+
),
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
node_adjacencies = []
|
| 158 |
+
node_text = []
|
| 159 |
+
for node, adjacencies in G.adjacency():
|
| 160 |
+
node_adjacencies.append(len(adjacencies))
|
| 161 |
+
node_text.append(f"{node}<br># of connections: {len(adjacencies)}")
|
| 162 |
+
|
| 163 |
+
node_trace.marker.color = node_adjacencies
|
| 164 |
+
node_trace.text = node_text
|
| 165 |
+
|
| 166 |
+
fig = go.Figure(
|
| 167 |
+
data=[edge_trace, node_trace],
|
| 168 |
+
layout=go.Layout(
|
| 169 |
+
title="Concept Co-occurrence Network",
|
| 170 |
+
titlefont_size=16,
|
| 171 |
+
showlegend=False,
|
| 172 |
+
hovermode="closest",
|
| 173 |
+
margin=dict(b=20, l=5, r=5, t=40),
|
| 174 |
+
annotations=[
|
| 175 |
+
dict(
|
| 176 |
+
text="",
|
| 177 |
+
showarrow=False,
|
| 178 |
+
xref="paper",
|
| 179 |
+
yref="paper",
|
| 180 |
+
x=0.005,
|
| 181 |
+
y=-0.002,
|
| 182 |
+
)
|
| 183 |
+
],
|
| 184 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 185 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 186 |
+
),
|
| 187 |
+
)
|
| 188 |
+
return fig
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# def load_database_with_graphs(db_name):
|
| 192 |
+
# global db
|
| 193 |
+
# ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 194 |
+
# db_path = os.path.join(ROOT, DEFAULT_TABLES_DIR, db_name)
|
| 195 |
+
# if not os.path.exists(db_path):
|
| 196 |
+
# return f"Database {db_name} does not exist.", None
|
| 197 |
+
# db = ArxivDatabase(db_path)
|
| 198 |
+
# db.init_db()
|
| 199 |
+
# if db.is_db_empty:
|
| 200 |
+
# return (
|
| 201 |
+
# f"Database loaded from {db_path}, but it is empty. Please populate it with data.",
|
| 202 |
+
# None,
|
| 203 |
+
# )
|
| 204 |
+
|
| 205 |
+
# # Generate graph
|
| 206 |
+
# graph = generate_concept_cooccurrence_graph(db_path)
|
| 207 |
+
|
| 208 |
+
# return f"Database loaded from {db_path}", graph
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def load_database_with_graphs(db_name):
|
| 212 |
+
global db
|
| 213 |
+
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 214 |
+
db_path = os.path.join(ROOT, DEFAULT_TABLES_DIR, db_name)
|
| 215 |
+
if not os.path.exists(db_path):
|
| 216 |
+
return f"Database {db_name} does not exist.", None
|
| 217 |
+
|
| 218 |
+
if db is None or db.db_path != db_path:
|
| 219 |
+
db = ArxivDatabase(db_path)
|
| 220 |
+
db.init_db()
|
| 221 |
+
|
| 222 |
+
if db.is_db_empty:
|
| 223 |
+
return (
|
| 224 |
+
f"Database loaded from {db_path}, but it is empty. Please populate it with data.",
|
| 225 |
+
None,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
graph = generate_concept_cooccurrence_graph(db_path)
|
| 229 |
+
return f"Database loaded from {db_path}", graph
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
css = """
|
| 233 |
+
#selected-query {
|
| 234 |
+
max-height: 100px;
|
| 235 |
+
overflow-y: auto;
|
| 236 |
+
white-space: pre-wrap;
|
| 237 |
+
word-break: break-word;
|
| 238 |
+
}
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def create_demo():
|
| 243 |
+
with gr.Blocks(css=css) as demo:
|
| 244 |
+
gr.Markdown("# ArXiv Database Query Interface")
|
| 245 |
+
|
| 246 |
+
with gr.Row():
|
| 247 |
+
db_dropdown = gr.Dropdown(
|
| 248 |
+
choices=get_available_databases(), label="Select Database"
|
| 249 |
+
)
|
| 250 |
+
load_db_btn = gr.Button("Load Database", size="sm")
|
| 251 |
+
status = gr.Textbox(label="Status")
|
| 252 |
+
|
| 253 |
+
with gr.Row():
|
| 254 |
+
graph_output = gr.Plot(label="Concept Co-occurrence Graph")
|
| 255 |
+
|
| 256 |
+
with gr.Row():
|
| 257 |
+
wrap_checkbox = gr.Checkbox(label="Wrap long text", value=False)
|
| 258 |
+
canned_query_dropdown = gr.Dropdown(
|
| 259 |
+
choices=[q[0] for q in canned_queries], label="Select Query", scale=3
|
| 260 |
+
)
|
| 261 |
+
limit_input = gr.Number(
|
| 262 |
+
label="Limit", value=10000, step=1, minimum=1, scale=1
|
| 263 |
+
)
|
| 264 |
+
selected_query = gr.Textbox(
|
| 265 |
+
label="Selected Query",
|
| 266 |
+
interactive=False,
|
| 267 |
+
scale=2,
|
| 268 |
+
show_label=True,
|
| 269 |
+
show_copy_button=True,
|
| 270 |
+
elem_id="selected-query",
|
| 271 |
+
)
|
| 272 |
+
canned_query_submit = gr.Button("Submit Query", size="sm", scale=1)
|
| 273 |
+
|
| 274 |
+
with gr.Row():
|
| 275 |
+
sql_input = gr.Textbox(label="Custom SQL Query", lines=3, scale=4)
|
| 276 |
+
sql_submit = gr.Button("Submit Custom SQL", size="sm", scale=1)
|
| 277 |
+
|
| 278 |
+
output = gr.DataFrame(label="Results", wrap=True)
|
| 279 |
+
|
| 280 |
+
def update_selected_query(query_description):
|
| 281 |
+
for desc, sql in canned_queries:
|
| 282 |
+
if desc == query_description:
|
| 283 |
+
return sql
|
| 284 |
+
return ""
|
| 285 |
+
|
| 286 |
+
def submit_canned_query(query_description, limit, wrap):
|
| 287 |
+
for desc, sql in canned_queries:
|
| 288 |
+
if desc == query_description:
|
| 289 |
+
return query_db(sql, True, limit, wrap)
|
| 290 |
+
return pd.DataFrame({"Error": ["Selected query not found."]})
|
| 291 |
+
|
| 292 |
+
load_db_btn.click(
|
| 293 |
+
load_database_with_graphs,
|
| 294 |
+
inputs=[db_dropdown],
|
| 295 |
+
outputs=[status, graph_output],
|
| 296 |
+
)
|
| 297 |
+
canned_query_dropdown.change(
|
| 298 |
+
update_selected_query,
|
| 299 |
+
inputs=[canned_query_dropdown],
|
| 300 |
+
outputs=[selected_query],
|
| 301 |
+
)
|
| 302 |
+
canned_query_submit.click(
|
| 303 |
+
submit_canned_query,
|
| 304 |
+
inputs=[canned_query_dropdown, limit_input, wrap_checkbox],
|
| 305 |
+
outputs=output,
|
| 306 |
+
)
|
| 307 |
+
sql_submit.click(
|
| 308 |
+
query_db,
|
| 309 |
+
inputs=[sql_input, gr.Checkbox(value=True), limit_input, wrap_checkbox],
|
| 310 |
+
outputs=output,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
return demo
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
demo = create_demo()
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def close_db():
|
| 320 |
+
global db
|
| 321 |
+
if db is not None:
|
| 322 |
+
db.close()
|
| 323 |
+
db = None
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
# def launch():
|
| 327 |
+
# print("Launching Gradio app...", flush=True)
|
| 328 |
+
# demo.launch(share=True)
|
| 329 |
+
# print(
|
| 330 |
+
# "Gradio app launched. If you don't see a URL above, there might be network restrictions.",
|
| 331 |
+
# flush=True,
|
| 332 |
+
# )
|
| 333 |
+
|
| 334 |
+
# close_db()
|
| 335 |
+
|
| 336 |
+
# if __name__ == "__main__":
|
| 337 |
+
# launch()
|
| 338 |
+
|
| 339 |
+
# Mount the Gradio app
|
| 340 |
+
app = mount_gradio_app(app, demo, path="/")
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
@app.exception_handler(Exception)
|
| 344 |
+
async def exception_handler(request: Request, exc: Exception):
|
| 345 |
+
print(f"An error occurred: {str(exc)}")
|
| 346 |
+
return {"error": str(exc)}
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
@app.on_event("startup")
|
| 350 |
+
async def startup_event():
|
| 351 |
+
# You can initialize the database here if needed
|
| 352 |
+
pass
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
@app.on_event("shutdown")
|
| 356 |
+
async def shutdown_event():
|
| 357 |
+
close_db()
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
if __name__ == "__main__":
|
| 361 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
scripts/run_db_interface_js.py
ADDED
|
File without changes
|