Spaces:
Sleeping
Sleeping
| import chromadb | |
| from chromadb import Documents, EmbeddingFunction, Embeddings | |
| from transformers import AutoModel | |
| import json | |
| from numpy.linalg import norm | |
| import sqlite3 | |
| import urllib | |
| from django.conf import settings | |
| import Levenshtein | |
| # this module act as a singleton class | |
| class JinaAIEmbeddingFunction(EmbeddingFunction): | |
| def __init__(self, model): | |
| super().__init__() | |
| self.model = model | |
| def __call__(self, input: Documents) -> Embeddings: | |
| embeddings = self.model.encode(input) | |
| return embeddings.tolist() | |
| # instance of embedding_model | |
| embedding_model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', | |
| trust_remote_code=True, | |
| cache_dir='models') | |
| # instance of JinaAIEmbeddingFunction | |
| ef = JinaAIEmbeddingFunction(embedding_model) | |
| # list of topics | |
| topic_descriptions = json.load(open("topic_descriptions.txt")) | |
| topics = list(dict.keys(topic_descriptions)) | |
| embeddings = [embedding_model.encode(topic_descriptions[key]) for key in topic_descriptions] | |
| cos_sim = lambda a,b: (a @ b.T) / (norm(a)*norm(b)) | |
| def lev_sim(a,b): return Levenshtein.distance(a,b) | |
| def choose_topic(summary): | |
| embed = embedding_model.encode(summary) | |
| topic = "" | |
| max_sim = 0. | |
| for i,key in enumerate(topics): | |
| sim = cos_sim(embed,embeddings[i]) | |
| if sim > max_sim: | |
| topic = key | |
| max_sim = sim | |
| return topic | |
| def authors_list_to_str(authors): | |
| """input a list of authors, return a string represent authors""" | |
| text = "" | |
| for author in authors: | |
| text+=author+", " | |
| return text[:-3] | |
| def authors_str_to_list(string): | |
| """input a string of authors, return a list of authors""" | |
| authors = [] | |
| list_auth = string.split("and") | |
| for author in list_auth: | |
| if author != "et al.": | |
| authors.append(author.strip()) | |
| return authors | |
| def chunk_texts(text, max_char=400): | |
| """ | |
| Chunk a long text into several chunks, with each chunk about 300-400 characters long, | |
| but make sure no word is cut in half. | |
| Args: | |
| text: The long text to be chunked. | |
| max_char: The maximum number of characters per chunk (default: 400). | |
| Returns: | |
| A list of chunks. | |
| """ | |
| chunks = [] | |
| current_chunk = "" | |
| words = text.split() | |
| for word in words: | |
| if len(current_chunk) + len(word) + 1 >= max_char: | |
| chunks.append(current_chunk) | |
| current_chunk = " " | |
| else: | |
| current_chunk += " " + word | |
| chunks.append(current_chunk.strip()) | |
| return chunks | |
| def trimming(txt): | |
| start = txt.find("{") | |
| end = txt.rfind("}") | |
| return txt[start:end+1].replace("\n"," ") | |
| # crawl data | |
| def extract_tag(txt,tagname): | |
| return txt[txt.find("<"+tagname+">")+len(tagname)+2:txt.find("</"+tagname+">")] | |
| def get_record(extract): | |
| id = extract_tag(extract,"id") | |
| updated = extract_tag(extract,"updated") | |
| published = extract_tag(extract,"published") | |
| title = extract_tag(extract,"title").replace("\n ","").strip() | |
| summary = extract_tag(extract,"summary").replace("\n","").strip() | |
| authors = [] | |
| while extract.find("<author>")!=-1: | |
| author = extract_tag(extract,"name") | |
| extract = extract[extract.find("</author>")+9:] | |
| authors.append(author) | |
| pattern = '<link title="pdf" href="' | |
| link_start = extract.find('<link title="pdf" href="') | |
| link = extract[link_start+len(pattern):extract.find("rel=",link_start)-2] | |
| return [id, updated, published, title, authors, link, summary] | |
| def crawl_exact_paper(title,author,max_results=3): | |
| authors = authors_list_to_str(author) | |
| records = [] | |
| url = 'http://export.arxiv.org/api/query?search_query=ti:{title}+AND+au:{author}&max_results={max_results}'.format(title=title,author=authors,max_results=max_results) | |
| url = url.replace(" ","%20") | |
| try: | |
| arxiv_page = urllib.request.urlopen(url,timeout=100).read() | |
| xml = str(arxiv_page,encoding="utf-8") | |
| while xml.find("<entry>") != -1: | |
| extract = xml[xml.find("<entry>")+7:xml.find("</entry>")] | |
| xml = xml[xml.find("</entry>")+8:] | |
| extract = get_record(extract) | |
| topic = choose_topic(extract[6]) | |
| records.append([topic,*extract]) | |
| return records | |
| except Exception as e: | |
| return "Error: "+str(e) | |
| def crawl_arxiv(keyword_list, max_results=100): | |
| baseurl = 'http://export.arxiv.org/api/query?search_query=' | |
| records = [] | |
| for i,keyword in enumerate(keyword_list): | |
| if i ==0: | |
| url = baseurl + 'all:' + keyword | |
| else: | |
| url = url + '+OR+' + 'all:' + keyword | |
| url = url+ '&max_results=' + str(max_results) | |
| url = url.replace(' ', '%20') | |
| try: | |
| arxiv_page = urllib.request.urlopen(url,timeout=100).read() | |
| xml = str(arxiv_page,encoding="utf-8") | |
| while xml.find("<entry>") != -1: | |
| extract = xml[xml.find("<entry>")+7:xml.find("</entry>")] | |
| xml = xml[xml.find("</entry>")+8:] | |
| extract = get_record(extract) | |
| topic = choose_topic(extract[6]) | |
| records.append([topic,*extract]) | |
| return records | |
| except Exception as e: | |
| return "Error: "+str(e) | |
| # This class act as a module | |
| class ArxivChroma: | |
| """ | |
| Create an interface to arxivdb, which only support query and addition. | |
| This interface do not support edition and deletion procedures. | |
| """ | |
| client = None | |
| model = None | |
| collection = None | |
| def connect(table="arxiv_records", name="arxivdb/"): | |
| ArxivChroma.client = chromadb.PersistentClient(name) | |
| ArxivChroma.model = embedding_model | |
| ArxivChroma.collection = ArxivChroma.client.get_or_create_collection(table, | |
| embedding_function=JinaAIEmbeddingFunction( | |
| model = ArxivChroma.model | |
| )) | |
| def query_relevant(keywords, query_texts, n_results=3): | |
| """ | |
| Perform a query using a list of keywords (str), | |
| or using a relavant string | |
| """ | |
| contains = [] | |
| for keyword in keywords: | |
| contains.append({"$contains":keyword.lower()}) | |
| return ArxivChroma.collection.query( | |
| query_texts=query_texts, | |
| where_document={ | |
| "$or":contains | |
| }, | |
| n_results=n_results, | |
| ) | |
| def query_exact(id): | |
| ids = ["{}_{}".format(id,j) for j in range(0,10)] | |
| return ArxivChroma.collection.get(ids=ids) | |
| def add(crawl_records): | |
| """ | |
| Add crawl_records (list) obtained from arxiv_crawlers | |
| A record is a list of 8 columns: | |
| [topic, id, updated, published, title, author, link, summary] | |
| Return the final length of the database table | |
| """ | |
| for record in crawl_records: | |
| embed_text = """ | |
| Topic: {}, | |
| Title: {}, | |
| Summary: {} | |
| """.format(record[0],record[4],record[7]) | |
| chunks = chunk_texts(embed_text) | |
| ids = [record[1][21:]+"_"+str(j) for j in range(len(chunks))] | |
| paper_ids = [{"paper_id":record[1][21:]} for _ in range(len(chunks))] | |
| ArxivChroma.collection.add( | |
| documents = chunks, | |
| metadatas=paper_ids, | |
| ids = ids | |
| ) | |
| return ArxivChroma.collection.count() | |
| def close_connection(): | |
| pass | |
| # This class act as a module | |
| class ArxivSQL: | |
| table = "arxivsql" | |
| con = None | |
| cur = None | |
| def connect(name="db.sqlite3"): | |
| ArxivSQL.con = sqlite3.connect(name, check_same_thread=False) | |
| ArxivSQL.cur = ArxivSQL.con.cursor() | |
| def query(title="", author=[], threshold = 15): | |
| if len(author)>0: | |
| query_author= " OR ".join([f"author LIKE '%{a}%'" for a in author]) | |
| else: | |
| query_author= "True" | |
| # Execute the query | |
| query = f"select * from {ArxivSQL.table} where {query_author}" | |
| results = ArxivSQL.cursor.execute(query).fetchall() | |
| if len(title) == 0: | |
| return results | |
| else: | |
| sim_score = {} | |
| for row in results: | |
| row_title = row[2] | |
| row_id = row[0] | |
| score = lev_sim(title, row_title) | |
| if score < threshold: | |
| sim_score[row_id] = score | |
| sorted_results = sorted(sim_score.items(), key=lambda x: x[1]) | |
| return ArxivSQL.query_id(sorted_results) | |
| def query_id(ids=[]): | |
| try: | |
| if len(ids) == 0: | |
| return None | |
| query = "select * from {} where id in (".format(ArxivSQL.table) | |
| for id in ids: | |
| query+="'"+id+"'," | |
| query = query[:-1] + ")" | |
| result = ArxivSQL.cur.execute(query) | |
| return result.fetchall() | |
| except Exception as e: | |
| print(e) | |
| print("Error query: ",query) | |
| def add(crawl_records): | |
| """ | |
| Add crawl_records (list) obtained from arxiv_crawlers | |
| A record is a list of 8 columns: | |
| [topic, id, updated, published, title, author, link, summary] | |
| Return the final length of the database table | |
| """ | |
| results = "" | |
| for record in crawl_records: | |
| try: | |
| query = """insert into arxivsql values("{}","{}","{}","{}","{}","{}","{}")""".format( | |
| record[1][21:], | |
| record[0], | |
| record[4].replace('"',"'"), | |
| authors_list_to_str(record[5]), | |
| record[2][:10], | |
| record[3][:10], | |
| record[6] | |
| ) | |
| ArxivSQL.cur.execute(query) | |
| ArxivSQL.con.commit() | |
| except Exception as e: | |
| results+=str(e) | |
| results+="\n" + query + "\n" | |
| finally: | |
| return results |