Christof Bless
first working mvp
b23f8b6 unverified
raw
history blame
4.17 kB
import gradio as gr
import numpy as np
import pymupdf4llm
import spacy
from transformers import AutoTokenizer, AutoModel
from adapters import AutoAdapterModel
from extract_citations import fetch_citations_for_dois
from extract_embeddings import (
prune_contexts,
embed_abstracts,
embed_contexts,
restore_inverted_abstract,
calculate_distances
)
from extract_mentions import extract_citation_contexts
def extract_text(pdf_file):
if not pdf_file:
return "Please upload a PDF file."
try:
return pymupdf4llm.to_markdown(pdf_file)
except Exception as e:
return f"Error when processing PDF. {e}"
def extract_citations(doi):
try:
citations_data = fetch_citations_for_dois([doi])
except Exception as e:
return f"Please submit a valid DOI. {e}"
return citations_data
def get_cite_context_distance(pdf, doi):
# Load models
tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_base')
model = AutoAdapterModel.from_pretrained('allenai/specter2_base')
nlp = spacy.load("en_core_web_sm")
# fetch cited papers from OpenAlex
citations_data = fetch_citations_for_dois([doi])
# get markdown text from PDF file
text = extract_text(pdf.name)
# get the context around citation markers
citations = extract_citation_contexts(citations_data, text)
citations["pruned_contexts"], citations["known_tokens_fraction"] = prune_contexts(citations, nlp, tokenizer)
# embed the contexts
citation_context_embedding = embed_contexts(
citations[
(citations["known_tokens_fraction"] >= 0.7) &
(~citations["pruned_contexts"].isna())
]["pruned_contexts"].to_list(),
model,
tokenizer,
).detach().numpy()
citations_data = {entry["id"]:entry for cite in citations_data.values() for entry in cite}
# embed the abstract
citation_abstract_embedding = embed_abstracts(
[
{
"title":citations_data[cite]["title"],
"abstract": (
restore_inverted_abstract(
citations_data[cite]["abstract_inverted_index"]
)
if citations_data[cite]["abstract_inverted_index"] is not None
else None
)
}
for cite in citations["citation_id"].unique()
],
model,
tokenizer,
batch_size=4,
).detach().numpy()
print(citation_abstract_embedding.shape)
# calculate the distances
index_left = citations.index[
(citations["known_tokens_fraction"] >= 0.7) &
(~citations["pruned_contexts"].isna())
].tolist()
index_right = citations["citation_id"].unique().tolist()
indices = [
(index_left.index(i), index_right.index(cite_id))
if i in index_left else (None, None)
for i, cite_id in enumerate(citations["citation_id"])
]
distances = np.array(calculate_distances(citation_context_embedding, citation_abstract_embedding, indices))
results = []
for i, dist in enumerate(distances):
if not np.isnan(dist):
obj = {}
left_context = citations.left_context[i][-50:].replace('\n', '')
right_context = citations.right_context[i][:50].replace('\n', '')
obj["cite_context_short"] = f"...{left_context}{citations.mention[i]}{right_context}..."
obj["cited_paper"] = citations_data[citations.citation_id[i]]["title"]
obj["cited_paper_id"] = citations.citation_id[i]
obj["distance"] = dist
results.append(obj)
return {"score": np.nanmean(distances), "individual_citations": results}
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Citation Integrity Score")
doi_input = gr.Textbox(label="Enter DOI (optional)")
pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
output = gr.Textbox(label="Extracted Citations", lines=20)
submit_btn = gr.Button("Submit")
submit_btn.click(fn=get_cite_context_distance, inputs=[pdf_input, doi_input], outputs=output)
demo.launch()