sk / app.py
naqibhakimi's picture
initial
797a2e2
import contextlib
import streamlit as st
import streamlit.components.v1 as components
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import utils
from kb import KB
import wikipedia
MAX_TOPICS= 5
BUTTON_COLUMS = 4
st.header("Extracting a Knowledge Graph from text")
# Loading the model
def load_model():
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
return tokenizer, model
def generate_kb():
st_model_load = st.text('Loading NER model... It may take a while.')
tokenizer, model = load_model()
st.success('Model loaded!')
st_model_load.text("")
kb = utils.from_text_to_kb(' '.join(st.session_state['wiki_text']), model, tokenizer, "", verbose=True)
utils.save_network_html(kb, filename="networks/network.html")
st.session_state.kb_chart = "networks/network.html"
st.session_state.kb_text = kb.get_textual_representation()
st.session_state.error_url = None
def show_textbox():
if len(st.session_state['wiki_text']) != 0:
for i, t in enumerate(st.session_state['wiki_text']):
new_expander = st.expander(label=f"{t[:30]}...", expanded=(i==0))
with new_expander:
st.markdown(t)
def wiki_show_text(page_title):
with st.spinner(text="Fetching wiki page..."):
# print(st.session_state['wiki_suggestions'])
try:
page = wikipedia.page(title=page_title, auto_suggest=False)
st.session_state['wiki_text'].append(page.summary)
st.session_state['topics'].append(page_title.lower())
st.session_state['wiki_suggestions'].remove(page_title)
show_textbox()
except wikipedia.DisambiguationError as e:
with st.spinner(text="Woops, ambigious term, recalculating options..."):
st.session_state['wiki_suggestions'].remove(page_title)
temp = st.session_state['wiki_suggestions'] + e.options[:3]
st.session_state['wiki_suggestions'] = list(set(temp))
show_textbox()
except wikipedia.WikipediaException:
st.session_state['wiki_suggestions'].remove(page_title)
def wiki_add_text(term):
if len(st.session_state['wiki_text']) > MAX_TOPICS:
return
try:
page = wikipedia.page(title=term, auto_suggest=False)
extra_text = page.summary
st.session_state['wiki_text'].append(extra_text)
st.session_state['topics'].append(term.lower())
st.session_state['nodes'].remove(term)
except wikipedia.DisambiguationError as e:
with st.spinner(text="Woops, ambigious term, recalculating options..."):
st.session_state['nodes'].remove(term)
temp = st.session_state['nodes'] + e.options[:3]
st.session_state['nodes'] = list(set(temp))
except wikipedia.WikipediaException as e:
st.session_state['nodes'].remove(term)
def reset_thread():
st.session_state['wiki_text'] = []
st.session_state['topics'] = []
st.session_state['nodes'] = []
st.session_state['has_run_wiki'] = False
st.session_state['wiki_suggestions'] = []
st.session_state['html_wiki'] = ''
def show_wiki_hub_page():
cols = st.columns([7, 1])
b_cols = st.columns([2, 1.2, 8])
with cols[0]:
st.text_input("Search", on_change=wiki_show_suggestion, key="text", value="graphs, are, awesome")
with cols[1]:
st.text('')
st.text('')
st.button("Search", on_click=wiki_show_suggestion, key="show_suggestion_key")
with b_cols[0]:
st.button("Generate KB", on_click=generate_kb)
with b_cols[1]:
st.button("Reset", on_click=reset_thread)
def wiki_show_suggestion():
with st.spinner(text="Fetching wiki topics..."):
text = st.session_state.text
if (text is not None) and (text != ""):
subjects = text.split(",")[:MAX_TOPICS]
for subj in subjects:
st.session_state['wiki_suggestions'] += wikipedia.search(subj, results = 3)
show_wiki_suggestions_buttons()
def show_wiki_suggestions_buttons():
if len(st.session_state['wiki_suggestions']) == 0:
return
num_buttons = len(st.session_state['wiki_suggestions'])
# st.session_state['wiki_suggestions'] = list(set(st.session_state['wiki_suggestions']))
num_cols = num_buttons if 0 < num_buttons < BUTTON_COLUMS else BUTTON_COLUMS
columns = st.columns([1] * num_cols )
for q in range(1 + num_buttons//num_cols):
for i, (c, s) in enumerate(zip(columns, st.session_state['wiki_suggestions'][q*num_cols: (q+1)*num_cols])):
with c:
with contextlib.suppress(Exception):
st.button(s, on_click=wiki_show_text, args=(s,), key=str(i)+s+"wiki_suggestion")
def init_variables():
if 'wiki_suggestions' not in st.session_state:
st.session_state['wiki_text'] = []
st.session_state['topics'] = []
st.session_state['nodes'] = []
st.session_state['has_run_wiki'] = True
st.session_state['wiki_suggestions'] = []
st.session_state['html_wiki'] = ''
init_variables()
show_wiki_hub_page()
# kb chart session state
if 'kb_chart' not in st.session_state:
st.session_state.kb_chart = None
if 'kb_text' not in st.session_state:
st.session_state.kb_text = None
if 'error_url' not in st.session_state:
st.session_state.error_url = None
# show graph
if st.session_state.error_url:
st.markdown(st.session_state.error_url)
elif st.session_state.kb_chart:
with st.container():
st.subheader("Generated KB")
st.markdown("*You can interact with the graph and zoom.*")
html_source_code = open(st.session_state.kb_chart, 'r', encoding='utf-8').read()
components.html(html_source_code, width=700, height=700)
st.markdown(st.session_state.kb_text)