Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| st.set_page_config(layout="wide") | |
| import numpy as np | |
| from abc import ABC, abstractmethod | |
| from typing import List, Dict, Any, Tuple | |
| from collections import defaultdict | |
| from tqdm import tqdm | |
| import pandas as pd | |
| from datetime import datetime, date | |
| from datasets import load_dataset, load_from_disk | |
| from collections import Counter | |
| import yaml, json, requests, sys, os, time | |
| import concurrent.futures | |
| from langchain import hub | |
| from langchain_openai import ChatOpenAI as openai_llm | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain_core.runnables import RunnableConfig, RunnablePassthrough, RunnableParallel | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_community.callbacks import StreamlitCallbackHandler | |
| from langchain_community.utilities import DuckDuckGoSearchAPIWrapper | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.document_loaders import TextLoader | |
| from langchain.agents import create_react_agent, Tool, AgentExecutor | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain.callbacks import FileCallbackHandler | |
| from langchain.callbacks.manager import CallbackManager | |
| import instructor | |
| from pydantic import BaseModel, Field | |
| from typing import List, Literal | |
| from nltk.corpus import stopwords | |
| import nltk | |
| from openai import OpenAI | |
| # import anthropic | |
| import cohere | |
| import faiss | |
| import spacy | |
| from string import punctuation | |
| import pytextrank | |
| nlp = spacy.load("en_core_web_sm") | |
| nlp.add_pipe("textrank") | |
| try: | |
| stopwords.words('english') | |
| except: | |
| nltk.download('stopwords') | |
| stopwords.words('english') | |
| from bokeh.plotting import figure | |
| from bokeh.models import ColumnDataSource | |
| from bokeh.io import output_notebook | |
| from bokeh.palettes import Spectral5 | |
| from bokeh.transform import linear_cmap | |
| ts = time.time() | |
| # anthropic_key = st.secrets["anthropic_key"] | |
| openai_key = st.secrets["openai_key"] | |
| cohere_key = st.secrets['cohere_key'] | |
| gen_llm = openai_llm(temperature=0,model_name='gpt-4o-mini', openai_api_key = openai_key) | |
| consensus_client = instructor.patch(OpenAI(api_key=openai_key)) | |
| embed_client = OpenAI(api_key = openai_key) | |
| embed_model = "text-embedding-3-small" | |
| embeddings = OpenAIEmbeddings(model = embed_model, api_key = openai_key) | |
| st.image('local_files/pathfinder_logo.png') | |
| st.expander("About", expanded=False).write( | |
| """ | |
| Pathfinder v2.0 is a framework for searching and visualizing astronomy papers on the [arXiv](https://arxiv.org/) and [ADS](https://ui.adsabs.harvard.edu/) using the context | |
| sensitivity from modern large language models (LLMs) to better parse patterns in paper contexts. | |
| This tool was built during the [JSALT workshop](https://www.clsp.jhu.edu/2024-jelinek-summer-workshop-on-speech-and-language-technology/) to do awesome things. | |
| **π Use the sidebar to tweak the search parameters to get better results**. | |
| ### Tool summary: | |
| - Please wait while the initial data loads and compiles, this takes about a minute initially. | |
| This is not meant to be a replacement to existing tools like the | |
| [ADS](https://ui.adsabs.harvard.edu/), | |
| [arxivsorter](https://www.arxivsorter.org/), semantic search or google scholar, but rather a supplement to find papers | |
| that otherwise might be missed during a literature survey. | |
| It is trained on astro-ph (astrophysics of galaxies) papers up to last-year-ish mined from arxiv and supplemented with ADS metadata, | |
| if you are interested in extending it please reach out! | |
| Also add: feedback form, socials, literature, contact us, copyright, collaboration, etc. | |
| The image below shows a representation of all the astro-ph.GA papers that can be explored in more detail | |
| using the `Arxiv embedding` page. The papers tend to cluster together by similarity, and result in an | |
| atlas that shows well studied (forests) and currently uncharted areas (water). | |
| """ | |
| ) | |
| # ---------------- get data and set up session state --------------------------- | |
| if 'arxiv_corpus' not in st.session_state: | |
| with st.spinner('loading data...'): | |
| try: | |
| arxiv_corpus = load_from_disk('data/') | |
| except: | |
| st.write('downloading data') | |
| arxiv_corpus = load_dataset('kiyer/pathfinder_arxiv_data',split='train') | |
| # arxiv_corpus = load_dataset('kiyer/pathfinder_arxiv_data_galaxy',split='train') | |
| arxiv_corpus.save_to_disk('data/') | |
| arxiv_corpus.add_faiss_index('embed') | |
| st.session_state.arxiv_corpus = arxiv_corpus | |
| st.toast('loaded arxiv corpus') | |
| else: | |
| arxiv_corpus = st.session_state.arxiv_corpus | |
| if 'ids' not in st.session_state: | |
| st.session_state.ids = arxiv_corpus['ads_id'] | |
| st.session_state.titles = arxiv_corpus['title'] | |
| st.session_state.abstracts = arxiv_corpus['abstract'] | |
| st.session_state.cites = arxiv_corpus['cites'] | |
| st.session_state.years = arxiv_corpus['date'] | |
| st.session_state.kws = arxiv_corpus['keywords'] | |
| st.session_state.ads_kws = arxiv_corpus['ads_keywords'] | |
| st.session_state.bibcode = arxiv_corpus['bibcode'] | |
| st.session_state.umap_x = arxiv_corpus['umap_x'] | |
| st.session_state.umap_y = arxiv_corpus['umap_y'] | |
| st.toast('done caching. time taken: %.2f sec' %(time.time()-ts)) | |
| #--------------------------------------------------------------- | |
| # A hack to "clear" the previous result when submitting a new prompt. This avoids | |
| # the "previous run's text is grayed-out but visible during rerun" Streamlit behavior. | |
| class DirtyState: | |
| NOT_DIRTY = "NOT_DIRTY" | |
| DIRTY = "DIRTY" | |
| UNHANDLED_SUBMIT = "UNHANDLED_SUBMIT" | |
| def get_dirty_state() -> str: | |
| return st.session_state.get("dirty_state", DirtyState.NOT_DIRTY) | |
| def set_dirty_state(state: str) -> None: | |
| st.session_state["dirty_state"] = state | |
| def with_clear_container(submit_clicked: bool) -> bool: | |
| if get_dirty_state() == DirtyState.DIRTY: | |
| if submit_clicked: | |
| set_dirty_state(DirtyState.UNHANDLED_SUBMIT) | |
| st.experimental_rerun() | |
| else: | |
| set_dirty_state(DirtyState.NOT_DIRTY) | |
| if submit_clicked or get_dirty_state() == DirtyState.UNHANDLED_SUBMIT: | |
| set_dirty_state(DirtyState.DIRTY) | |
| return True | |
| return False | |
| # ---------------- define embedding retrieval systems -------------------------- | |
| def get_keywords(text): | |
| result = [] | |
| pos_tag = ['PROPN', 'ADJ', 'NOUN'] | |
| doc = nlp(text.lower()) | |
| for token in doc: | |
| if(token.text in nlp.Defaults.stop_words or token.text in punctuation): | |
| continue | |
| if(token.pos_ in pos_tag): | |
| result.append(token.text) | |
| return result | |
| def parse_doc(text, nret = 10): | |
| local_kws = [] | |
| doc = nlp(text) | |
| # examine the top-ranked phrases in the document | |
| for phrase in doc._.phrases[:nret]: | |
| # print(phrase.text) | |
| local_kws.append(phrase.text) | |
| return local_kws | |
| class EmbeddingRetrievalSystem(): | |
| def __init__(self, weight_citation = False, weight_date = False, weight_keywords = False): | |
| self.ids = st.session_state.ids | |
| self.years = st.session_state.years | |
| self.abstract = st.session_state.abstracts | |
| self.client = OpenAI(api_key = openai_key) | |
| self.embed_model = "text-embedding-3-small" | |
| self.dataset = st.session_state.arxiv_corpus | |
| self.kws = st.session_state.kws | |
| self.cites = st.session_state.cites | |
| self.weight_citation = weight_citation | |
| self.weight_date = weight_date | |
| self.weight_keywords = weight_keywords | |
| self.id_to_index = {self.ids[i]: i for i in range(len(self.ids))} | |
| # self.citation_filter = CitationFilter(self.dataset) | |
| # self.date_filter = DateFilter(self.dataset['date']) | |
| # self.keyword_filter = KeywordFilter(corpus=self.dataset, remove_capitals=True) | |
| def parse_date(self, id): | |
| # indexval = np.where(self.ids == id)[0][0] | |
| indexval = id | |
| return self.years[indexval] | |
| def make_embedding(self, text): | |
| str_embed = self.client.embeddings.create(input = [text], model = self.embed_model).data[0].embedding | |
| return str_embed | |
| def embed_batch(self, texts: List[str]) -> List[np.ndarray]: | |
| embeddings = self.client.embeddings.create(input=texts, model=self.embed_model).data | |
| return [np.array(embedding.embedding, dtype=np.float32) for embedding in embeddings] | |
| def get_query_embedding(self, query): | |
| return self.make_embedding(query) | |
| def analyze_temporal_query(self, query): | |
| return | |
| def calc_faiss(self, query_embedding, top_k = 100): | |
| # xq = query_embedding.reshape(-1,1).T.astype('float32') | |
| # D, I = self.index.search(xq, top_k) | |
| # return I[0], D[0] | |
| tmp = self.dataset.search('embed', query_embedding, k=top_k) | |
| return [tmp.indices, tmp.scores] | |
| def rank_and_filter(self, query, query_embedding, query_date, top_k = 10, return_scores=False, time_result=None): | |
| # st.write('status') | |
| # st.write('toggles', self.toggles) | |
| # st.write('question_type', self.question_type) | |
| # st.write('rag method', self.rag_method) | |
| # st.write('gen method', self.gen_method) | |
| self.weight_keywords = self.toggles["Keyword weighting"] | |
| self.weight_date = self.toggles["Time weighting"] | |
| self.weight_citation = self.toggles["Citation weighting"] | |
| topk_indices, similarities = self.calc_faiss(np.array(query_embedding), top_k = 1000) | |
| similarities = 1/similarities # converting from a distance (less is better) to a similarity (more is better) | |
| query_kws = get_keywords(query) | |
| input_kws = self.query_input_keywords | |
| query_kws = query_kws + input_kws | |
| self.query_kws = query_kws | |
| if self.weight_keywords == True: | |
| sub_kws = [self.kws[i] for i in topk_indices] | |
| kw_weight = np.zeros((len(topk_indices),)) + 0.1 | |
| for k in query_kws: | |
| for i in (range(len(topk_indices))): | |
| for j in range(len(sub_kws[i])): | |
| if k.lower() in sub_kws[i][j].lower(): | |
| kw_weight[i] = kw_weight[i] + 0.1 | |
| # print(i, k, sub_kws[i][j]) | |
| # kw_weight = kw_weight**0.36 / np.amax(kw_weight**0.36) | |
| kw_weight = kw_weight / np.amax(kw_weight) | |
| else: | |
| kw_weight = np.ones((len(topk_indices),)) | |
| if self.weight_date == True: | |
| sub_dates = [self.years[i] for i in topk_indices] | |
| date = datetime.now().date() | |
| date_diff = np.array([((date - i).days / 365.) for i in sub_dates]) | |
| # age_weight = (1 + np.exp(date_diff/2.1))**(-1) + 0.5 | |
| age_weight = (1 + np.exp(date_diff/0.7))**(-1) | |
| age_weight = age_weight / np.amax(age_weight) | |
| else: | |
| age_weight = np.ones((len(topk_indices),)) | |
| if self.weight_citation == True: | |
| # st.write('weighting by citations') | |
| sub_cites = np.array([self.cites[i] for i in topk_indices]) | |
| temp = sub_cites.copy() | |
| temp[sub_cites > 300] = 300. | |
| cite_weight = (1 + np.exp((300-temp)/42.0))**(-1.) | |
| cite_weight = cite_weight / np.amax(cite_weight) | |
| else: | |
| cite_weight = np.ones((len(topk_indices),)) | |
| similarities = similarities * (kw_weight) * (age_weight) * (cite_weight) | |
| # if self.weight_keywords: | |
| # keyword_matches = self.keyword_filter.filter(query) | |
| # self.query_kws = keyword_matches | |
| # kw_indices = np.zeros_like(similarities) | |
| # for s in keyword_matches: | |
| # if self.id_to_index[s] in topk_indices: | |
| # # print('yes', self.id_to_index[s], topk_indices[np.where(topk_indices == self.id_to_index[s])[0]]) | |
| # similarities[np.where(topk_indices == self.id_to_index[s])[0]] = similarities[np.where(topk_indices == self.id_to_index[s])[0]] * 10. | |
| # similarities = similarities / 10. | |
| filtered_results = [[topk_indices[i], similarities[i]] for i in range(len(similarities))] | |
| top_results = sorted(filtered_results, key=lambda x: x[1], reverse=True)[:top_k] | |
| if return_scores: | |
| return {doc[0]: doc[1] for doc in top_results} | |
| # Only keep the document IDs | |
| top_results = [doc[0] for doc in top_results] | |
| return top_results | |
| def retrieve(self, query, top_k, time_result=None, query_date = None, return_scores = False): | |
| query_embedding = self.get_query_embedding(query) | |
| # Judge time relevance | |
| if time_result is None: | |
| if self.weight_date: | |
| time_result, time_taken = self.analyze_temporal_query(query, self.anthropic_client) | |
| else: | |
| time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None} | |
| top_results = self.rank_and_filter(query, | |
| query_embedding, | |
| query_date, | |
| top_k, | |
| return_scores = return_scores, | |
| time_result = time_result) | |
| return top_results | |
| class HydeRetrievalSystem(EmbeddingRetrievalSystem): | |
| def __init__(self, generation_model: str = "claude-3-haiku-20240307", | |
| embedding_model: str = "text-embedding-3-small", | |
| temperature: float = 0.5, | |
| max_doclen: int = 500, | |
| generate_n: int = 1, | |
| embed_query = True, | |
| conclusion = False, **kwargs): | |
| # Handle the kwargs for the superclass init -- filters/citation weighting | |
| super().__init__(**kwargs) | |
| if max_doclen * generate_n > 8191: | |
| raise ValueError("Too many tokens. Please reduce max_doclen or generate_n.") | |
| self.embedding_model = embedding_model | |
| self.generation_model = generation_model | |
| # HYPERPARAMETERS | |
| self.temperature = temperature # generation temperature | |
| self.max_doclen = max_doclen # max tokens for generation | |
| self.generate_n = generate_n # how many documents | |
| self.embed_query = embed_query # embed the query vector? | |
| self.conclusion = conclusion # generate conclusion as well? | |
| # self.anthropic_key = anthropic_key | |
| # self.generation_client = anthropic.Anthropic(api_key = self.anthropic_key) | |
| self.generation_client = openai_llm(temperature=0,model_name='gpt-4o-mini', openai_api_key = openai_key) | |
| def retrieve(self, query: str, top_k: int = 10, return_scores = False, time_result = None) -> List[Tuple[str, str, float]]: | |
| if time_result is None: | |
| if self.weight_date: time_result, time_taken = analyze_temporal_query(query, self.anthropic_client) | |
| else: time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None} | |
| docs = self.generate_docs(query) | |
| st.expander('Abstract generated with hyde', expanded=False).write(docs) | |
| doc_embeddings = self.embed_docs(docs) | |
| if self.embed_query: | |
| query_emb = self.embed_docs([query])[0] | |
| doc_embeddings.append(query_emb) | |
| embedding = np.mean(np.array(doc_embeddings), axis = 0) | |
| top_results = self.rank_and_filter(query, embedding, query_date=None, top_k = top_k, return_scores = return_scores, time_result = time_result) | |
| return top_results | |
| def generate_doc(self, query: str): | |
| prompt = """You are an expert astronomer. Given a scientific query, generate the abstract of an expert-level research paper | |
| that answers the question. Stick to a maximum length of {} tokens and return just the text of the abstract and conclusion. | |
| Do not include labels for any section. Use research-specific jargon.""".format(self.max_doclen) | |
| # st.write('invoking hyde generation') | |
| # message = self.generation_client.messages.create( | |
| # model = self.generation_model, | |
| # max_tokens = self.max_doclen, | |
| # temperature = self.temperature, | |
| # system = prompt, | |
| # messages=[{ "role": "user", | |
| # "content": [{"type": "text", "text": query,}] }] | |
| # ) | |
| # return message.content[0].text | |
| messages = [("system",prompt,),("human", query),] | |
| return self.generation_client.invoke(messages).content | |
| def generate_docs(self, query: str): | |
| docs = [] | |
| for i in range(self.generate_n): | |
| # st.write('invoking hyde generation2') | |
| docs.append(self.generate_doc(query)) | |
| # with concurrent.futures.ThreadPoolExecutor() as executor: | |
| # st.write('invoking hyde generation2') | |
| # future_to_query = {executor.submit(self.generate_doc, query): query for i in range(self.generate_n)} | |
| # for future in concurrent.futures.as_completed(future_to_query): | |
| # query = future_to_query[future] | |
| # try: | |
| # data = future.result() | |
| # docs.append(data) | |
| # except Exception as exc: | |
| # pass | |
| return docs | |
| def embed_docs(self, docs: List[str]): | |
| return self.embed_batch(docs) | |
| class HydeCohereRetrievalSystem(HydeRetrievalSystem): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.cohere_key = cohere_key | |
| self.cohere_client = cohere.Client(self.cohere_key) | |
| def retrieve(self, query: str, | |
| top_k: int = 10, | |
| rerank_top_k: int = 250, | |
| return_scores = False, time_result = None, | |
| reweight = False) -> List[Tuple[str, str, float]]: | |
| if time_result is None: | |
| if self.weight_date: time_result, time_taken = analyze_temporal_query(query, self.anthropic_client) | |
| else: time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None} | |
| top_results = super().retrieve(query, top_k = rerank_top_k, time_result = time_result) | |
| # doc_texts = self.get_document_texts(top_results) | |
| # docs_for_rerank = [f"Abstract: {doc['abstract']}\nConclusions: {doc['conclusions']}" for doc in doc_texts] | |
| docs_for_rerank = [self.abstract[i] for i in top_results] | |
| if len(docs_for_rerank) == 0: | |
| return [] | |
| reranked_results = self.cohere_client.rerank( | |
| query=query, | |
| documents=docs_for_rerank, | |
| model='rerank-english-v3.0', | |
| top_n=top_k | |
| ) | |
| final_results = [] | |
| for result in reranked_results.results: | |
| doc_id = top_results[result.index] | |
| doc_text = docs_for_rerank[result.index] | |
| score = float(result.relevance_score) | |
| final_results.append([doc_id, "", score]) | |
| if reweight: | |
| if time_result['has_temporal_aspect']: | |
| final_results = self.date_filter.filter(final_results, time_score = time_result['expected_recency_weight']) | |
| if self.weight_citation: self.citation_filter.filter(final_results) | |
| if return_scores: | |
| return {result[0]: result[2] for result in final_results} | |
| return [doc[0] for doc in final_results] | |
| def embed_docs(self, docs: List[str]): | |
| return self.embed_batch(docs) | |
| # ---------------------------------------------------------------- | |
| if 'ec' not in st.session_state: | |
| ec = HydeCohereRetrievalSystem(weight_keywords=True) | |
| st.session_state.ec = ec | |
| st.toast('loaded retrieval system') | |
| else: | |
| ec = st.session_state.ec | |
| def get_topk(query, top_k): | |
| print('running retrieval') | |
| rs = st.session_state.ec.retrieve(query, top_k, return_scores=True) | |
| return rs | |
| def Library(query, top_k = 7): | |
| rs = get_topk(query, top_k = top_k) | |
| op_docs = '' | |
| for paperno, i in enumerate(rs): | |
| op_docs = op_docs + 'Paper %.0f:' %(paperno+1) +' (published in '+st.session_state.bibcode[i][0:4] + ') ' + st.session_state.titles[i] + '\n' + st.session_state.abstracts[i] + '\n\n' | |
| return op_docs | |
| def Library2(query, top_k = 7): | |
| rs = get_topk(query, top_k = top_k) | |
| absts, fnames = [], [] | |
| for paperno, i in enumerate(rs): | |
| absts.append(st.session_state.abstracts[i]) | |
| fnames.append(st.session_state.bibcode[i]) | |
| return absts, fnames, rs | |
| def get_paper_df(ids): | |
| papers, scores, yrs, links, cites, kws = [], [], [], [], [], [] | |
| for i in ids: | |
| papers.append(st.session_state.titles[i]) | |
| scores.append(ids[i]) | |
| links.append('https://ui.adsabs.harvard.edu/abs/'+st.session_state.bibcode[i]+'/abstract') | |
| yrs.append(st.session_state.bibcode[i][0:4]) | |
| cites.append(st.session_state.cites[i]) | |
| kws.append(st.session_state.ads_kws[i]) | |
| return pd.DataFrame({ | |
| 'Title': papers, | |
| 'Relevance': scores, | |
| 'Year': yrs, | |
| 'ADS Link': links, | |
| 'Citations': cites, | |
| 'Keywords': kws, | |
| }) | |
| # def find_outliers(inp_simids, arxiv_cutoff_distance = 0.8): | |
| # | |
| # inp_simids = np.array(inp_simids) | |
| # | |
| # # Calculate the centroid for each point, excluding itself | |
| # orange_black_points = st.session_state.embed[inp_simids] | |
| # | |
| # topk_dists = [] | |
| # for i, point in enumerate(orange_black_points): | |
| # # Exclude the current point | |
| # other_points = np.delete(orange_black_points, i, axis=0) | |
| # # Calculate centroid of other points | |
| # centroid = np.mean(other_points, axis=0) | |
| # # Calculate distance from the point to this centroid | |
| # dist = np.sqrt(np.sum((point - centroid)**2)) | |
| # topk_dists.append(dist) | |
| # | |
| # topk_dists = np.array(topk_dists) | |
| # | |
| # # Separate distances for orange and black points | |
| # orange_distances = topk_dists[:len(inp_simids)] | |
| # black_distances = topk_dists[len(inp_simids):] | |
| # | |
| # # Calculate the median of distances | |
| # orange_black_distances = topk_dists | |
| # median_topk_distance = np.median(orange_black_distances) | |
| # | |
| # # def get_sims_and_dists(inp_data): | |
| # | |
| # # all_sims, all_dists = [], [] | |
| # | |
| # # np.random.seed(12) | |
| # # rand_indices = np.random.choice(inp_data.shape[0], size=return_n, replace=False) | |
| # | |
| # # for j in tqdm(range(len(rand_indices))): | |
| # | |
| # # i = rand_indices[j] | |
| # # inferred_vector = inp_data[i,:] | |
| # # sims, dists = find_closest_dists(i, inp_data, return_n + 1) | |
| # # all_sims.append(sims[1:]) | |
| # # all_dists.append(dists[1:]) | |
| # | |
| # # return np.array(all_sims), np.array(all_dists) | |
| # | |
| # # # Identify papers with distances greater than the 95th percentile | |
| # # _, all_dists = get_sims_and_dists(arxiv_ada_embeddings) | |
| # # arxiv_cutoff_distance = find_cutoff_dist(all_dists) | |
| # # hardcoding for now | |
| # outlier_indices = inp_simids[np.where(orange_black_distances > arxiv_cutoff_distance)[0]] | |
| # # outlier_titles = [titles[i] for i in outlier_indices] | |
| # | |
| # return outlier_indices #, outlier_titles | |
| def create_embedding_plot(rs): | |
| """ | |
| function to create embedding plot | |
| """ | |
| pltsource = ColumnDataSource(data=dict( | |
| x=st.session_state.umap_x, | |
| y=st.session_state.umap_y, | |
| title=st.session_state.titles, | |
| link=st.session_state.bibcode, | |
| )) | |
| rsflag = np.zeros((len(st.session_state.ids),)) | |
| rsflag[np.array([k for k in rs])] = 1 | |
| # outflag = np.zeros((len(st.session_state.ids),)) | |
| # outflag[np.array([k for k in find_outliers(rs)])] = 1 | |
| pltsource.data['colors'] = rsflag * 0.8 + 0.1 | |
| # pltsource.data['colors'][outflag] = 0.5 | |
| pltsource.data['sizes'] = (rsflag + 1)**5 / 100 | |
| TOOLTIPS = """ | |
| <div style="width:300px;"> | |
| ID: $index | |
| ($x, $y) | |
| @title <br> | |
| @link <br> <br> | |
| </div> | |
| """ | |
| mapper = linear_cmap(field_name="colors", palette=Spectral5, low=0., high=1.) | |
| p = figure(width=700, height=900, tooltips=TOOLTIPS, x_range=(0, 20), y_range=(-4.2,18), | |
| title="UMAP projection of embeddings for the astro-ph corpus") | |
| p.axis.visible=False | |
| p.grid.visible=False | |
| p.outline_line_alpha = 0. | |
| p.circle('x', 'y', radius='sizes', source=pltsource, alpha=0.3, fill_color=mapper, fill_alpha='colors', line_color="lightgrey",line_alpha=0.1) | |
| return p | |
| def extract_keywords(question, ec): | |
| # Simulated keyword extraction (replace with actual logic) | |
| return ['keyword1', 'keyword2', 'keyword3'] | |
| # Function to estimate consensus (replace with actual implementation) | |
| def estimate_consensus(): | |
| # Simulated consensus estimation (replace with actual calculation) | |
| return 0.75 | |
| def run_agent_qa(query, top_k): | |
| # define tools | |
| search = DuckDuckGoSearchAPIWrapper() | |
| tools = [ | |
| Tool( | |
| name="Library", | |
| func=Library, | |
| description="A source of information pertinent to your question. Do not answer a question without consulting this!" | |
| ), | |
| Tool( | |
| name="Search", | |
| func=search.run, | |
| description="useful for when you need to look up knowledge about common topics or current events", | |
| ) | |
| ] | |
| if 'tools' not in st.session_state: | |
| st.session_state.tools = tools | |
| # define prompt | |
| # for another question type: | |
| # First, find the quotes from the document that are most relevant to answering the question, and then print them in numbered order. | |
| # Quotes should be relatively short. If there are no relevant quotes, write βNo relevant quotesβ instead. | |
| template = """You are an expert astronomer and cosmologist. | |
| Answer the following question as best you can using information from the library, but speaking in a concise and factual manner. | |
| If you can not come up with an answer, say you do not know. | |
| Try to break the question down into smaller steps and solve it in a logical manner. | |
| You have access to the following tools: | |
| {tools} | |
| Use the following format: | |
| Question: the input question you must answer | |
| Thought: you should always think about what to do | |
| Action: the action to take, should be one of [{tool_names}] | |
| Action Input: the input to the action | |
| Observation: the result of the action | |
| ... (this Thought/Action/Action Input/Observation can repeat N times) | |
| Thought: I now know the final answer | |
| Final Answer: the final answer to the original input question. provide information about how you arrived at the answer, and any nuances or uncertainties the reader should be aware of | |
| Begin! Remember to speak in a pedagogical and factual manner." | |
| Question: {input} | |
| Thought:{agent_scratchpad}""" | |
| prompt = hub.pull("hwchase17/react") | |
| prompt.template=template | |
| # path to write intermediate trace to | |
| file_path = "agent_trace.txt" | |
| try: | |
| os.remove(file_path) | |
| except: | |
| pass | |
| file_handler = FileCallbackHandler(file_path) | |
| callback_manager=CallbackManager([file_handler]) | |
| # define and execute agent | |
| tool_names = [tool.name for tool in st.session_state.tools] | |
| if 'agent' not in st.session_state: | |
| # agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names) | |
| agent = create_react_agent(llm=gen_llm, tools=tools, prompt=prompt) | |
| st.session_state.agent = agent | |
| if 'agent_executor' not in st.session_state: | |
| agent_executor = AgentExecutor(agent=st.session_state.agent, tools=st.session_state.tools, verbose=True, handle_parsing_errors=True, callbacks=CallbackManager([file_handler])) | |
| st.session_state.agent_executor = agent_executor | |
| answer = st.session_state.agent_executor.invoke({"input": query,}) | |
| return answer | |
| def make_rag_qa_answer(query, top_k = 10): | |
| absts, fhdrs, rs = Library2(query, top_k = top_k) | |
| temp_abst = '' | |
| loaders = [] | |
| for i in range(len(absts)): | |
| temp_abst = absts[i] | |
| try: | |
| text_file = open("absts/"+fhdrs[i]+".txt", "w") | |
| except: | |
| os.mkdir('absts') | |
| text_file = open("absts/"+fhdrs[i]+".txt", "w") | |
| n = text_file.write(temp_abst) | |
| text_file.close() | |
| loader = TextLoader("absts/"+fhdrs[i]+".txt") | |
| loaders.append(loader) | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=150, chunk_overlap=50, add_start_index=True) | |
| splits = text_splitter.split_documents([loader.load()[0] for loader in loaders]) | |
| vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings, collection_name='retdoc4') | |
| # retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 6, "fetch_k": len(splits)}) | |
| retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 6}) | |
| for i in range(len(absts)): | |
| os.remove("absts/"+fhdrs[i]+".txt") | |
| template = """You are an expert astronomer and cosmologist. | |
| Answer the following question as best you can using information from the library, but speaking in a concise and factual manner. | |
| If you can not come up with an answer, say you do not know. | |
| Try to break the question down into smaller steps and solve it in a logical manner. | |
| Provide information about how you arrived at the answer, and any nuances or uncertainties the reader should be aware of. | |
| Begin! Remember to speak in a pedagogical and factual manner." | |
| Relevant documents:{context} | |
| Question: {question} | |
| Answer:""" | |
| prompt = PromptTemplate.from_template(template) | |
| def format_docs(docs): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| rag_chain_from_docs = ( | |
| RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"]))) | |
| | prompt | |
| | gen_llm | |
| | StrOutputParser() | |
| ) | |
| rag_chain_with_source = RunnableParallel( | |
| {"context": retriever, "question": RunnablePassthrough()} | |
| ).assign(answer=rag_chain_from_docs) | |
| rag_answer = rag_chain_with_source.invoke(query, ) | |
| vectorstore.delete_collection() | |
| return rag_answer, rs | |
| def guess_question_type(query: str): | |
| categorization_prompt = """You are an expert astrophysicist and computer scientist specializing in linguistics and semantics. Your task is to categorize a given query into one of the following categories: | |
| 1. Summarization | |
| 2. Single-paper factual | |
| 3. Multi-paper factual | |
| 4. Named entity recognition | |
| 5. Jargon-specific questions / overloaded words | |
| 6. Time-sensitive | |
| 7. Consensus evaluation | |
| 8. What-ifs and counterfactuals | |
| 9. Compositional | |
| Analyze the query carefully, considering its content, structure, and implications. Then, determine which of the above categories best fits the query. | |
| In your analysis, consider the following: | |
| - Does the query ask for a well-known datapoint or mechanism? | |
| - Can it be answered by a single paper or does it require multiple sources? | |
| - Does it involve proper nouns or specific scientific terms? | |
| - Is it time-dependent or likely to change in the near future? | |
| - Does it require evaluating consensus across multiple sources? | |
| - Is it a hypothetical or counterfactual question? | |
| - Does it need to be broken down into sub-queries (i.e. compositional)? | |
| After your analysis, categorize the query into one of the nine categories listed above. | |
| Provide a brief explanation for your categorization, highlighting the key aspects of the query that led to your decision. | |
| Present your final answer in the following format: | |
| <categorization> | |
| Category: [Selected category] | |
| Explanation: [Your explanation for the categorization] | |
| </categorization>""" | |
| # st.write('invoking hyde generation') | |
| # message = self.generation_client.messages.create( | |
| # model = self.generation_model, | |
| # max_tokens = self.max_doclen, | |
| # temperature = self.temperature, | |
| # system = prompt, | |
| # messages=[{ "role": "user", | |
| # "content": [{"type": "text", "text": query,}] }] | |
| # ) | |
| # return message.content[0].text | |
| messages = [("system",categorization_prompt,),("human", query),] | |
| return st.session_state.ec.generation_client.invoke(messages).content | |
| class OverallConsensusEvaluation(BaseModel): | |
| consensus: Literal["Strong Agreement", "Moderate Agreement", "Weak Agreement", "No Clear Consensus", "Weak Disagreement", "Moderate Disagreement", "Strong Disagreement"] = Field( | |
| ..., | |
| description="The overall level of consensus between the query and the abstracts" | |
| ) | |
| explanation: str = Field( | |
| ..., | |
| description="A detailed explanation of the consensus evaluation" | |
| ) | |
| relevance_score: float = Field( | |
| ..., | |
| description="A score from 0 to 1 indicating how relevant the abstracts are to the query overall", | |
| ge=0, | |
| le=1 | |
| ) | |
| def evaluate_overall_consensus(query: str, abstracts: List[str]) -> OverallConsensusEvaluation: | |
| """ | |
| Evaluates the overall consensus of the abstracts in relation to the query in a single LLM call. | |
| """ | |
| prompt = f""" | |
| Query: {query} | |
| You will be provided with {len(abstracts)} scientific abstracts. Your task is to: | |
| 1. Evaluate the overall consensus between the query and the abstracts. | |
| 2. Provide a detailed explanation of your consensus evaluation. | |
| 3. Assign an overall relevance score from 0 to 1, where 0 means completely irrelevant and 1 means highly relevant. | |
| For the consensus evaluation, use one of the following levels: | |
| Strong Agreement, Moderate Agreement, Weak Agreement, No Clear Consensus, Weak Disagreement, Moderate Disagreement, Strong Disagreement | |
| Here are the abstracts: | |
| {' '.join([f"Abstract {i+1}: {abstract}" for i, abstract in enumerate(abstracts)])} | |
| Provide your evaluation in a structured format. | |
| """ | |
| response = consensus_client.chat.completions.create( | |
| model="gpt-4o-mini", # used to be "gpt-4", | |
| response_model=OverallConsensusEvaluation, | |
| messages=[ | |
| {"role": "system", "content": """You are an assistant with expertise in astrophysics for question-answering tasks. | |
| Evaluate the overall consensus of the retrieved scientific abstracts in relation to a given query. | |
| If you don't know the answer, just say that you don't know. | |
| Use six sentences maximum and keep the answer concise."""}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=0 | |
| ) | |
| return response | |
| # Streamlit app | |
| def main(): | |
| # st.title("Question Answering App") | |
| # Sidebar (Inputs) | |
| st.sidebar.header("Fine-tune the search") | |
| top_k = st.sidebar.slider("Number of papers to retrieve:", 3, 30, 10) | |
| extra_keywords = st.sidebar.text_input("Enter extra keywords (comma-separated):") | |
| st.sidebar.subheader("Toggles") | |
| toggle_a = st.sidebar.toggle("Weight by keywords", value = False) | |
| toggle_b = st.sidebar.toggle("Weight by date", value = False) | |
| toggle_c = st.sidebar.toggle("Weight by citations", value = False) | |
| method = st.sidebar.radio("Retrieval method:", ["Semantic search", "Semantic search + HyDE", "Semantic search + HyDE + CoHERE"], index=2) | |
| if (method == "Semantic search"): | |
| with st.spinner('set retrieval method to'+ method): | |
| st.session_state.ec = EmbeddingRetrievalSystem(weight_keywords=True) | |
| elif (method == "Semantic search + HyDE"): | |
| with st.spinner('set retrieval method to'+ method): | |
| st.session_state.ec = HydeRetrievalSystem(weight_keywords=True) | |
| elif (method == "Semantic search + HyDE + CoHERE"): | |
| with st.spinner('set retrieval method to'+ method): | |
| st.session_state.ec = HydeCohereRetrievalSystem(weight_keywords=True) | |
| method2 = st.sidebar.radio("Generation complexity:", ["Basic RAG","ReAct Agent"]) | |
| if method2 == "Basic RAG": | |
| st.session_state.gen_method = 'rag' | |
| elif method2 == "ReAct Agent": | |
| st.session_state.gen_method = 'agent' | |
| question_type = st.sidebar.selectbox("Select question type:", ["Single paper", "Multi-paper", "Summary"]) | |
| store_output = st.sidebar.button("Save output") | |
| # Main page (Outputs) | |
| # st.markdown(""" | |
| # <style> | |
| # .stTextInput > div > div { font-size: 50px; } | |
| # </style> | |
| # """, unsafe_allow_html=True) | |
| # st.markdown( | |
| # """ | |
| # <style> | |
| # textarea { | |
| # font-size: 3rem !important; | |
| # font-weight: bold; | |
| # font-family: "Times New Roman", Times, serif; | |
| # } | |
| # input { | |
| # font-size: 3rem !important; | |
| # font-weight: bold; | |
| # font-family: "Times New Roman", Times, serif; | |
| # } | |
| # </style> | |
| # """, | |
| # unsafe_allow_html=True, | |
| # ) | |
| # query = st.text_area("Ask me anything:", height=30) | |
| query = st.text_input("Ask me anything:") | |
| submit_button = st.button("Submit") | |
| if submit_button: | |
| search_text_list = ['rooting around in the paper pile...','looking for clarity...','scanning the event horizon...','peering into the abyss...','potatoes power this ongoing search...'] | |
| with st.spinner(search_text_list[np.random.choice(len(search_text_list))]): | |
| # Process inputs | |
| keywords = [kw.strip() for kw in extra_keywords.split(',')] if extra_keywords else [] | |
| toggles = {'Keyword weighting': toggle_a, 'Time weighting': toggle_b, 'Citation weighting': toggle_c} | |
| # Generate outputs | |
| st.session_state.ec.query_input_keywords = keywords | |
| st.session_state.ec.toggles = toggles | |
| st.session_state.ec.question_type = question_type | |
| st.session_state.ec.rag_method = method | |
| st.session_state.ec.gen_method = method2 | |
| # Display outputs | |
| if st.session_state.gen_method == 'agent': | |
| answer = run_agent_qa(query, top_k) | |
| rs = get_topk(query, top_k) | |
| st.write(answer["output"]) | |
| file_path = "agent_trace.txt" | |
| with open(file_path, 'r') as file: | |
| intermediate_steps = file.read() | |
| st.expander('Intermediate steps', expanded=False).write(intermediate_steps) | |
| elif st.session_state.gen_method == 'rag': | |
| answer, rs = make_rag_qa_answer(query, top_k) | |
| st.write(answer['answer']) | |
| papers_df = get_paper_df(rs) | |
| embedding_plot = create_embedding_plot(rs) | |
| triggered_keywords = st.session_state.ec.query_kws | |
| st.write('**Triggered keywords:** `'+ "`, `".join(triggered_keywords)+'`') | |
| # consensus = estimate_consensus() | |
| with st.expander("Relevant papers", expanded=True): | |
| # st.dataframe(papers_df, hide_index=True) | |
| st.data_editor(papers_df, | |
| column_config = {'ADS Link':st.column_config.LinkColumn(display_text= 'https://ui.adsabs.harvard.edu/abs/(.*?)/abstract')} | |
| ) | |
| # with st.expander("Embedding map", expanded=False): | |
| st.bokeh_chart(embedding_plot) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Question type suggestion") | |
| question_type_gen = guess_question_type(query) | |
| if '<categorization>' in question_type_gen: | |
| question_type_gen = question_type_gen.split('<categorization>')[1] | |
| if '</categorization>' in question_type_gen: | |
| question_type_gen = question_type_gen.split('</categorization>')[0] | |
| question_type_gen = question_type_gen.replace('\n',' \n') | |
| st.markdown(question_type_gen) | |
| with col2: | |
| # st.subheader("Triggered Keywords") | |
| # st.write(", ".join(triggered_keywords)) | |
| consensus_answer = evaluate_overall_consensus(query, [st.session_state.abstracts[i] for i in rs]) | |
| st.subheader("Consensus: "+consensus_answer.consensus) | |
| st.markdown(consensus_answer.explanation) | |
| st.markdown('Relevance of retrieved papers to answer: %.1f' %consensus_answer.relevance_score) | |
| # st.write(f"{consensus:.2%}") | |
| else: | |
| st.info("Use the sidebar to tweak the search parameters to get better results.") | |
| if store_output: | |
| st.toast("Output stored successfully!") | |
| if __name__ == "__main__": | |
| main() | |