File size: 5,048 Bytes
bd37926
 
 
 
 
 
 
 
 
 
 
 
 
b092604
 
 
bd37926
 
 
 
 
b092604
 
 
 
 
bd37926
 
b092604
 
bd37926
 
 
b092604
bd37926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b092604
 
bd37926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b092604
bd37926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b092604
bd37926
 
 
b092604
 
bd37926
 
 
 
 
 
b092604
 
 
 
bd37926
b092604
bd37926
 
 
 
 
 
 
 
 
 
 
 
 
b092604
bd37926
b092604
bd37926
 
 
 
 
 
 
b092604
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
import tiktoken
import PyPDF2
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
import streamlit as st
from langchain.prompts import ChatPromptTemplate
from dotenv import load_dotenv




HUMAN_TEMPLATE = """

You are a helpful assistant who answers questions based on provided context. 
You must only use the provided context, and cannot use your own knowledge. 
If you do not know the answer, or it's not contained in the provided context response with "I don't know"

#Question:
{query}

#CONTEXT:
{context}
"""

def check_api_key():
    load_dotenv()
    """Verify that the API key is set and valid"""
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        raise ValueError("OpenAI API key not found in environment variables")
    return api_key

#Read PDF data
def read_pdf_data(pdf_file):
    try:
        pdf_reader = PyPDF2.PdfReader(pdf_file)
        text = ""
        for page in pdf_reader.pages:
            text += page.extract_text()
        if not text.strip():
            raise ValueError("No text extracted from PDF")
        return text
    except Exception as e:
        raise Exception(f"Error reading PDF: {str(e)}")

def tiktoken_len(text):
    try:
        tokens = tiktoken.encoding_for_model("gpt-4").encode(text)
        return len(tokens)
    except Exception as e:
        raise Exception(f"Error in token calculation: {str(e)}")

#Split data into chunks
def split_data(text):
    try:
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,  # Increased for better context
            chunk_overlap=200,  # Added overlap for better continuity
        )   
        chunks = text_splitter.split_text(text)
        if not chunks:
            raise ValueError("Text splitting produced no chunks")
        return chunks
    except Exception as e:
        raise Exception(f"Error splitting text: {str(e)}")

#Create embeddings instance

def create_embeddings():
    try:
        api_key = check_api_key()
        embedding_model = OpenAIEmbeddings(
            model="text-embedding-3-small",
            openai_api_key=api_key
        )
        return embedding_model
    except Exception as e:
        raise Exception(f"Error creating embeddings model: {str(e)}")


# Create a vector database using Qdrant
def create_vector_store(embedding_model, chunks):
    try:
        embedding_dim = 1536
        client = QdrantClient(":memory:")  # Consider using persistent storage for production
        
        # Create collection with error handling
        try:
            client.create_collection(
                collection_name="lcel_doc_v2",
                vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
            )
        except Exception as e:
            raise Exception(f"Error creating Qdrant collection: {str(e)}")

        vector_store = QdrantVectorStore(
            client=client,
            collection_name="lcel_doc_v2",
            embedding=embedding_model,
        )
        
        # Add texts with progress tracking
        try:
            _ = vector_store.add_texts(texts=chunks)
        except Exception as e:
            raise Exception(f"Error adding texts to vector store: {str(e)}")
            
        return vector_store
    except Exception as e:
        raise Exception(f"Error in vector store creation: {str(e)}")

# create RAG
def create_rag():
    try:
        api_key = check_api_key()
        openai_chat_model = ChatOpenAI(
            model="gpt-4o-mini",
            openai_api_key=api_key
        )
        
        chat_prompt = ChatPromptTemplate.from_messages([
            ("system", "You are a helpful assistant that answers questions based on the provided context."),
            ("human", HUMAN_TEMPLATE)
        ])
        if 'vector_store' in st.session_state:
            vector_store = st.session_state.vector_store        
        else:
            raise ValueError("Vector store not found in session state")
        
        retriever = vector_store.as_retriever(search_kwargs={"k": 5})
        
        simple_rag = (
            {"context": retriever, "query": RunnablePassthrough()}
            | chat_prompt
            | openai_chat_model
            | StrOutputParser() 
        ) 
        
        return simple_rag
    except Exception as e:
        raise Exception(f"Error creating RAG chain: {str(e)}")

# Invoke RAG
def invoke_rag(query):
    try:
        rag_chain = create_rag()
        response = rag_chain.invoke(query)
        return response
    except Exception as e:
        raise Exception(f"Error invoking RAG chain: {str(e)}")