Spaces:
Runtime error
Runtime error
File size: 25,150 Bytes
5dc6194 7afe16d 5dc6194 aa2da35 5dc6194 a6817e4 2e1d3e2 8f8eff4 5dc6194 cb38292 5dc6194 7afe16d 5dc6194 cb38292 5dc6194 cb38292 5dc6194 cb38292 5dc6194 cb38292 5dc6194 4bafe5d ab5080f 4bafe5d 5dc6194 7afe16d 5dc6194 7afe16d 5dc6194 b180cc5 5dc6194 5c06ed8 5dc6194 b0c24ac 5dc6194 8f8eff4 e710ef9 8f8eff4 5dc6194 7afe16d 5dc6194 7afe16d 5dc6194 7afe16d 5dc6194 35aa5a4 5dc6194 a6817e4 5dc6194 e710ef9 5dc6194 ffb37ac df0e5b1 ffb37ac 5dc6194 cb38292 5dc6194 e710ef9 5dc6194 1d36a62 5dc6194 ffb37ac 5dc6194 e710ef9 5dc6194 e710ef9 5dc6194 0cdb40c 5dc6194 0cdb40c 5dc6194 e710ef9 5dc6194 e710ef9 5dc6194 8f2fefc e710ef9 5dc6194 0cdb40c 5dc6194 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 |
"""
Build an editable user profile based recommender.
- Read the users json and read their paper reps and keyphrases into memory.
- Read the candidates document (first stage retrieval) and
sentence embeddings into memory (second stage retrieval).
- Display the keyphrases to users and ask them to check it.
- Use the keyphrases and sentence embeddings to compute keyphrase values.
- Display the keyphrase selection box to users for retrieval.
- Use the selected keyphrases for performing retrieval.
"""
import copy
import json
import pickle
import re
import joblib
import os
import collections
import streamlit as st
import numpy as np
from scipy.spatial import distance
from scipy import special
from sklearn.neighbors import NearestNeighbors
from sentence_transformers import SentenceTransformer, models
import torch
import ot
# import seaborn as sns
# import matplotlib
# matplotlib.use('Agg')
# import matplotlib.pyplot as plt
# plt.rcParams['figure.dpi'] = 400
# plt.rcParams.update({'axes.labelsize': 'small'})
in_path = './data'
########################################
# BACKEND CODE #
########################################
def read_user(seed_json):
"""
Given the seed json for the user read the embedded
documents for the user.
:param seed_json:
:return:
"""
if 'doc_vectors_user' not in st.session_state:
uname = seed_json['username']
user_kps = seed_json['user_kps']
# Read document vectors.
doc_vectors_user = np.load(os.path.join(in_path, 'users', uname, f'embeds-{uname}-doc.npy'))
with open(os.path.join(in_path, 'users', uname, f'pid2idx-{uname}-doc.json'), 'r') as fp:
pid2idx_user = json.load(fp)
# Read sentence vectors.
pid2sent_vectors = joblib.load(os.path.join(in_path, 'users', uname, f'embeds-{uname}-sent.pickle'))
pid2sent_vectors_user = collections.OrderedDict()
for pid in sorted(pid2sent_vectors):
pid2sent_vectors_user[pid] = pid2sent_vectors[pid]
st.session_state['doc_vectors_user'] = doc_vectors_user
st.session_state['pid2idx_user'] = pid2idx_user
st.session_state['pid2sent_vectors_user'] = pid2sent_vectors_user
st.session_state['user_kps'] = user_kps
st.session_state['username'] = uname
st.session_state['seed_titles'] = []
for pd in seed_json['papers']:
norm_title = " ".join(pd['title'].lower().strip().split())
st.session_state.seed_titles.append(norm_title)
return doc_vectors_user, pid2idx_user, pid2sent_vectors, user_kps
else:
return st.session_state.doc_vectors_user, st.session_state.pid2idx_user, \
st.session_state.pid2sent_vectors_user, st.session_state.user_kps
def first_stage_ranked_docs(user_doc_queries, per_doc_to_rank, total_to_rank=2000):
"""
Return a list of ranked documents given a set of queries.
:param user_doc_queries: read the cached query embeddings
:return:
"""
if 'first_stage_ret_pids' not in st.session_state:
# read the document vectors
doc_vectors = np.load(os.path.join(in_path, 'cands', 'embeds-mlconfs-18_23.npy'))
with open(os.path.join(in_path, 'cands', 'pid2idx-mlconfs-18_23.pickle'), 'rb') as fp:
pid2idx_cands = pickle.load(fp)
idx2pid_cands = dict([(v, k) for k, v in pid2idx_cands.items()])
# index the vectors into a nearest neighbors structure
neighbors = NearestNeighbors(n_neighbors=per_doc_to_rank)
neighbors.fit(doc_vectors)
st.session_state['neighbors'] = neighbors
st.session_state['idx2pid_cands'] = idx2pid_cands
# Get the dists for all the query docs.
nearest_dists, nearest_idxs = neighbors.kneighbors(user_doc_queries, return_distance=True)
# Get the docs
top_pids = []
uniq_top = set()
for ranki in range(per_doc_to_rank): # Save papers by rank position for debugging.
for qi in range(user_doc_queries.shape[0]):
idx = nearest_idxs[qi, ranki]
pid = idx2pid_cands[idx]
if pid not in uniq_top: # Only save the unique papers. (ignore multiple retrievals of the same paper)
top_pids.append(pid)
uniq_top.add(pid)
top_pids = top_pids[:total_to_rank]
st.session_state['first_stage_ret_pids'] = top_pids
return top_pids
else:
return st.session_state.first_stage_ret_pids
def read_kp_encoder(in_path):
"""
Read the kp encoder model from disk.
:param in_path: string;
:return:
"""
if 'kp_enc_model' not in st.session_state:
word_embedding_model = models.Transformer('Sheshera/lace-kp-encoder-compsci',
max_seq_length=512)
# trained_model_fname = os.path.join(in_path, 'models', 'kp_encoder_cur_best.pt')
# if torch.cuda.is_available():
# saved_model = torch.load(trained_model_fname)
# else:
# saved_model = torch.load(trained_model_fname, map_location=torch.device('cpu'))
# word_embedding_model.auto_model.load_state_dict(saved_model)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode='mean')
kp_enc_model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
st.session_state['kp_enc_model'] = kp_enc_model
else:
return st.session_state.kp_enc_model
def read_candidates(in_path):
"""
Read candidate papers into pandas dataframe.
:param in_path:
:return:
"""
if 'pid2abstract' not in st.session_state:
with open(os.path.join(in_path, 'cands', 'abstract-mlconfs-18_23.pickle'), 'rb') as fp:
pid2abstract = pickle.load(fp)
# read the sentence vectors
pid2sent_vectors = joblib.load(os.path.join(in_path, 'cands', f'embeds-sent-mlconfs-18_23.pickle'))
st.session_state['pid2sent_vectors_cands'] = pid2sent_vectors
st.session_state['pid2abstract'] = pid2abstract
return pid2abstract, pid2sent_vectors
else:
return st.session_state.pid2abstract, st.session_state.pid2sent_vectors_cands
def get_kp_embeddings(profile_keyphrases):
"""
Embed the passed profike keyphrases
:param profile_keyphrases: list(string)
:return:
"""
kp_enc_model = st.session_state['kp_enc_model']
if 'kp_vectors_user' not in st.session_state:
kp_embeddings = kp_enc_model.encode(profile_keyphrases)
kp_vectors_user = collections.OrderedDict()
for i, kp in enumerate(profile_keyphrases):
kp_vectors_user[kp] = kp_embeddings[i, :]
st.session_state['kp_vectors_user'] = kp_vectors_user
return kp_vectors_user
else:
uncached_kps = [kp for kp in profile_keyphrases if kp not in st.session_state.kp_vectors_user]
kp_embeddings = kp_enc_model.encode(uncached_kps)
for i, kp in enumerate(uncached_kps):
st.session_state.kp_vectors_user[kp] = kp_embeddings[i, :]
return st.session_state.kp_vectors_user
def generate_profile_values(profile_keyphrases):
"""
- Read sentence embeddings
- Read profile keyphrase embeddings
- Compute alignment from sentences to keyphrases
- Barycenter project the keyphrases to sentences to get kp values
- Return the kp values
:param profile_keyphrases: list(string)
:return:
"""
kp_embeddings = get_kp_embeddings(profile_keyphrases)
# Read sentence embeddings.
user_seed_sentembeds = np.vstack(list(st.session_state.pid2sent_vectors_user.values()))
# Read keyphrase embeddings.
kps_embeds_flat = []
for kp in profile_keyphrases:
kps_embeds_flat.append(kp_embeddings[kp])
kps_embeds_flat = np.vstack(kps_embeds_flat)
# Compute transport plan from sentence to keyphrases.
pair_dists = distance.cdist(user_seed_sentembeds, kps_embeds_flat, 'euclidean')
a_distr = [1 / user_seed_sentembeds.shape[0]] * user_seed_sentembeds.shape[0]
b_distr = [1 / kps_embeds_flat.shape[0]] * kps_embeds_flat.shape[0]
# tplan = ot.bregman.sinkhorn_epsilon_scaling(a_distr, b_distr, pair_dists, 0.05, numItermax=2000)
tplan = ot.partial.entropic_partial_wasserstein(a_distr, b_distr, pair_dists, 0.05, m=0.8)
# Barycenter project the keyphrases to the sentences: len(profile_keyphraases) x embedding_dim
proj_kp_vectors = np.matmul(user_seed_sentembeds.T, tplan).T
norm = np.sum(tplan, axis=0)
kp_value_vectors = proj_kp_vectors/norm[:, np.newaxis]
# Return as a dict.
kp2valvectors = {}
for i, kp in enumerate(profile_keyphrases):
kp2valvectors[kp] = kp_value_vectors[i, :]
return kp2valvectors, tplan
def second_stage_ranked_docs(selected_query_kps, first_stage_pids, pid2abstract, pid2sent_reps_cand, to_rank=30):
"""
Return a list of ranked documents given a set of queries.
:param first_stage_pids: list(string)
:param pid2abstract: dict(pid: paperd)
:param query_paper_idxs: list(int);
:return:
"""
if len(selected_query_kps) < 3:
topk = len(selected_query_kps)
else: # Use 20% of keyphrases for scoring or 3 whichever is larger
topk = max(int(len(st.session_state.kp2val_vectors)*0.2), 3)
query_kp_values = np.vstack([st.session_state.kp2val_vectors[kp] for kp in selected_query_kps])
pid2topkdist = dict()
pid2kp_expls = collections.defaultdict(list)
for i, pid in enumerate(first_stage_pids):
sent_reps = pid2sent_reps_cand[pid]
pair_dists = distance.cdist(query_kp_values, sent_reps)
# Pick the topk unique profile concepts.
kp_ind = np.argsort(pair_dists.min(axis=1))[:topk]
sub_pair_dists = pair_dists[kp_ind, :]
# sub_kp_reps = query_kp_values[kp_ind, :]
# a_distr = special.softmax(-1*np.min(sub_pair_dists, axis=1))
# b_distr = [1 / sent_reps.shape[0]] * sent_reps.shape[0]
# tplan = ot.bregman.sinkhorn_epsilon_scaling(a_distr, b_distr, sub_pair_dists, 0.05)
# Use attention instead of OT for distance computation
tplan = special.softmax(-1 * sub_pair_dists)
wd = np.sum(sub_pair_dists * tplan)
# topk_dist = 0
# for k in range(topk):
# topk_dist += pair_dists[kp_ind[k], sent_ind[k]]
# pid2kp_expls[pid].append(selected_query_kps[kp_ind[k]])
# pid2topkdist[pid] = topk_dist
pid2topkdist[pid] = wd
top_pids = sorted(pid2topkdist, key=pid2topkdist.get)
# Get the docs
retrieved_papers = collections.OrderedDict()
for pid in top_pids:
# Exclude papers from the seed set in the result set.
norm_title = " ".join(pid2abstract[pid]['title'].lower().strip().split())
# The mlconf pid2abstract has braces in the titles sometimes - remove them
norm_title = re.sub('\{', '', norm_title)
norm_title = re.sub('\}', '', norm_title)
if norm_title in st.session_state.seed_titles:
continue
retrieved_papers[pid2abstract[pid]['title']] = {
'title': pid2abstract[pid]['title'],
'kp_explanations': pid2kp_expls[pid],
'abstract': pid2abstract[pid]['abstract'],
'author_names': pid2abstract[pid]['author_names'],
'url': pid2abstract[pid]['url'],
}
if len(retrieved_papers) == to_rank:
break
return retrieved_papers
########################################
# HELPER CODE #
########################################
def parse_input_kps(unparsed_kps, initial_user_kps):
"""
Function to parse the input keyphrase string.
:return:
"""
if unparsed_kps.strip():
kps = unparsed_kps.split(',')
parsed_user_kps = []
uniq_kps = set()
for kp in kps:
kp = kp.strip()
if kp not in uniq_kps:
parsed_user_kps.append(kp)
uniq_kps.add(kp)
else: # If its an empty string use the initial kps
parsed_user_kps = copy.copy(initial_user_kps)
return parsed_user_kps
# def plot_sent_kp_alignment(tplan, kp_labels, sent_labels):
# """
# Plot the sentence keyphrase alignment.
# :return:
# """
# fig, ax = plt.subplots()
# h = sns.heatmap(tplan.T, linewidths=.3, xticklabels=sent_labels,
# yticklabels=kp_labels, cmap='Blues')
# h.tick_params('y', labelsize=5)
# h.tick_params('x', labelsize=2)
# plt.tight_layout()
# return fig
def multiselect_title_formatter(title):
"""
Format the multi-select titles.
:param title: string
:return: string: formatted title
"""
ftitle = title.split()[:5]
return ' '.join(ftitle) + '...'
def format_abstract(paperd, to_display=3, markdown=True):
"""
Given a dict with title and abstract return
a formatted text for rendering with markdown.
:param paperd:
:param to_display:
:return:
"""
if len(paperd['abstract']) < to_display:
sents = ' '.join(paperd['abstract'])
else:
sents = ' '.join(paperd['abstract'][:to_display]) + '...'
try:
kp_expl = ', '.join(paperd['kp_explanations'])
except KeyError:
kp_expl = ''
title = re.sub('\{', '', paper['title'])
title = re.sub('\}', '', title)
sents = re.sub('\{', '', sents)
sents = re.sub('\}', '', sents)
if markdown:
try:
url = paperd['url']
par = '<p><b>Title</b>: <i><a href="{:s}">{:s}</a></i><br><b>Abstract</b>: {:s}<br><i>{:s}</i></p>'. \
format(url, title, sents, kp_expl)
except KeyError:
par = '<p><b>Title</b>: <i>{:s}</i><br><b>Abstract</b>: {:s}<br><i>{:s}</i></p>'. \
format(paper['title'], sents, kp_expl)
else:
par = 'Title: {:s}; Abstract: {:s}'.format(paper['title'], sents)
return par
def perp_result_json():
"""
Create a json with the results retrieved for each
iteration and the papers users choose to save at
each step.
:return:
"""
result_json = {}
# print(len(st.session_state.i_selections))
# print(len(st.session_state.i_resultps))
# print(len(st.session_state.i_savedps))
# print(st.session_state.tuning_i)
assert(len(st.session_state.i_selections) == len(st.session_state.i_resultps)
== len(st.session_state.i_savedps) == st.session_state.tuning_i)
for tuning_i, i_pselects, (_, i_savedps) in zip(range(st.session_state.tuning_i), st.session_state.i_selections,
st.session_state.i_savedps.items()):
iterdict = {
'iteration': tuning_i,
'profile_selections': copy.deepcopy(i_pselects),
'saved_papers': copy.deepcopy(list(i_savedps.items()))
}
result_json[tuning_i] = iterdict
result_json['condition'] = 'maple'
result_json['username'] = st.session_state.username
return json.dumps(result_json)
########################################
# APP CODE #
########################################
st.title('\U0001F341 Maple Paper Recommender \U0001F341')
st.markdown(
'\U0001F341 Maple \U0001F341 uses a seed set of authored papers to make paper recommendations from ML and NLP conferences: NeurIPS, ICLR, ICML, UAI, AISTATS, ACL*, and EMNLP from years 2018 to 2023.'
'\n1. :white_check_mark: Select your username on the left\n2. :eyes: Verify keyphrases inferred for the papers and click '
'"\U0001F9D1 Generate profile \U0001F9D1"\n3. :mag: Request recommendations\n4. :repeat: Tune recommendations')
# Load candidate documents and models.
pid2abstract_cands, pid2sent_vectors_cands = read_candidates(in_path)
kp_encoding_model = read_kp_encoder(in_path)
# Initialize the session state:
if 'tuning_i' not in st.session_state:
st.session_state['tuning_i'] = 0
# Save the profile keyphrases at every run
# (run is every time the script runs, iteration is every time recs are requested)
st.session_state['run_user_kps'] = []
# Save the profile selections at each iteration
st.session_state['i_selections'] = []
# dict of dicts: tuning_i: dict(paper_title: paper)
st.session_state['i_resultps'] = {}
# dict of dicts: tuning_i: dict(paper_title: saved or not bool)
st.session_state['i_savedps'] = collections.defaultdict(dict)
# Ask user to upload a set of seed query papers.
with st.sidebar:
available_users = os.listdir(os.path.join(in_path, 'users'))
available_users.sort()
available_users = (None,) + tuple(available_users)
# uploaded_file = st.file_uploader("\U0001F331 Upload seed papers",
# type='json',
# help='Upload a json file with titles and abstracts of the papers to '
# 'include in your profile.')
# st.markdown(f"<b style='color:red;'>Select your username from the drop-down:</b>", unsafe_allow_html=True)
selected_user = st.selectbox('Select your username from the drop-down',
available_users)
if selected_user is not None:
user_papers = json.load(
open(os.path.join(in_path, 'users', selected_user, f'seedset-{selected_user}-maple.json')))
# user_papers = json.load(uploaded_file)
# Read user data.
doc_vectors_user, pid2idx_user, pid2sent_vectors_user, user_kps = read_user(user_papers)
st.session_state.run_user_kps.append(copy.copy(user_kps))
display_profile_kps = ', '.join(user_kps)
# Perform first stage retrieval.
first_stage_ret_pids = first_stage_ranked_docs(user_doc_queries=doc_vectors_user, per_doc_to_rank=500)
with st.expander("Examine seed papers"):
st.markdown(f'**Initial profile keyphrases**:')
st.markdown(display_profile_kps)
st.markdown('**Seed papers**: {:d}'.format(len(user_papers['papers'])))
for paper in user_papers['papers']:
par = format_abstract(paperd=paper, to_display=6)
st.markdown(par, unsafe_allow_html=True)
st.markdown('\u2b50 Saved papers')
if selected_user is not None:
# Create a text box where users can see their profile keyphrases.
st.subheader('\U0001F4DD Seed paper keyphrases')
with st.form('profile_kps'):
input_kps = st.text_area(
'Add/remove keyphrases to fix redundancy, inaccuracy, incompleteness, or being nonsensical:',
display_profile_kps,
help='Edit the profile keyphrases if they are redundant, incomplete, nonsensical, '
'or dont accurately describe the seed papers. You can also add keyphrases to '
'capture aspects of the seed papers that the keyphrases dont currently capture.',
placeholder='If left empty initial profile keyphrases will be used...')
input_user_kps = parse_input_kps(unparsed_kps=input_kps, initial_user_kps=user_kps)
col1, col2, col3 = st.columns([1, 1, 1])
with col2:
generate_profile = st.form_submit_button('\U0001F9D1 Generate profile \U0001F9D1')
if generate_profile:
prev_run_input_kps = st.session_state.run_user_kps[-1]
if set(prev_run_input_kps) == set(input_user_kps): # If there is no change then use
if 'kp2val_vectors' in st.session_state: # This happens all the time except the first run.
kp2val_vectors = st.session_state.kp2val_vectors
user_tplan = st.session_state.user_tplan
else: # This happens on the first run.
with st.spinner(text="Generating profile..."):
kp2val_vectors, user_tplan = generate_profile_values(profile_keyphrases=input_user_kps)
st.session_state['kp2val_vectors'] = kp2val_vectors
st.session_state['user_tplan'] = user_tplan
else:
with st.spinner(text="Generating profile..."):
kp2val_vectors, user_tplan = generate_profile_values(profile_keyphrases=input_user_kps)
st.session_state['kp2val_vectors'] = kp2val_vectors
st.session_state['user_tplan'] = user_tplan
st.session_state.run_user_kps.append(copy.copy(input_user_kps))
# Create a multiselect dropdown
if 'kp2val_vectors' in st.session_state:
# with st.expander("Examine paper-keyphrase alignment"):
# user_tplan = st.session_state.user_tplan
# fig = plot_sent_kp_alignment(tplan=user_tplan, kp_labels=input_user_kps,
# sent_labels=range(user_tplan.shape[0]))
# st.write(fig)
st.subheader('\U0001F9D1 Profile keyphrases for ranking')
with st.form('profile_input'):
st.markdown("""
<style>
.stMultiSelect [data-baseweb=select] span{
max-width: 500px;
}
</style>
""", unsafe_allow_html=True)
profile_selections = st.multiselect(label='Include or exclude profile keyphrases to use for recommendations:',
default=input_user_kps, # Use all the values by default.
options=input_user_kps,
help='Items selected here will be used for creating your '
'recommended list')
col1, col2, col3 = st.columns([1, 1, 1])
with col2:
generate_recs = st.form_submit_button('\U0001F9ED Recommend papers \U0001F9ED')
# Use the uploaded files to create a ranked list of items.
if generate_recs and profile_selections:
# st.write('Generating recs...')
st.session_state.tuning_i += 1
st.session_state.i_selections.append(copy.deepcopy(profile_selections))
with st.spinner(text="Recommending papers..."):
top_papers = second_stage_ranked_docs(first_stage_pids=first_stage_ret_pids,
selected_query_kps=profile_selections,
pid2abstract=pid2abstract_cands,
pid2sent_reps_cand=pid2sent_vectors_cands,
to_rank=30)
st.session_state.i_resultps[st.session_state.tuning_i] = copy.deepcopy(top_papers)
# Read off from the result cache and allow users to save some papers.
if st.session_state.tuning_i in st.session_state.i_resultps:
# st.write('Waiting for selections...')
cached_top_papers = st.session_state.i_resultps[st.session_state.tuning_i]
for paper in cached_top_papers.values():
# This statement ensures correctness for when users unselect a previously selected item.
st.session_state.i_savedps[st.session_state.tuning_i][paper['title']] = False
dcol1, dcol2 = st.columns([1, 16])
with dcol1:
save_paper = st.checkbox('\u2b50', key=paper['title'])
with dcol2:
plabel = format_abstract(paperd=paper, to_display=2, markdown=True)
st.markdown(plabel, unsafe_allow_html=True)
with st.expander('See more..'):
full_abstract = ' '.join(paper['abstract'])
st.markdown(full_abstract, unsafe_allow_html=True)
if save_paper:
st.session_state.i_savedps[st.session_state.tuning_i].update({paper['title']: True})
# Print the saved papers across iterations in the sidebar.
with st.sidebar:
with st.expander("Examine saved papers"):
# st.write('Later write..')
# st.write(st.session_state.i_savedps)
for iteration, savedps in st.session_state.i_savedps.items():
st.markdown('Iteration: {:}'.format(iteration))
for papert, saved in savedps.items():
if saved:
fpapert = '<p style=color:Gray; ">- {:}</p>'.format(papert)
st.markdown('{:}'.format(fpapert), unsafe_allow_html=True)
if st.session_state.tuning_i > 0:
st.download_button('Download papers', perp_result_json(), mime='json',
help='Download the papers saved in the session.')
with st.expander("Copy saved papers to clipboard"):
st.write(json.loads(perp_result_json()))
|