suwesh commited on
Commit
c543e79
·
verified ·
1 Parent(s): b1fee52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -4
app.py CHANGED
@@ -1,11 +1,57 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def respond(
11
  message,
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
 
4
+ import json
5
+ import pandas as pd
6
+ import numpy as np
7
+
8
+ import torch
9
+ from sentence_transformers import SentenceTransformer
10
+ import nltk
11
+ from nltk.tokenize import sent_tokenize
12
+ import faiss
13
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
14
+
15
+ optimus = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
16
+ textsplitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
17
+
18
+ dbpath = r''
19
+ gridpath = r''
20
+
21
+ with open(dbpath) as f:
22
+ papers = [json.loads(line) for line in f]
23
+ df = pd.DataFrame(papers)
24
+ reqdf = df[['id', 'title', 'categories', 'abstract']]
25
+
26
+ d = 384
27
+ index = faiss.IndexFlatL2(d)
28
+ thegird = []#load the grid and index from json file here
29
+
30
+ def gen_embeddings(text):
31
+ sentences = sent_tokenize(text)
32
+ embeddings = optimus.encode(sentences)
33
+ return embeddings
34
+
35
+ query_list = gen_embeddings(query)
36
+ if len(query_list) > 1:
37
+ query_list = torch.mean(query_list, dim=0)
38
+ query_matrix = np.array(query_list).astype('float32').reshape(-1,1)
39
+ k = 10
40
+ distances, indices = index.search(query_matrix, k)
41
+ result_texts = [thegrid[idx]['text'] for idx in indices[0]]
42
+ for i, text in enumerate(result_texts):
43
+ printres = f"Match {i+1}: {text}"
44
 
45
+ searched_topics = []
46
+ idcache = []
47
+ for text in result_texts:
48
+ rowid = text.split("|||")[0]
49
+ if rowid in idcache:
50
+ break;
51
+ else:
52
+ topic = reqdf.loc[reqdf['id'] === rowid, 'title'].values[0]
53
+ searched_topics.append(rowid)
54
+ idcache.append(rowid)
55
 
56
  def respond(
57
  message,