Yozhikoff's picture
Update app.py
be36e61
import re
import urllib
import urllib.request
import xml.etree.ElementTree as ET
import gradio as gr
import torch
from transformers import pipeline
classifier = pipeline(model="Yozhikoff/arxiv-topics-distilbert-base-cased")
def get_arxiv_title_and_abstract(link):
try:
# Validate the arxiv link
pattern = r'^https?://arxiv.org/(abs|pdf)/(\d{4}\.\d{4,5})(\.pdf)?$'
match = re.match(pattern, link)
if not match:
raise ValueError('Invalid arxiv link')
# Construct the arxiv API URL
arxiv_id = match.group(2)
api_url = f'http://export.arxiv.org/api/query?id_list={arxiv_id}'
# Send a request to the arxiv API
response = urllib.request.urlopen(api_url)
xml_data = response.read()
# Parse the XML data
root = ET.fromstring(xml_data)
entry = root.find('{http://www.w3.org/2005/Atom}entry')
title = entry.find('{http://www.w3.org/2005/Atom}title').text
summary = entry.find('{http://www.w3.org/2005/Atom}summary').text
return title, summary
except:
raise gr.Error('Invalid arXiv URL!')
def classify_paper(title, abstract):
if title == '' and abstract == '':
raise gr.Error('Fill Title or/and Abstract')
text = f"TITLE\n{title}\n\nABSTRACT\n{abstract}"
item = classifier.tokenizer(text)
input_tensor = torch.tensor(item['input_ids'])[None]
logits = classifier.model(input_tensor).logits[0]
preds = torch.sigmoid(logits).detach().cpu().numpy()
result = {}
for num, prob in enumerate(preds):
if prob < 0.25:
continue
if classifier.model.config.id2label[num] in result:
if result[classifier.model.config.id2label[num]] > prob:
continue
result[classifier.model.config.id2label[num]] = float(prob)
return result
with gr.Blocks(title='Paper classifier') as demo:
gr.Markdown('# Paper Topic Classifier')
with gr.Row():
with gr.Column():
gr.Markdown('## Inputs')
gr.Markdown('#### Please enter an arXiv link **OR** fill title and abstract manually')
arxiv_link = gr.Textbox(label="Arxiv link", placeholder="https://arxiv.org/abs/1706.03762")
b1 = gr.Button("Parse Link")
title = gr.Textbox(label="Paper title", placeholder="Title text")
abstract = gr.Textbox(label="Paper abstract", placeholder="Abstract text")
b2 = gr.Button("Classify Paper", variant='primary')
b1.click(fn=get_arxiv_title_and_abstract, inputs=arxiv_link, outputs=[title, abstract], api_name="parse")
with gr.Column():
gr.Markdown('## Topics')
gr.Markdown('## ')
gr.Markdown('## ')
out = gr.Label(label="Topics")
b2.click(classify_paper, inputs=[title, abstract], outputs=out)
gr.Markdown('## Examples')
gr.Examples(
examples=[['https://arxiv.org/abs/1706.03762'], ['https://arxiv.org/abs/2304.06718'], ['https://arxiv.org/abs/1307.0058']],
inputs=arxiv_link,
outputs=[title, abstract],
fn=get_arxiv_title_and_abstract,
cache_examples=True,
)
demo.launch()
with gr.Blocks(title='Paper classifier') as demo:
gr.Markdown('# Paper Topic Classifier')
with gr.Row():
with gr.Column():
gr.Markdown('## Inputs')
gr.Markdown('#### Please enter an arXiv link **OR** fill title and abstract manually')
arxiv_link = gr.Textbox(label="Arxiv link", placeholder="https://arxiv.org/abs/1706.03762")
b1 = gr.Button("Parse Link")
title = gr.Textbox(label="Paper title", placeholder="Title text")
abstract = gr.Textbox(label="Paper abstract", placeholder="Abstract text")
b2 = gr.Button("Classify Paper", variant='primary')
b1.click(fn=get_arxiv_title_and_abstract, inputs=arxiv_link, outputs=[title, abstract], api_name="parse")
with gr.Column():
gr.Markdown('## Topics')
gr.Markdown('## ')
gr.Markdown('## ')
out = gr.Label(label="Topics")
b2.click(classify_paper, inputs=[title, abstract], outputs=out)
gr.Markdown('## Examples')
gr.Examples(
examples=[['https://arxiv.org/abs/1706.03762'], ['https://arxiv.org/abs/2304.06718'], ['https://arxiv.org/abs/1307.0058']],
inputs=arxiv_link,
outputs=[title, abstract],
fn=get_arxiv_title_and_abstract,
cache_examples=True,
)
demo.launch()