File size: 4,388 Bytes
60d5c99
b2efd5e
 
ec4765e
b2efd5e
 
 
 
60d5c99
20768f0
60d5c99
f093d4b
b2efd5e
 
 
5cf9a77
 
e42c9fc
d32a867
5cf9a77
b2efd5e
e42c9fc
 
b2efd5e
e42c9fc
 
 
48b4cd8
e42c9fc
 
b2efd5e
e42c9fc
 
 
b2efd5e
e42c9fc
b2efd5e
e42c9fc
 
 
b2efd5e
e42c9fc
 
b2efd5e
e42c9fc
 
 
b2efd5e
e42c9fc
 
b2efd5e
e42c9fc
 
 
b2efd5e
20768f0
 
 
 
f093d4b
 
 
 
 
 
 
 
 
 
 
 
20768f0
f093d4b
20768f0
f093d4b
20768f0
f093d4b
 
 
 
 
 
 
b2efd5e
5cf9a77
b2efd5e
d32a867
 
b2efd5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9636641
b2efd5e
5cf9a77
b2efd5e
e42c9fc
 
 
60d5c99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Required imports
import json
import time
import os
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone, ServerlessSpec
from groq import Groq
from tqdm.auto import tqdm
import streamlit as st
import re

# Variables
FILE_PATH = "anjibot_chunks.json"
BATCH_SIZE = 384
INDEX_NAME = "groq-llama-3-rag"
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY") 
DIMS = 768
encoder = SentenceTransformer('dwzhu/e5-base-4k')
groq_client = Groq(api_key=GROQ_API_KEY)

with open(FILE_PATH, 'r') as file:
        data= json.load(file)

pc = Pinecone(api_key=PINECONE_API_KEY)
spec = ServerlessSpec(cloud="aws", region='us-east-1')
existing_indexes = [index_info["name"] for index_info in pc.list_indexes()]

if INDEX_NAME not in existing_indexes:
    pc.create_index(INDEX_NAME, dimension=DIMS, metric='cosine', spec=spec)

    # Wait for the index to be initialized
    while not pc.describe_index(INDEX_NAME).status['ready']:
        time.sleep(1)

index = pc.Index(INDEX_NAME)

for i in tqdm(range(0, len(data['id']), BATCH_SIZE)):
    # Find end of batch
    i_end = min(len(data['id']), i + BATCH_SIZE)

    # Create batch
    batch = {k: v[i:i_end] for k, v in data.items()}

    # Create embeddings
    chunks = [f'{x["title"]}: {x["content"]}' for x in batch["metadata"]]
    embeds = encoder.encode(chunks)

    # Ensure correct length
    assert len(embeds) == (i_end - i)

    # Upsert to Pinecone
    to_upsert = list(zip(batch["id"], embeds, batch["metadata"]))
    index.upsert(vectors=to_upsert)

def extract_course_code(text) -> list[str]:
    pattern = r'\b(?:geds?|stats?|maths?|cosc|seng|itgy)\s*\d{3}\b'
    match = re.findall(pattern, text, re.IGNORECASE)
    return match if match else None

def get_docs(query: str, top_k: int) -> list[str]:
    course_code = extract_course_code(query)
    exact_matches = []

    if course_code:
        course_code = [code.lower() for code in course_code]
        
        exact_matches = [
            x['content'] for x in data['metadata']
            if any(code in x['content'].lower() for code in course_code)
        ]
    
    remaining_slots = top_k - len(exact_matches)
    
    if remaining_slots > 0:
        xq = encoder.encode(query)
        res = index.query(vector=xq.tolist(), top_k=remaining_slots if exact_matches else top_k, include_metadata=True)
        
        embedding_matches = [x["metadata"]['content'] for x in res["matches"]]
        
        exact_matches.extend(embedding_matches)
    
    return exact_matches[:top_k]

def get_response(query: str, docs: list[str]) -> str:
    system_message = (
        "You are Anjibot, the AI course rep of 400 Level Computer Science department. You are always helpful, jovial, can be sarcastic but still sweet.\n"
        "Provide the answer to class-related queries using\n"
        "context provided below.\n"
        "If you don't the answer to the user's question based on your pretrained knowledge and the context provided, just direct the user to Anji the human course rep.\n"
        "Anji's phone number: 08145170886.\n\n"
        "CONTEXT:\n"
        "\n---\n".join(docs)
        )
    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": query}
    ]

    chat_response = groq_client.chat.completions.create(
        model="llama3-70b-8192",
        messages=messages
    )
    return chat_response.choices[0].message.content

def handle_query(user_query: str):

    docs = get_docs(user_query, top_k=5)

    response = get_response(user_query, docs=docs)

    for word in response.split():
            yield word + " "
            time.sleep(0.05)

def main():
    st.title("Ask Anjibot 2.0")

    if "messages" not in st.session_state:
        st.session_state.messages = []

    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    if prompt := st.chat_input("Ask me anything"):
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)

        with st.chat_message("assistant"):
            response = st.write_stream(handle_query(prompt))
        st.session_state.messages.append({"role": "assistant", "content": response})

if __name__ == "__main__":
    main()