Spaces:
Sleeping
Sleeping
| import time | |
| s2 = time.time() | |
| import numpy as np | |
| import streamlit as st | |
| import json | |
| from abc import ABC, abstractmethod | |
| from typing import List, Dict, Any, Tuple | |
| from collections import defaultdict | |
| # import wandb | |
| import numpy as np | |
| from tqdm import tqdm | |
| from datetime import datetime, date | |
| import pickle | |
| from datasets import load_dataset | |
| import os | |
| from nltk.corpus import stopwords | |
| import nltk | |
| from openai import OpenAI | |
| import anthropic | |
| import time | |
| from collections import Counter | |
| try: | |
| stopwords.words('english') | |
| except: | |
| nltk.download('stopwords') | |
| stopwords.words('english') | |
| openai_key = st.secrets['openai_key'] | |
| anthropic_key = st.secrets['anthropic_key'] | |
| # anthropic_key = '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' | |
| def load_astro_meta(): | |
| print('load astro meta') | |
| return load_dataset('arxiv_corpus/', split = "train") | |
| def load_index_mapping(index_mapping_path): | |
| print("Loading index mapping...") | |
| with open(index_mapping_path, 'rb') as f: | |
| temp = pickle.load(f) | |
| return temp | |
| def load_embeddings(embeddings_path): | |
| print("Loading embedding") | |
| return np.load(embeddings_path) | |
| def load_metadata(meta_path): | |
| print("Loading metadata...") | |
| with open(meta_path, 'r') as f: | |
| metadata = json.load(f) | |
| return metadata | |
| # @st.cache_data | |
| def load_umapcoords(umap_path): | |
| print('loading umap coords') | |
| with open(umap_path, "rb") as fp: #Pickling | |
| umap = pickle.load(fp) | |
| return umap | |
| class EmbeddingClient: | |
| def __init__(self, client: OpenAI, model: str = "text-embedding-3-small"): | |
| self.client = client | |
| self.model = model | |
| def embed(self, text: str) -> np.ndarray: | |
| embedding = self.client.embeddings.create(input=[text], model=self.model).data[0].embedding | |
| return np.array(embedding, dtype=np.float32) | |
| def embed_batch(self, texts: List[str]) -> List[np.ndarray]: | |
| embeddings = self.client.embeddings.create(input=texts, model=self.model).data | |
| return [np.array(embedding.embedding, dtype=np.float32) for embedding in embeddings] | |
| class RetrievalSystem(ABC): | |
| def retrieve(self, query: str, arxiv_id: str, top_k: int = 100) -> List[str]: | |
| pass | |
| def parse_date(self, arxiv_id: str) -> datetime: | |
| if arxiv_id is None: | |
| return date.today() | |
| if arxiv_id.startswith('astro-ph'): | |
| arxiv_id = arxiv_id.split('astro-ph')[1].split('_arXiv')[0] | |
| try: | |
| year = int("20" + arxiv_id[:2]) | |
| month = int(arxiv_id[2:4]) | |
| except: | |
| year = 2023 | |
| month = 1 | |
| return date(year, month, 1) | |
| class EmbeddingRetrievalSystem(RetrievalSystem): | |
| def __init__(self, embeddings_path: str = "local_files/embeddings_matrix.npy", | |
| documents_path: str = "local_files/documents.pkl", | |
| index_mapping_path: str = "local_files/index_mapping.pkl", | |
| metadata_path: str = "local_files/metadata.json", | |
| weight_citation = False, weight_date = False, weight_keywords = False): | |
| self.embeddings_path = embeddings_path | |
| self.documents_path = documents_path | |
| self.index_mapping_path = index_mapping_path | |
| self.metadata_path = metadata_path | |
| self.weight_citation = weight_citation | |
| self.weight_date = weight_date | |
| self.weight_keywords = weight_keywords | |
| self.embeddings = None | |
| self.documents = None | |
| self.index_mapping = None | |
| self.metadata = None | |
| self.document_dates = [] | |
| self.load_data() | |
| self.init_filters() | |
| # config = yaml.safe_load(open('../config.yaml', 'r')) | |
| self.client = EmbeddingClient(OpenAI(api_key=openai_key)) | |
| self.anthropic_client = anthropic.Anthropic(api_key=anthropic_key) | |
| def generate_metadata(self): | |
| astro_meta = load_astro_meta() | |
| # dataset = load_dataset('arxiv_corpus/') | |
| keys = list(astro_meta[0].keys()) | |
| keys.remove('abstract') | |
| keys.remove('introduction') | |
| keys.remove('conclusions') | |
| self.metadata = {} | |
| for paper in astro_meta: | |
| id_str = paper['arxiv_id'] | |
| self.metadata[id_str] = {key: paper[key] for key in keys} | |
| with open(self.metadata_path, 'w') as f: | |
| json.dump(self.metadata, f) | |
| st.markdown("Wrote metadata to {}".format(self.metadata_path)) | |
| # | |
| def load_data(self): | |
| # print("Loading embeddings...") | |
| # self.embeddings = np.load(self.embeddings_path) | |
| self.embeddings = load_embeddings(self.embeddings_path) | |
| st.sidebar.success("Loaded embeddings") | |
| # with open(self.index_mapping_path, 'rb') as f: | |
| # self.index_mapping = pickle.load(f) | |
| self.index_mapping = load_index_mapping(self.index_mapping_path) | |
| st.sidebar.success("Loaded index mapping") | |
| # print("Loading documents...") | |
| # with open(self.documents_path, 'rb') as f: | |
| # self.documents = pickle.load(f) | |
| dataset = load_astro_meta() | |
| st.sidebar.success("Loaded documents") | |
| print("Processing document dates...") | |
| # self.document_dates = {doc.id: self.parse_date(doc.arxiv_id) for doc in self.documents} | |
| aids = dataset['arxiv_id'] | |
| adsids = dataset['id'] | |
| self.document_dates = {adsids[i]: self.parse_date(aids[i]) for i in range(len(aids))} | |
| if os.path.exists(self.metadata_path): | |
| self.metadata = load_metadata(self.metadata_path) | |
| print("Loaded metadata.") | |
| else: | |
| print("Could not find path; generating metadata.") | |
| self.generate_metadata() | |
| print("Data loaded successfully.") | |
| def init_filters(self): | |
| print("Loading filters...") | |
| self.citation_filter = CitationFilter(metadata = self.metadata) | |
| self.date_filter = DateFilter(document_dates = self.document_dates) | |
| self.keyword_filter = KeywordFilter(index_path = "local_files/keyword_index.json", metadata = self.metadata, remove_capitals = True) | |
| def retrieve(self, query: str, arxiv_id: str = None, top_k: int = 10, return_scores = False, time_result = None) -> List[Tuple[str, str, float]]: | |
| query_date = self.parse_date(arxiv_id) | |
| query_embedding = self.get_query_embedding(query) | |
| # Judge time relevance | |
| if time_result is None: | |
| if self.weight_date: time_result, time_taken = analyze_temporal_query(query, self.anthropic_client) | |
| else: time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None} | |
| top_results = self.rank_and_filter(query, query_embedding, query_date, top_k, return_scores = return_scores, time_result = time_result) | |
| return top_results | |
| def rank_and_filter(self, query, query_embedding: np.ndarray, query_date, top_k: int = 10, return_scores = False, time_result = None) -> List[Tuple[str, str, float]]: | |
| # Calculate similarities | |
| similarities = np.dot(self.embeddings, query_embedding) | |
| # Filter and rank results | |
| if self.weight_keywords: keyword_matches = self.keyword_filter.filter(query) | |
| results = [] | |
| for doc_id, mappings in self.index_mapping.items(): | |
| if not self.weight_keywords or doc_id in keyword_matches: | |
| abstract_sim = similarities[mappings['abstract']] if 'abstract' in mappings else -np.inf | |
| conclusions_sim = similarities[mappings['conclusions']] if 'conclusions' in mappings else -np.inf | |
| if abstract_sim > conclusions_sim: | |
| results.append([doc_id, "abstract", abstract_sim]) | |
| else: | |
| results.append([doc_id, "conclusions", conclusions_sim]) | |
| # Sort and weight and get top-k results | |
| if time_result['has_temporal_aspect']: | |
| filtered_results = self.date_filter.filter(results, boolean_date = time_result['expected_year_filter'], time_score = time_result['expected_recency_weight'], max_date = query_date) | |
| else: | |
| filtered_results = self.date_filter.filter(results, max_date = query_date) | |
| if self.weight_citation: self.citation_filter.filter(filtered_results) | |
| top_results = sorted(filtered_results, key=lambda x: x[2], reverse=True)[:top_k] | |
| if return_scores: | |
| return {doc[0]: doc[2] for doc in top_results} | |
| # Only keep the document IDs | |
| top_results = [doc[0] for doc in top_results] | |
| return top_results | |
| def get_query_embedding(self, query: str) -> np.ndarray: | |
| embedding = self.client.embed(query) | |
| return np.array(embedding, dtype = np.float32) | |
| def get_document_texts(self, doc_ids: List[str]) -> List[Dict[str, str]]: | |
| results = [] | |
| for doc_id in doc_ids: | |
| doc = next((d for d in self.documents if d.id == doc_id), None) | |
| if doc: | |
| results.append({ | |
| 'id': doc.id, | |
| 'abstract': doc.abstract, | |
| 'conclusions': doc.conclusions | |
| }) | |
| else: | |
| print(f"Warning: Document with ID {doc_id} not found.") | |
| return results | |
| def retrieve_context(self, query, top_k, sections = ["abstract", "conclusions"], **kwargs): | |
| docs = self.retrieve(query, top_k = top_k, return_scores = True, **kwargs) | |
| docids = docs.keys() | |
| doctexts = self.get_document_texts(docids) # avoid having to do this repetitively? | |
| context_str = "" | |
| doclist = [] | |
| for docid, doctext in zip(docids, doctexts): | |
| for section in sections: | |
| context_str += f"{docid}: {doctext[section]}\n" | |
| meta_row = self.metadata[docid] | |
| doclist.append(Document(docid, doctext['abstract'], doctext['conclusions'], docid, title = meta_row['title'], | |
| score = docs[docid], n_citation = meta_row['citation_count'], keywords = meta_row['keyword_search'])) | |
| return context_str, doclist | |
| class Filter(): | |
| def filter(self, query: str, arxiv_id: str) -> List[str]: | |
| pass | |
| class CitationFilter(Filter): # can do it with all metadata | |
| def __init__(self, metadata): | |
| self.metadata = metadata | |
| self.citation_counts = {doc_id: self.metadata[doc_id]['citation_count'] for doc_id in self.metadata} | |
| def citation_weight(self, x, shift, scale): | |
| return 1 / (1 + np.exp(-1 * (x - shift) / scale)) # sigmoid function | |
| def filter(self, doc_scores, weight = 0.1): # additive weighting | |
| citation_count = np.array([self.citation_counts[doc[0]] for doc in doc_scores]) | |
| cmean, cstd = np.median(citation_count), np.std(citation_count) | |
| citation_score = self.citation_weight(citation_count, cmean, cstd) | |
| for i, doc in enumerate(doc_scores): | |
| doc_scores[i][2] += weight * citation_score[i] | |
| class DateFilter(Filter): # include time weighting eventually | |
| def __init__(self, document_dates): | |
| self.document_dates = document_dates | |
| def parse_date(self, arxiv_id: str) -> datetime: # only for documents | |
| if arxiv_id.startswith('astro-ph'): | |
| arxiv_id = arxiv_id.split('astro-ph')[1].split('_arXiv')[0] | |
| try: | |
| year = int("20" + arxiv_id[:2]) | |
| month = int(arxiv_id[2:4]) | |
| except: | |
| year = 2023 | |
| month = 1 | |
| return date(year, month, 1) | |
| def weight(self, time, shift, scale): | |
| return 1 / (1 + np.exp((time - shift) / scale)) | |
| def evaluate_filter(self, year, filter_string): | |
| try: | |
| # Use ast.literal_eval to safely evaluate the expression | |
| result = eval(filter_string, {"__builtins__": None}, {"year": year}) | |
| return result | |
| except Exception as e: | |
| print(f"Error evaluating filter: {e}") | |
| return False | |
| def filter(self, docs, boolean_date = None, min_date = None, max_date = None, time_score = 0): | |
| filtered = [] | |
| if boolean_date is not None: | |
| boolean_date = boolean_date.replace("AND", "and").replace("OR", "or") | |
| for doc in docs: | |
| if self.evaluate_filter(self.document_dates[doc[0]].year, boolean_date): | |
| filtered.append(doc) | |
| else: | |
| if min_date == None: min_date = date(1990, 1, 1) | |
| if max_date == None: max_date = date(2024, 7, 3) | |
| for doc in docs: | |
| if self.document_dates[doc[0]] >= min_date and self.document_dates[doc[0]] <= max_date: | |
| filtered.append(doc) | |
| if time_score is not None: # apply time weighting | |
| for i, item in enumerate(filtered): | |
| time_diff = (max_date - self.document_dates[filtered[i][0]]).days / 365 | |
| filtered[i][2] += time_score * 0.1 * self.weight(time_diff, 5, 5) | |
| return filtered | |
| class KeywordFilter(Filter): | |
| def __init__(self, index_path: str = "local_files/keyword_index.json", | |
| remove_capitals: bool = True, metadata = None, ne_only = True, verbose = False): | |
| self.index_path = index_path | |
| self.metadata = metadata | |
| self.remove_capitals = remove_capitals | |
| self.ne_only = ne_only | |
| self.stopwords = set(stopwords.words('english')) | |
| self.verbose = verbose | |
| self.index = None | |
| self.load_or_build_index() | |
| def preprocess_text(self, text: str) -> str: | |
| text = ''.join(char for char in text if char.isalnum() or char.isspace()) | |
| if self.remove_capitals: text = text.lower() | |
| return ' '.join(word for word in text.split() if word.lower() not in self.stopwords) | |
| def build_index(self): # include the title in the index | |
| print("Building index...") | |
| self.index = {} | |
| for i, index in tqdm(enumerate(self.metadata)): | |
| paper = self.metadata[index] | |
| title = paper['title'][0] | |
| title_keywords = set() #set(self.parse_doc(title) + self.get_propn(title)) | |
| for keyword in set(paper['keyword_search']) | title_keywords: | |
| term = ' '.join(word for word in keyword.lower().split() if word.lower() not in self.stopwords) | |
| if term not in self.index: | |
| self.index[term] = [] | |
| self.index[term].append(paper['arxiv_id']) | |
| with open(self.index_path, 'w') as f: | |
| json.dump(self.index, f) | |
| def load_index(self): | |
| print("Loading existing index...") | |
| with open(self.index_path, 'rb') as f: | |
| self.index = json.load(f) | |
| print("Index loaded successfully.") | |
| def load_or_build_index(self): | |
| if os.path.exists(self.index_path): | |
| self.load_index() | |
| else: | |
| self.build_index() | |
| def parse_doc(self, doc): | |
| local_kws = [] | |
| for phrase in doc._.phrases: | |
| local_kws.append(phrase.text.lower()) | |
| return [self.preprocess_text(word) for word in local_kws] | |
| def get_propn(self, doc): | |
| result = [] | |
| working_str = '' | |
| for token in doc: | |
| if(token.text in nlp.Defaults.stop_words or token.text in punctuation): | |
| if working_str != '': | |
| result.append(working_str.strip()) | |
| working_str = '' | |
| if(token.pos_ == "PROPN"): | |
| working_str += token.text + ' ' | |
| if working_str != '': result.append(working_str.strip()) | |
| return [self.preprocess_text(word) for word in result] | |
| def filter(self, query: str, doc_ids = None): | |
| doc = nlp(query) | |
| query_keywords = self.parse_doc(doc) | |
| nouns = self.get_propn(doc) | |
| if self.verbose: print('keywords:', query_keywords) | |
| if self.verbose: print('proper nouns:', nouns) | |
| filtered = set() | |
| if len(query_keywords) > 0 and not self.ne_only: | |
| for keyword in query_keywords: | |
| if keyword != '' and keyword in self.index.keys(): filtered |= set(self.index[keyword]) | |
| if len(nouns) > 0: | |
| ne_results = set() | |
| for noun in nouns: | |
| if noun in self.index.keys(): ne_results |= set(self.index[noun]) | |
| if self.ne_only: filtered = ne_results # keep only named entity results | |
| else: filtered &= ne_results # take the intersection | |
| if doc_ids is not None: filtered &= doc_ids # apply filter to results | |
| return filtered | |
| def get_cluster_keywords(clust_ids, all_keywords): | |
| tagstr = '' | |
| clust_tags = [] | |
| for i in range(len(clust_ids)): | |
| clust_paper_kw = [] | |
| for j in range(len(all_keywords[clust_ids[i]])): | |
| clust_tags.append(all_keywords[clust_ids[i]][j]) | |
| tags = Counter(clust_tags).most_common(30) | |
| for i in range(len(tags)): | |
| # print(tags[i][0]) | |
| if len(tags[i][0]) > 2: | |
| tagstr = tagstr + tags[i][0]+ ', ' | |
| return tagstr | |
| def get_keywords(query, ret_indices, all_keywords): | |
| kws = get_cluster_keywords(ret_indices, all_keywords) | |
| kw_prompt = """You are an expert research assistant. Here are a list of keywords corresponding to the topics that a query and its answer are about that you need to synthesize into a succinct summary: | |
| ["""+kws+"""] | |
| First, find the keywords that are most relevant to answering the question, and then print them in numbered order. Keywords should be a few words at most. Do not list more than five keywords. | |
| If there are no relevant quotes, write “No relevant keywords” instead. | |
| Thus, the format of your overall response should look like what’s shown between the tags. Make sure to follow the formatting and spacing exactly. | |
| Keywords: | |
| [1] Milky Way galaxy | |
| [2] Good agreement | |
| [3] Bayesian | |
| [4] Observational constraints | |
| [5] Globular clusters | |
| [6] Kinematic data | |
| If the question cannot be answered by the document, say so.""" | |
| client = anthropic.Anthropic(api_key=anthropic_key,) | |
| message = client.messages.create(model="claude-3-haiku-20240307",max_tokens=200,temperature=0,system=kw_prompt, | |
| messages=[{"role": "user","content": [{"type": "text","text": query}]}]) | |
| return message.content[0].text |