Spaces:
Paused
Paused
File size: 1,353 Bytes
8c3a73e b4c442a 8c3a73e b4c442a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
# File: retrieval.py
from langchain_qdrant import Qdrant
from langchain_groq import ChatGroq
from langchain_openai import OpenAIEmbeddings
from langchain.chains import RetrievalQA
from config import *
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
llm = ChatGroq(model="llama3-70b-4096", temperature=0.3)
def rag_query(query: str) -> str:
logging.info(f"Processing query: {query}")
try:
qdrant = Qdrant.from_existing_collection(
embedding=embeddings,
collection_name=COLLECTION_NAME,
url=QDRANT_API_URL,
api_key=QDRANT_API_KEY,
prefer_grpc=True,
)
retriever = qdrant.as_retriever(search_kwargs={"k": 5})
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True
)
result = qa_chain({"query": query})
logging.info("Query processed successfully")
return result["result"]
except Exception as e:
error_message = f"Error processing query: {str(e)}"
logging.error(error_message)
return error_message |