BillBojangeles2000 commited on
Commit
4ac91c8
·
1 Parent(s): 6792cec

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pinecone
2
+ import streamlit as st
3
+
4
+ API = st.text_area('Enter API key:')
5
+ res = st.button('Submit')
6
+ if res = True:
7
+ # connect to pinecone environment
8
+ pinecone.init(
9
+ api_key="API",
10
+ environment="us-central1-gcp" # find next to API key in console
11
+ )
12
+
13
+ index_name = "abstractive-question-answering"
14
+
15
+ # check if the abstractive-question-answering index exists
16
+ if index_name not in pinecone.list_indexes():
17
+ # create the index if it does not exist
18
+ pinecone.create_index(
19
+ index_name,
20
+ dimension=768,
21
+ metric="cosine"
22
+ )
23
+
24
+ # connect to abstractive-question-answering index we created
25
+ index = pinecone.Index(index_name)
26
+
27
+ import torch
28
+ from sentence_transformers import SentenceTransformer
29
+
30
+ # set device to GPU if available
31
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
32
+ # load the retriever model from huggingface model hub
33
+ retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base", device=device)
34
+
35
+ from transformers import BartTokenizer, BartForConditionalGeneration
36
+
37
+ # load bart tokenizer and model from huggingface
38
+ tokenizer = BartTokenizer.from_pretrained('vblagoje/bart_lfqa')
39
+ generator = BartForConditionalGeneration.from_pretrained('vblagoje/bart_lfqa').to('cpu')
40
+
41
+ def query_pinecone(query, top_k):
42
+ # generate embeddings for the query
43
+ xq = retriever.encode([query]).tolist()
44
+ # search pinecone index for context passage with the answer
45
+ xc = index.query(xq, top_k=top_k, include_metadata=True)
46
+ return xc
47
+
48
+ def format_query(query, context):
49
+ # extract passage_text from Pinecone search result and add the <P> tag
50
+ context = [f"<P> {m['metadata']['text']}" for m in context]
51
+ # concatinate all context passages
52
+ context = " ".join(context)
53
+ # contcatinate the query and context passages
54
+ query = f"question: {query} context: {context}"
55
+ return query
56
+
57
+ def generate_answer(query):
58
+ # tokenize the query to get input_ids
59
+ inputs = tokenizer([query], trunication=True, max_length=1024, return_tensors="pt")
60
+ # use generator to predict output ids
61
+ ids = generator.generate(inputs["input_ids"], num_beams=2, min_length=20, max_length=64)
62
+ # use tokenizer to decode the output ids
63
+ answer = tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
64
+ return pprint(answer)
65
+
66
+ query = st.text_area('Enter your question:')
67
+ s = st.button('Submit')
68
+ if s = True:
69
+ context = query_pinecone(query, top_k=5)
70
+ query = format_query(query, context["matches"])
71
+ generate_answer(query)