File size: 7,217 Bytes
2af0eb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4779f10
2af0eb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# See README for more info on how the DataCollectionPipeline works
# The retrieval pipeline is part of the DataCollectionPipeline
import os
import sys
from operator import itemgetter

from clearml import PipelineDecorator
from dotenv import load_dotenv
from langchain.prompts import PromptTemplate
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.llms import Ollama
from qdrant_client import QdrantClient

# Setup ClearML
try:
    load_dotenv(override=True)
except Exception:
    load_dotenv(sys.path[1] + "/.env", override=True)
CLEARML_WEB_HOST = os.getenv("CLEARML_WEB_HOST")
CLEARML_API_HOST = os.getenv("CLEARML_API_HOST")
CLEARML_FILES_HOST = os.getenv("CLEARML_FILES_HOST")
CLEARML_API_ACCESS_KEY = os.getenv("CLEARML_API_ACCESS_KEY")
CLEARML_API_SECRET_KEY = os.getenv("CLEARML_API_SECRETKEY")


# Query expansion(I only generate one additional prompt for simplicity)
@PipelineDecorator.component(cache=False, return_values=["newQuery"])
def queryExpansion(query):
    # Setup the model
    MODEL = "llama3.2"
    try:
        load_dotenv(override=True)
    except Exception:
        load_dotenv(sys.path[1] + "/.env", override=True)
    USE_DOCKER = os.getenv("USE_DOCKER")
    if USE_DOCKER == "True":
        model = Ollama(model=MODEL, base_url="http://host.docker.internal:11434")
    else:
        model = Ollama(model=MODEL)

    template = """
    Rewrite the prompt. The new prompt must offer a different perspective.
    Do not change the meaning. Output only the rewritten prompt with no introduction.
        Prompt: {prompt}
    """
    prompt = PromptTemplate.from_template(template)
    chain = {"prompt": itemgetter("prompt")} | prompt | model
    return chain.invoke({"prompt": query})


# Self-querying(The metadata I will be generating determines whether to look through the Qdrant collection containing github code)
@PipelineDecorator.component(cache=False, return_values=["codingQuestion"])
def selfQuerying(query):
    # Setup the model
    MODEL = "llama3.2"
    try:
        load_dotenv(override=True)
    except Exception:
        load_dotenv(sys.path[1] + "/.env", override=True)
    USE_DOCKER = os.getenv("USE_DOCKER")
    if USE_DOCKER == "True":
        model = Ollama(model=MODEL, base_url="http://host.docker.internal:11434")
    else:
        model = Ollama(model=MODEL)

    template = """
    You are an AI assistant. You must determine if the prompt requires code as the answer.
    Output a 1 if it is or a 0 if it is not and nothing else.
        Prompt: {prompt}
    """
    prompt = PromptTemplate.from_template(template)
    chain = {"prompt": itemgetter("prompt")} | prompt | model
    return chain.invoke({"prompt": query})


# Filtered vector search for each of the N=2 queries after expansion
@PipelineDecorator.component(cache=False, return_values=["results1, results2"])
def filteredVectorSearch(query, newQuery, codingQuestion):
    # Create a qdrant connection
    try:
        load_dotenv(override=True)
    except Exception:
        load_dotenv(sys.path[1] + "/.env", override=True)
    USE_QDRANT_CLOUD = os.getenv("USE_QDRANT_CLOUD")
    QDRANT_CLOUD_URL = os.getenv("QDRANT_CLOUD_URL")
    QDRANT_APIKEY = os.getenv("QDRANT_APIKEY")
    if USE_QDRANT_CLOUD=="True":
        qClient = QdrantClient(url=QDRANT_CLOUD_URL, api_key=QDRANT_APIKEY)
    else:
        qClient = QdrantClient(url=QDRANT_CLOUD_URL)

    # Setup the text embedder
    MODEL = "llama3.2"
    try:
        load_dotenv(override=True)
    except Exception:
        load_dotenv(sys.path[1] + "/.env", override=True)
    USE_DOCKER = os.getenv("USE_DOCKER")
    if USE_DOCKER == "True":
        embeddingsModel = OllamaEmbeddings(model=MODEL, base_url="http://host.docker.internal:11434")
    else:
        embeddingsModel = OllamaEmbeddings(model=MODEL)

    # Search the related collection
    relatedCollection = "Document"
    if codingQuestion == "1":
        relatedCollection = "Github"
    results1 = qClient.search(
        collection_name=relatedCollection,
        query_vector=embeddingsModel.embed_query(query),
        limit=10,
    )
    results2 = qClient.search(
        collection_name=relatedCollection,
        query_vector=embeddingsModel.embed_query(newQuery),
        limit=10,
    )
    return results1, results2


# Collecting results
@PipelineDecorator.component(cache=False, return_values=["results"])
def collectingResults(results1, results2):
    return results1 + results2


# Reranking(Instead of using a CrossEncoder, I will manually compare embeddings)
@PipelineDecorator.component(cache=False, return_values=["topTexts"])
def reranking(results):
    ids = [result.id for result in results]
    scores = [result.score for result in results]
    topIds = []
    topIndexes = []
    for x in range(3):
        maxScore = 0
        maxIndex = 0
        for i in range(len(ids)):
            if ids[i] not in topIds and scores[i] > maxScore:
                maxScore = scores[i]
                maxIndex = i
        topIds.append(ids[maxIndex])
        topIndexes.append(maxIndex)
    texts = [result.payload["text"] for result in results]
    topTexts = ""
    for index in topIndexes:
        topTexts += texts[index]
    return topTexts


# Building prompt
@PipelineDecorator.component(cache=False, return_values=["prompt"])
def buildingPrompt(codingQuestion):
    if codingQuestion == "1":
        template = """
        Write code for the following question given the related coding document below.

        Document: {document}
        Question: {question}
        """
        return PromptTemplate.from_template(template)
    else:
        template = """
        Answer the question based on the document below. If you can't answer the question, reply "I don't know"

        Document: {document}
        Question: {question}
        """
        return PromptTemplate.from_template(template)


# Obtaining answer
@PipelineDecorator.component(cache=False, return_values=["answer"])
def obtainingAnswer(query, prompt, topTexts):
    # Setup the model
    MODEL = "llama3.2"
    try:
        load_dotenv(override=True)
    except Exception:
        load_dotenv(sys.path[1] + "/.env", override=True)
    USE_DOCKER = os.getenv("USE_DOCKER")
    if USE_DOCKER == "True":
        model = Ollama(model=MODEL, base_url="http://host.docker.internal:11434")
    else:
        model = Ollama(model=MODEL)

    chain = (
        {"document": itemgetter("document"), "question": itemgetter("question")}
        | prompt
        | model
    )
    chain.invoke({"document": topTexts, "question": query})


# Inference Pipeline
@PipelineDecorator.pipeline(
    name="Inference Pipeline",
    project="RAG LLM",
    version="0.1",
)
def main():
    # User query
    query = "What operating system was ROS written for?"
    newQuery = queryExpansion(query)
    codingQuestion = selfQuerying(query)
    results1, results2 = filteredVectorSearch(query, newQuery, codingQuestion)
    results = collectingResults(results1, results2)
    topTexts = reranking(results)
    prompt = buildingPrompt(codingQuestion)
    return obtainingAnswer(query, prompt, topTexts)


if __name__ == "__main__":
    PipelineDecorator.run_locally()
    main()