dataprincess commited on
Commit
e42c9fc
·
verified ·
1 Parent(s): d32a867

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -44
app.py CHANGED
@@ -8,61 +8,59 @@ from groq import Groq
8
  from tqdm.auto import tqdm
9
  import streamlit as st
10
 
 
 
 
 
 
 
 
 
 
11
  # Constants (hardcoded)
12
  FILE_PATH = "anjibot_chunks.json"
13
  BATCH_SIZE = 384
14
  INDEX_NAME = "groq-llama-3-rag"
15
- PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
16
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
17
- DIMENSIONS = 768
18
-
19
- # Load data once at the start
20
- data = load_data(FILE_PATH)
21
-
22
- # Initialize Pinecone and SentenceTransformer once
23
- index = initialize_pinecone(PINECONE_API_KEY, INDEX_NAME, DIMENSIONS)
24
  encoder = SentenceTransformer('dwzhu/e5-base-4k')
25
 
26
- def load_data(file_path: str) -> dict:
27
- with open(file_path, 'r') as file:
28
- return json.load(file)
29
 
30
- def initialize_pinecone(api_key: str, index_name: str, dims: int) -> any:
31
- pc = Pinecone(api_key=api_key)
32
- spec = ServerlessSpec(cloud="aws", region='us-east-1')
 
 
 
33
 
34
- existing_indexes = [index_info["name"] for index_info in pc.list_indexes()]
 
 
35
 
36
- # Check if index already exists; if not, create it
37
- if index_name not in existing_indexes:
38
- pc.create_index(index_name, dimension=dims, metric='cosine', spec=spec)
39
 
40
- # Wait for the index to be initialized
41
- while not pc.describe_index(index_name).status['ready']:
42
- time.sleep(1)
43
 
44
- return pc.Index(index_name)
 
45
 
46
- def upsert_data_to_pinecone(index: any, data: dict):
47
- for i in tqdm(range(0, len(data['id']), BATCH_SIZE)):
48
- # Find end of batch
49
- i_end = min(len(data['id']), i + BATCH_SIZE)
50
-
51
- # Create batch
52
- batch = {k: v[i:i_end] for k, v in data.items()}
53
-
54
- # Create embeddings
55
- chunks = [f'{x["title"]}: {x["content"]}' for x in batch["metadata"]]
56
- embeds = encoder.encode(chunks)
57
 
58
- # Ensure correct length
59
- assert len(embeds) == (i_end - i)
60
 
61
- # Upsert to Pinecone
62
- to_upsert = list(zip(batch["id"], embeds, batch["metadata"]))
63
- index.upsert(vectors=to_upsert)
64
 
65
- def get_docs(query: str, index: any, encoder: any, top_k: int) -> list[str]:
66
  xq = encoder.encode(query)
67
  res = index.query(vector=xq.tolist(), top_k=top_k, include_metadata=True)
68
  return [x["metadata"]['content'] for x in res["matches"]]
@@ -88,20 +86,22 @@ def get_response(query: str, docs: list[str], groq_client: any) -> str:
88
  )
89
  return chat_response.choices[0].message.content
90
 
 
 
91
  def handle_query(user_query: str):
92
- # Upsert data into Pinecone (if necessary)
93
- upsert_data_to_pinecone(index, data)
94
 
95
  # Initialize Groq client
96
  groq_client = Groq(api_key=GROQ_API_KEY)
97
 
98
  # Get relevant documents
99
- docs = get_docs(user_query, index, encoder, top_k=5)
100
 
101
  # Generate and return response
102
  response = get_response(user_query, docs, groq_client)
103
 
104
- return response
 
 
105
 
106
  def main():
107
  st.title("Ask Anjibot 2.0")
 
8
  from tqdm.auto import tqdm
9
  import streamlit as st
10
 
11
+ # Required imports
12
+ import json
13
+ import time
14
+ import os
15
+ from sentence_transformers import SentenceTransformer
16
+ from pinecone import Pinecone, ServerlessSpec
17
+ from groq import Groq
18
+ from tqdm.auto import tqdm
19
+
20
  # Constants (hardcoded)
21
  FILE_PATH = "anjibot_chunks.json"
22
  BATCH_SIZE = 384
23
  INDEX_NAME = "groq-llama-3-rag"
24
+ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") # Fixed syntax here
25
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY") # Fixed s
26
+ DIMS = 768
 
 
 
 
 
 
27
  encoder = SentenceTransformer('dwzhu/e5-base-4k')
28
 
29
+ with open(FILE_PATH, 'r') as file:
30
+ data= json.load(file)
 
31
 
32
+ pc = Pinecone(api_key=PINECONE_API_KEY)
33
+ spec = ServerlessSpec(cloud="aws", region='us-east-1')
34
+ existing_indexes = [index_info["name"] for index_info in pc.list_indexes()]
35
+ # Check if index already exists; if not, create it
36
+ if INDEX_NAME not in existing_indexes:
37
+ pc.create_index(INDEX_NAME, dimension=DIMS, metric='cosine', spec=spec)
38
 
39
+ # Wait for the index to be initialized
40
+ while not pc.describe_index(INDEX_NAME).status['ready']:
41
+ time.sleep(1)
42
 
43
+ index = pc.Index(INDEX_NAME)
 
 
44
 
45
+ for i in tqdm(range(0, len(data['id']), BATCH_SIZE)):
46
+ # Find end of batch
47
+ i_end = min(len(data['id']), i + BATCH_SIZE)
48
 
49
+ # Create batch
50
+ batch = {k: v[i:i_end] for k, v in data.items()}
51
 
52
+ # Create embeddings
53
+ chunks = [f'{x["title"]}: {x["content"]}' for x in batch["metadata"]]
54
+ embeds = encoder.encode(chunks)
 
 
 
 
 
 
 
 
55
 
56
+ # Ensure correct length
57
+ assert len(embeds) == (i_end - i)
58
 
59
+ # Upsert to Pinecone
60
+ to_upsert = list(zip(batch["id"], embeds, batch["metadata"]))
61
+ index.upsert(vectors=to_upsert)
62
 
63
+ def get_docs(query: str, top_k: int) -> list[str]:
64
  xq = encoder.encode(query)
65
  res = index.query(vector=xq.tolist(), top_k=top_k, include_metadata=True)
66
  return [x["metadata"]['content'] for x in res["matches"]]
 
86
  )
87
  return chat_response.choices[0].message.content
88
 
89
+
90
+
91
  def handle_query(user_query: str):
 
 
92
 
93
  # Initialize Groq client
94
  groq_client = Groq(api_key=GROQ_API_KEY)
95
 
96
  # Get relevant documents
97
+ docs = get_docs(user_query, top_k=5)
98
 
99
  # Generate and return response
100
  response = get_response(user_query, docs, groq_client)
101
 
102
+ for word in response.split():
103
+ yield word + " "
104
+ time.sleep(0.05)
105
 
106
  def main():
107
  st.title("Ask Anjibot 2.0")