Spaces:
Sleeping
Sleeping
Upload 27 files
Browse files- policy_rag/__init__.py +0 -0
- policy_rag/__pycache__/__init__.cpython-311.pyc +0 -0
- policy_rag/__pycache__/app_utils.cpython-311.pyc +0 -0
- policy_rag/__pycache__/chains.cpython-311.pyc +0 -0
- policy_rag/__pycache__/data_models.cpython-311.pyc +0 -0
- policy_rag/__pycache__/eval_utils.cpython-311.pyc +0 -0
- policy_rag/__pycache__/ragas_utils.cpython-311.pyc +0 -0
- policy_rag/__pycache__/sdg_utils.cpython-311.pyc +0 -0
- policy_rag/__pycache__/text_utils.cpython-311.pyc +0 -0
- policy_rag/__pycache__/vectorstore_utils.cpython-311.pyc +0 -0
- policy_rag/app_utils.py +55 -0
- policy_rag/chains.py +47 -0
- policy_rag/data_models.py +50 -0
- policy_rag/eval_utils.py +68 -0
- policy_rag/metrics/__init__.py +4 -0
- policy_rag/metrics/__pycache__/__init__.cpython-311.pyc +0 -0
- policy_rag/metrics/__pycache__/_answer_relevancy.cpython-311.pyc +0 -0
- policy_rag/metrics/__pycache__/_context_precision.cpython-311.pyc +0 -0
- policy_rag/metrics/__pycache__/_context_recall.cpython-311.pyc +0 -0
- policy_rag/metrics/__pycache__/_faithfulness.cpython-311.pyc +0 -0
- policy_rag/metrics/_answer_relevancy.py +126 -0
- policy_rag/metrics/_context_precision.py +75 -0
- policy_rag/metrics/_context_recall.py +95 -0
- policy_rag/metrics/_faithfulness.py +98 -0
- policy_rag/sdg_utils.py +68 -0
- policy_rag/text_utils.py +60 -0
- policy_rag/vectorstore_utils.py +115 -0
policy_rag/__init__.py
ADDED
File without changes
|
policy_rag/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (164 Bytes). View file
|
|
policy_rag/__pycache__/app_utils.cpython-311.pyc
ADDED
Binary file (2.17 kB). View file
|
|
policy_rag/__pycache__/chains.cpython-311.pyc
ADDED
Binary file (1.9 kB). View file
|
|
policy_rag/__pycache__/data_models.cpython-311.pyc
ADDED
Binary file (3.09 kB). View file
|
|
policy_rag/__pycache__/eval_utils.cpython-311.pyc
ADDED
Binary file (2.83 kB). View file
|
|
policy_rag/__pycache__/ragas_utils.cpython-311.pyc
ADDED
Binary file (1.52 kB). View file
|
|
policy_rag/__pycache__/sdg_utils.cpython-311.pyc
ADDED
Binary file (2.98 kB). View file
|
|
policy_rag/__pycache__/text_utils.cpython-311.pyc
ADDED
Binary file (3.21 kB). View file
|
|
policy_rag/__pycache__/vectorstore_utils.cpython-311.pyc
ADDED
Binary file (6.4 kB). View file
|
|
policy_rag/app_utils.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
load_dotenv()
|
4 |
+
|
5 |
+
from typing import Dict, Tuple
|
6 |
+
from collections.abc import Callable
|
7 |
+
|
8 |
+
from langchain_openai import OpenAIEmbeddings
|
9 |
+
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
10 |
+
|
11 |
+
from policy_rag.text_utils import get_recursive_token_chunks, get_semantic_chunks
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
# Config Options
|
16 |
+
CHUNK_METHOD = {
|
17 |
+
'token-overlap': get_recursive_token_chunks,
|
18 |
+
'semantic': get_semantic_chunks
|
19 |
+
}
|
20 |
+
|
21 |
+
EMBEDDING_MODEL_SOURCE = {
|
22 |
+
'openai': OpenAIEmbeddings,
|
23 |
+
'huggingface': HuggingFaceInferenceAPIEmbeddings
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
# Helpers
|
28 |
+
def get_chunk_func(chunk_method: Dict) -> Tuple[Callable, Dict]:
|
29 |
+
chunk_func = CHUNK_METHOD[chunk_method['method']]
|
30 |
+
|
31 |
+
if chunk_method['method'] == 'token-overlap':
|
32 |
+
chunk_func_args = chunk_method['args']
|
33 |
+
|
34 |
+
if chunk_method['method'] == 'semantic':
|
35 |
+
args = chunk_method['args']
|
36 |
+
chunk_func_args = {
|
37 |
+
'embedding_model': EMBEDDING_MODEL_SOURCE[args['model_source']](model=args['model_name']),
|
38 |
+
'breakpoint_type': args['breakpoint_type']
|
39 |
+
}
|
40 |
+
|
41 |
+
return chunk_func, chunk_func_args
|
42 |
+
|
43 |
+
|
44 |
+
def get_embedding_model(config) -> OpenAIEmbeddings | HuggingFaceInferenceAPIEmbeddings:
|
45 |
+
if config['model_source'] == 'openai':
|
46 |
+
model = EMBEDDING_MODEL_SOURCE[config['model_source']](model=config['model_name'])
|
47 |
+
|
48 |
+
if config['model_source'] == 'huggingface':
|
49 |
+
model = EMBEDDING_MODEL_SOURCE[config['model_source']](
|
50 |
+
api_key=os.getenv('HF_API_KEY'),
|
51 |
+
model_name=config['model_name'],
|
52 |
+
api_url=config['api_url']
|
53 |
+
)
|
54 |
+
|
55 |
+
return model
|
policy_rag/chains.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from operator import itemgetter
|
2 |
+
|
3 |
+
from langchain_openai import ChatOpenAI
|
4 |
+
from langchain.chains.base import Chain
|
5 |
+
from langchain_core.output_parsers import StrOutputParser
|
6 |
+
from langchain_core.runnables import RunnablePassthrough
|
7 |
+
from langchain.prompts import ChatPromptTemplate
|
8 |
+
from langchain_core.vectorstores import VectorStoreRetriever
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
def get_qa_chain(
|
13 |
+
retriever: VectorStoreRetriever,
|
14 |
+
streaming: bool = False
|
15 |
+
) -> Chain:
|
16 |
+
template = """
|
17 |
+
Answer any questions based solely on the context below. If the context
|
18 |
+
doesn't provide the answer, still do your best to answer the question
|
19 |
+
factually, but indicate there isn't a clear answer in the context
|
20 |
+
and that you're giving a best-effort response.
|
21 |
+
|
22 |
+
Question:
|
23 |
+
{question}
|
24 |
+
|
25 |
+
Context:
|
26 |
+
{context}
|
27 |
+
"""
|
28 |
+
primary_qa_llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0, streaming=streaming)
|
29 |
+
|
30 |
+
prompt = ChatPromptTemplate.from_template(template)
|
31 |
+
|
32 |
+
retrieval_augmented_qa_chain = (
|
33 |
+
# INVOKE CHAIN WITH: {"question" : "<<SOME USER QUESTION>>"}
|
34 |
+
# "question" : populated by getting the value of the "question" key
|
35 |
+
# "context" : populated by getting the value of the "question" key and chaining it into the base_retriever
|
36 |
+
{"context": itemgetter("question") | retriever, "question": itemgetter("question")}
|
37 |
+
# "context" : is assigned to a RunnablePassthrough object (will not be called or considered in the next step)
|
38 |
+
# by getting the value of the "context" key from the previous step
|
39 |
+
| RunnablePassthrough.assign(context=itemgetter("context"))
|
40 |
+
# "answer" : the "context" and "question" values are used to format our prompt object and then piped
|
41 |
+
# into the LLM and stored in a key called "answer": NOTE: Key MUST be "answer" for LangSmith.
|
42 |
+
# "contexts" : populated by getting the value of the "context" key from the previous step.
|
43 |
+
# NOTE: Key must be "contexts" for LangSmith
|
44 |
+
| {"answer": prompt | primary_qa_llm, "contexts": itemgetter("context")}
|
45 |
+
)
|
46 |
+
|
47 |
+
return retrieval_augmented_qa_chain
|
policy_rag/data_models.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, RootModel, field_validator
|
2 |
+
from langchain_core.documents.base import Document
|
3 |
+
from typing import List, Dict
|
4 |
+
from uuid import UUID
|
5 |
+
|
6 |
+
|
7 |
+
class DocList(RootModel[List[Document]]):
|
8 |
+
model_config = {'validate_assignment': True}
|
9 |
+
|
10 |
+
|
11 |
+
class QuestionObject(RootModel[Dict[str, str]]):
|
12 |
+
model_config = {'validate_assignment': True}
|
13 |
+
|
14 |
+
@field_validator('root')
|
15 |
+
def validate_key_is_uuid(cls, value):
|
16 |
+
for key in value.keys():
|
17 |
+
try:
|
18 |
+
u = UUID(key)
|
19 |
+
if u.version != 4:
|
20 |
+
raise ValueError(f"{key} is not UUID v4")
|
21 |
+
except ValueError as e:
|
22 |
+
raise ValueError(f"{key} is not UUID v4")
|
23 |
+
return value
|
24 |
+
|
25 |
+
|
26 |
+
class ContextObject(RootModel[Dict[str, List[str]]]):
|
27 |
+
model_config = {'validate_assignment': True}
|
28 |
+
|
29 |
+
@field_validator('root')
|
30 |
+
def validate_key_is_uuid(cls, value):
|
31 |
+
for key in value.keys():
|
32 |
+
try:
|
33 |
+
u = UUID(key)
|
34 |
+
if u.version != 4:
|
35 |
+
raise ValueError(f"{key} is not UUID v4")
|
36 |
+
except ValueError as e:
|
37 |
+
raise ValueError(f"{key} is not UUID v4")
|
38 |
+
return value
|
39 |
+
|
40 |
+
@field_validator('root')
|
41 |
+
def validate_values_are_uuid(cls, value):
|
42 |
+
for key, val in value.items():
|
43 |
+
for v in val:
|
44 |
+
try:
|
45 |
+
u = UUID(v)
|
46 |
+
if u.version != 4:
|
47 |
+
raise ValueError(f"{key} is not UUID v4")
|
48 |
+
except:
|
49 |
+
raise ValueError(f"{key} is not UUID v4")
|
50 |
+
return value
|
policy_rag/eval_utils.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
load_dotenv()
|
4 |
+
|
5 |
+
from typing import List, Any
|
6 |
+
from langsmith import Client
|
7 |
+
from langsmith.evaluation import evaluate
|
8 |
+
|
9 |
+
from langchain_core.vectorstores import VectorStoreRetriever
|
10 |
+
import pandas as pd
|
11 |
+
import uuid
|
12 |
+
|
13 |
+
from policy_rag.chains import get_qa_chain
|
14 |
+
from policy_rag.metrics import (
|
15 |
+
faithfulness,
|
16 |
+
answer_relevancy,
|
17 |
+
context_precision,
|
18 |
+
context_recall
|
19 |
+
)
|
20 |
+
|
21 |
+
METRICS = {
|
22 |
+
'faithfulness': faithfulness,
|
23 |
+
'answer_relevancy': answer_relevancy,
|
24 |
+
'context_precision': context_precision,
|
25 |
+
'context_recall': context_recall
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
def get_ls_dataset(ls_dataset_name: str) -> pd.DataFrame:
|
30 |
+
client = Client()
|
31 |
+
examples = client.list_examples(dataset_name=ls_dataset_name)
|
32 |
+
rows = [row.outputs | row.inputs | {'id': str(row.id)} for row in examples]
|
33 |
+
return pd.DataFrame(rows)
|
34 |
+
|
35 |
+
|
36 |
+
# Get RAG QA Chain
|
37 |
+
def eval_on_ls_dataset(
|
38 |
+
metrics: List[str],
|
39 |
+
retriever: VectorStoreRetriever,
|
40 |
+
ls_dataset_name: str,
|
41 |
+
ls_project_name: str,
|
42 |
+
ls_experiment_name: str
|
43 |
+
):
|
44 |
+
os.environ['LANGCHAIN_PROJECT'] = ls_project_name
|
45 |
+
|
46 |
+
print('Getting RAG QA Chain')
|
47 |
+
rag_qa_chain = get_qa_chain(retriever=retriever)
|
48 |
+
|
49 |
+
# Get LS Dataset and Eval Dataset
|
50 |
+
#print('Getting Test Set from LangSmith')
|
51 |
+
#test_df = get_ls_dataset(ls_dataset_name)
|
52 |
+
#test_questions = test_df['question'].to_list()
|
53 |
+
#test_groundtruths = test_df['ground_truth'].to_list()
|
54 |
+
|
55 |
+
# Evaluate
|
56 |
+
print('Running Experiment in LangSmith')
|
57 |
+
print(f'Evaluating {metrics}')
|
58 |
+
|
59 |
+
client = Client(auto_batch_tracing=False)
|
60 |
+
results = evaluate(
|
61 |
+
rag_qa_chain.invoke,
|
62 |
+
data=ls_dataset_name,
|
63 |
+
evaluators=[METRICS[metric] for metric in metrics],
|
64 |
+
experiment_prefix=ls_experiment_name,
|
65 |
+
client=client
|
66 |
+
)
|
67 |
+
|
68 |
+
return results
|
policy_rag/metrics/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ._faithfulness import faithfulness
|
2 |
+
from ._answer_relevancy import answer_relevancy
|
3 |
+
from ._context_precision import context_precision
|
4 |
+
from ._context_recall import context_recall
|
policy_rag/metrics/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (456 Bytes). View file
|
|
policy_rag/metrics/__pycache__/_answer_relevancy.cpython-311.pyc
ADDED
Binary file (6.25 kB). View file
|
|
policy_rag/metrics/__pycache__/_context_precision.cpython-311.pyc
ADDED
Binary file (3.54 kB). View file
|
|
policy_rag/metrics/__pycache__/_context_recall.cpython-311.pyc
ADDED
Binary file (4.74 kB). View file
|
|
policy_rag/metrics/__pycache__/_faithfulness.cpython-311.pyc
ADDED
Binary file (4.74 kB). View file
|
|
policy_rag/metrics/_answer_relevancy.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
load_dotenv()
|
3 |
+
import json
|
4 |
+
from typing import List, Tuple
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from langsmith.schemas import Example, Run
|
8 |
+
from pydantic import BaseModel, Field
|
9 |
+
|
10 |
+
from langchain_core.prompts import ChatPromptTemplate
|
11 |
+
from langchain_openai import ChatOpenAI
|
12 |
+
from langchain_core.output_parsers import PydanticToolsParser
|
13 |
+
from langchain_openai import OpenAIEmbeddings
|
14 |
+
|
15 |
+
|
16 |
+
class VariantQuestionAnswerCommittal(BaseModel):
|
17 |
+
"""Use to generate a question based on the given answer
|
18 |
+
and determine if the answer is noncommittal."""
|
19 |
+
|
20 |
+
question: str = Field(description="The generated question based on the given answer.")
|
21 |
+
noncommittal: bool = Field(description="The judgement of if the answer is noncommittal.")
|
22 |
+
|
23 |
+
|
24 |
+
def cosine_similarity_np(embedding_a, embedding_b):
|
25 |
+
"""
|
26 |
+
Calculate the cosine similarity between two vectors using numpy.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
- embedding_a (np.array): First embedding vector.
|
30 |
+
- embedding_b (np.array): Second embedding vector.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
- float: Cosine similarity value.
|
34 |
+
"""
|
35 |
+
# Normalize the embeddings to avoid division by zero
|
36 |
+
norm_a = np.linalg.norm(embedding_a)
|
37 |
+
norm_b = np.linalg.norm(embedding_b)
|
38 |
+
|
39 |
+
# Compute cosine similarity
|
40 |
+
cosine_sim = np.dot(embedding_a, embedding_b) / (norm_a * norm_b)
|
41 |
+
return cosine_sim
|
42 |
+
|
43 |
+
|
44 |
+
def mean_cosine_similarity(embeddings_list, reference_embedding):
|
45 |
+
"""
|
46 |
+
Calculate the mean cosine similarity of a list of embeddings to a reference embedding.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
- embeddings_list (list of np.array): A list of embeddings.
|
50 |
+
- reference_embedding (np.array): The reference embedding to which the cosine similarity is calculated.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
- float: The mean cosine similarity value.
|
54 |
+
"""
|
55 |
+
similarities = []
|
56 |
+
|
57 |
+
for embedding in embeddings_list:
|
58 |
+
# Calculate cosine similarity using numpy
|
59 |
+
sim = cosine_similarity_np(reference_embedding, embedding)
|
60 |
+
similarities.append(sim)
|
61 |
+
|
62 |
+
# Return the mean of the cosine similarities
|
63 |
+
return np.mean(similarities)
|
64 |
+
|
65 |
+
|
66 |
+
def calculate_similarity(question: str, generated_questions: list[str]) -> float:
|
67 |
+
embeddings = OpenAIEmbeddings(model='text-embedding-3-large')
|
68 |
+
question_vec = np.asarray(embeddings.embed_query(question)).reshape(1, -1)
|
69 |
+
gen_question_vec = np.asarray(
|
70 |
+
embeddings.embed_documents(generated_questions)
|
71 |
+
).reshape(len(generated_questions), -1)
|
72 |
+
norm = np.linalg.norm(gen_question_vec, axis=1) * np.linalg.norm(
|
73 |
+
question_vec, axis=1
|
74 |
+
)
|
75 |
+
|
76 |
+
return np.mean((np.dot(gen_question_vec, question_vec.T).reshape(-1,) / norm))
|
77 |
+
|
78 |
+
|
79 |
+
def generate_questions(answer: str) -> Tuple[str, bool]:
|
80 |
+
template = """
|
81 |
+
Generate a question for the given answer and identify if answer is noncommittal.
|
82 |
+
Give noncommittal as True if the answer is noncommittal and False if the answer is committal.
|
83 |
+
A noncommittal answer is one that is evasive, vague, or ambiguous.
|
84 |
+
For example, "I don't know" or "I'm not sure" are noncommittal answers.
|
85 |
+
|
86 |
+
Answer:
|
87 |
+
{answer}
|
88 |
+
"""
|
89 |
+
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
|
90 |
+
|
91 |
+
prompt = ChatPromptTemplate.from_template(template)
|
92 |
+
|
93 |
+
tools = [VariantQuestionAnswerCommittal]
|
94 |
+
|
95 |
+
chain = (
|
96 |
+
prompt
|
97 |
+
| llm.bind_tools(tools)
|
98 |
+
| PydanticToolsParser(tools=tools)
|
99 |
+
)
|
100 |
+
|
101 |
+
res = chain.invoke({'answer': answer})[0]
|
102 |
+
question = res.question
|
103 |
+
noncommittal = res.noncommittal
|
104 |
+
|
105 |
+
return question, noncommittal
|
106 |
+
|
107 |
+
|
108 |
+
def answer_relevancy(run: Run, example: Example) -> dict:
|
109 |
+
# Assumes your RAG app includes the prediction in the "output" key in its response
|
110 |
+
answer: str = run.outputs["answer"].content
|
111 |
+
o_question: str = example.inputs['question']
|
112 |
+
|
113 |
+
# Get generated question variants based on chain answer
|
114 |
+
questions, noncommittals = [], []
|
115 |
+
for _ in range(3):
|
116 |
+
question, noncommittal = generate_questions(answer)
|
117 |
+
|
118 |
+
if noncommittal:
|
119 |
+
return {"key": "Answer Relevancy", "score": 0}
|
120 |
+
|
121 |
+
questions.append(question)
|
122 |
+
noncommittals.append(noncommittal)
|
123 |
+
|
124 |
+
relevancy_score = calculate_similarity(o_question, questions)
|
125 |
+
|
126 |
+
return {"key": "Answer Relevancy", "score": relevancy_score}
|
policy_rag/metrics/_context_precision.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
load_dotenv()
|
3 |
+
import json
|
4 |
+
from typing import List, Tuple
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from langsmith.schemas import Example, Run
|
8 |
+
from pydantic import BaseModel, Field
|
9 |
+
|
10 |
+
from langchain_core.prompts import ChatPromptTemplate
|
11 |
+
from langchain_openai import ChatOpenAI
|
12 |
+
from langchain_core.output_parsers import PydanticToolsParser
|
13 |
+
|
14 |
+
|
15 |
+
class ContextPrecisionVerification(BaseModel):
|
16 |
+
"""Answer for the verification task wether the context was useful."""
|
17 |
+
|
18 |
+
verdict: int = Field(..., description="Binary (0/1) verdict of verification")
|
19 |
+
|
20 |
+
|
21 |
+
def verify_context_precision(
|
22 |
+
question: str,
|
23 |
+
answer: str,
|
24 |
+
context: str
|
25 |
+
) -> int:
|
26 |
+
template = """
|
27 |
+
Given Question, Answer, and Context below, verify if the Context was useful in arriving at the given Answer.
|
28 |
+
|
29 |
+
Question:
|
30 |
+
{question}
|
31 |
+
|
32 |
+
Answer:
|
33 |
+
{answer}
|
34 |
+
|
35 |
+
Context:
|
36 |
+
{context}
|
37 |
+
"""
|
38 |
+
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
|
39 |
+
|
40 |
+
prompt = ChatPromptTemplate.from_template(template)
|
41 |
+
|
42 |
+
tools = [ContextPrecisionVerification]
|
43 |
+
|
44 |
+
chain = (
|
45 |
+
prompt
|
46 |
+
| llm.bind_tools(tools)
|
47 |
+
| PydanticToolsParser(tools=tools)
|
48 |
+
)
|
49 |
+
|
50 |
+
res = chain.invoke({'question': question, 'answer': answer, 'context': context})[0]
|
51 |
+
|
52 |
+
return res.verdict
|
53 |
+
|
54 |
+
|
55 |
+
def context_precision(run: Run, example: Example) -> dict:
|
56 |
+
question: str = example.inputs['question']
|
57 |
+
ground_truth: str = example.outputs["ground_truth"]
|
58 |
+
contexts: List[str] = [context.page_content for context in run.outputs['contexts']]
|
59 |
+
|
60 |
+
# Verify if the context was relevant / useful to the generated answer.
|
61 |
+
verdicts = []
|
62 |
+
for context in contexts:
|
63 |
+
verdict = verify_context_precision(question, ground_truth, context)
|
64 |
+
verdicts.append(verdict)
|
65 |
+
|
66 |
+
# Calculate Precsions@k for each context chunk
|
67 |
+
precisions_at_k = []
|
68 |
+
for idx, verdict in enumerate(verdicts):
|
69 |
+
k = idx+1
|
70 |
+
precision_at_k = verdict/k
|
71 |
+
precisions_at_k.append(precision_at_k)
|
72 |
+
|
73 |
+
context_precision_score = sum(precisions_at_k) / (sum(verdicts) + 1e-10)
|
74 |
+
|
75 |
+
return {"key": "Context Precision", "score": context_precision_score}
|
policy_rag/metrics/_context_recall.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
load_dotenv()
|
3 |
+
import json
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from langsmith.schemas import Example, Run
|
7 |
+
from pydantic import BaseModel, Field
|
8 |
+
|
9 |
+
from langchain_core.documents.base import Document
|
10 |
+
from langchain_core.prompts import ChatPromptTemplate
|
11 |
+
from langchain_openai import ChatOpenAI
|
12 |
+
from langchain_core.output_parsers import PydanticToolsParser
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
class Statements(BaseModel):
|
17 |
+
"""Use to record each statement in the answer."""
|
18 |
+
|
19 |
+
statements: List[str] = Field(description="The statements found in the text.")
|
20 |
+
|
21 |
+
|
22 |
+
class ContextRecallAttribution(BaseModel):
|
23 |
+
"""Use to determine if a statement can be attributed to the context."""
|
24 |
+
|
25 |
+
attributed: int = Field(..., description="Binary (0/1) verdict of whether statement can be attributed to context.")
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
def extract_statements(ground_truth: str) -> List[str]:
|
30 |
+
template = """
|
31 |
+
Extract all statements from the Text below. Record each statement as
|
32 |
+
a self-contained logical sentence that can be used to verify attribution
|
33 |
+
later.
|
34 |
+
|
35 |
+
Text:
|
36 |
+
{ground_truth}
|
37 |
+
"""
|
38 |
+
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
|
39 |
+
|
40 |
+
prompt = ChatPromptTemplate.from_template(template)
|
41 |
+
|
42 |
+
tools = [Statements]
|
43 |
+
|
44 |
+
chain = (
|
45 |
+
prompt
|
46 |
+
| llm.bind_tools(tools)
|
47 |
+
| PydanticToolsParser(tools=tools)
|
48 |
+
)
|
49 |
+
|
50 |
+
return chain.invoke({'ground_truth': ground_truth})[0].statements
|
51 |
+
|
52 |
+
|
53 |
+
def get_statement_attribution(statement: str, formatted_docs: str) -> List[str]:
|
54 |
+
template = """
|
55 |
+
Given a Statement and a Context, classify if the Statement can be attributed
|
56 |
+
to the Context or not. Use only (1) or (0) as a binary classification.
|
57 |
+
|
58 |
+
Statement: {statement}
|
59 |
+
|
60 |
+
Context:
|
61 |
+
{formatted_docs}
|
62 |
+
"""
|
63 |
+
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
|
64 |
+
|
65 |
+
prompt = ChatPromptTemplate.from_template(template)
|
66 |
+
|
67 |
+
tools = [ContextRecallAttribution]
|
68 |
+
|
69 |
+
chain = (
|
70 |
+
prompt
|
71 |
+
| llm.bind_tools(tools)
|
72 |
+
| PydanticToolsParser(tools=tools)
|
73 |
+
)
|
74 |
+
|
75 |
+
res = chain.invoke({'statement': statement, 'formatted_docs': formatted_docs})
|
76 |
+
attributed = res[0].attributed
|
77 |
+
|
78 |
+
return attributed
|
79 |
+
|
80 |
+
|
81 |
+
def context_recall(run: Run, example: Example) -> dict:
|
82 |
+
ground_truth: str = example.outputs["ground_truth"]
|
83 |
+
retrieved_docs: List[Document] = run.outputs["contexts"]
|
84 |
+
formatted_docs: str = "\n".join([doc.page_content for doc in retrieved_docs])
|
85 |
+
|
86 |
+
statements = extract_statements(ground_truth)
|
87 |
+
|
88 |
+
attributions = []
|
89 |
+
for statement in statements:
|
90 |
+
attribution = get_statement_attribution(statement, formatted_docs)
|
91 |
+
attributions.append(attribution)
|
92 |
+
|
93 |
+
context_recall_score = sum(attributions) / len(attributions) if attributions else None
|
94 |
+
|
95 |
+
return {"key": "Context Recall", "score": context_recall_score}
|
policy_rag/metrics/_faithfulness.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
load_dotenv()
|
3 |
+
import json
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from langsmith.schemas import Example, Run
|
7 |
+
from pydantic import BaseModel, Field
|
8 |
+
|
9 |
+
from langchain_core.documents.base import Document
|
10 |
+
from langchain_core.prompts import ChatPromptTemplate
|
11 |
+
from langchain_openai import ChatOpenAI
|
12 |
+
from langchain_core.output_parsers import PydanticToolsParser
|
13 |
+
|
14 |
+
|
15 |
+
class Propositions(BaseModel):
|
16 |
+
"""Use to record each factual assertion."""
|
17 |
+
|
18 |
+
propositions: List[str] = Field(description="The factual propositions generated by the model")
|
19 |
+
|
20 |
+
|
21 |
+
class FaithfulnessScore(BaseModel):
|
22 |
+
"""Use to score how faithful the propositions are to the docs."""
|
23 |
+
|
24 |
+
reasoning: str = Field(description="The reasoning for the faithfulness score")
|
25 |
+
score: bool
|
26 |
+
|
27 |
+
|
28 |
+
def extract_propositions(text: str) -> List[str]:
|
29 |
+
template = """
|
30 |
+
Extract all factual statements from the following Text:
|
31 |
+
|
32 |
+
Text:
|
33 |
+
{text}
|
34 |
+
"""
|
35 |
+
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
|
36 |
+
|
37 |
+
prompt = ChatPromptTemplate.from_template(template)
|
38 |
+
|
39 |
+
tools = [Propositions]
|
40 |
+
|
41 |
+
chain = (
|
42 |
+
prompt
|
43 |
+
| llm.bind_tools(tools)
|
44 |
+
| PydanticToolsParser(tools=tools)
|
45 |
+
)
|
46 |
+
|
47 |
+
return chain.invoke({'text': text})[0].propositions
|
48 |
+
|
49 |
+
|
50 |
+
def get_faithfulness_score(proposition: str, formatted_docs: str) -> List[str]:
|
51 |
+
template = """
|
52 |
+
Grade whether the Proposition can be logically concluded
|
53 |
+
from the Docs:
|
54 |
+
|
55 |
+
Proposition: {proposition}
|
56 |
+
|
57 |
+
Docs:
|
58 |
+
{formatted_docs}
|
59 |
+
"""
|
60 |
+
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
|
61 |
+
|
62 |
+
prompt = ChatPromptTemplate.from_template(template)
|
63 |
+
|
64 |
+
tools = [FaithfulnessScore]
|
65 |
+
|
66 |
+
chain = (
|
67 |
+
prompt
|
68 |
+
| llm.bind_tools(tools)
|
69 |
+
| PydanticToolsParser(tools=tools)
|
70 |
+
)
|
71 |
+
|
72 |
+
res = chain.invoke({'proposition': proposition, 'formatted_docs': formatted_docs})
|
73 |
+
score = res[0].score
|
74 |
+
reasoning = res[0].reasoning
|
75 |
+
|
76 |
+
return score, reasoning
|
77 |
+
|
78 |
+
|
79 |
+
def faithfulness(run: Run, example: Example) -> dict:
|
80 |
+
# Assumes your RAG app includes the prediction in the "output" key in its response
|
81 |
+
response: str = run.outputs["answer"].content
|
82 |
+
# Assumes your RAG app includes the retrieved docs as a "context" key in the outputs
|
83 |
+
# If not, you can fetch from the child_runs of the run object
|
84 |
+
retrieved_docs: List[Document] = run.outputs["contexts"]
|
85 |
+
formatted_docs = "\n".join([doc.page_content for doc in retrieved_docs])
|
86 |
+
|
87 |
+
propositions = extract_propositions(response)
|
88 |
+
|
89 |
+
scores, reasoning = [], []
|
90 |
+
for proposition in propositions:
|
91 |
+
score, reason = get_faithfulness_score(proposition, formatted_docs)
|
92 |
+
scores.append(score)
|
93 |
+
reasoning.append(reason)
|
94 |
+
|
95 |
+
average_score = sum(scores) / len(scores) if scores else None
|
96 |
+
comment = "\n".join(reasoning)
|
97 |
+
|
98 |
+
return {"key": "faithfulness", "score": average_score, "comment": comment}
|
policy_rag/sdg_utils.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from ragas.testset.generator import TestsetGenerator
|
4 |
+
from ragas.testset.generator import TestDataset
|
5 |
+
from ragas.testset.evolutions import simple, reasoning, multi_context
|
6 |
+
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
7 |
+
from langchain_core.documents.base import Document
|
8 |
+
from typing import List
|
9 |
+
|
10 |
+
|
11 |
+
from langsmith import Client
|
12 |
+
from pandas import DataFrame
|
13 |
+
import asyncio
|
14 |
+
|
15 |
+
|
16 |
+
async def ragas_sdg(
|
17 |
+
context_docs: List[Document],
|
18 |
+
n_qa_pairs: int = 20,
|
19 |
+
embedding_model: OpenAIEmbeddings = OpenAIEmbeddings(model='text-embedding-3-large')
|
20 |
+
) -> TestDataset:
|
21 |
+
generator_llm = ChatOpenAI(model="gpt-4o")
|
22 |
+
critic_llm = ChatOpenAI(model="gpt-4o-mini")
|
23 |
+
embeddings = embedding_model
|
24 |
+
|
25 |
+
generator = TestsetGenerator.from_langchain(
|
26 |
+
generator_llm,
|
27 |
+
critic_llm,
|
28 |
+
embeddings
|
29 |
+
)
|
30 |
+
|
31 |
+
distributions = {
|
32 |
+
simple: 0.5,
|
33 |
+
multi_context: 0.25,
|
34 |
+
reasoning: 0.25
|
35 |
+
}
|
36 |
+
|
37 |
+
test_set = generator.generate_with_langchain_docs(context_docs, n_qa_pairs, distributions)
|
38 |
+
|
39 |
+
return test_set
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
def upload_dataset_langsmith(
|
44 |
+
dataset: TestDataset | DataFrame,
|
45 |
+
dataset_name: str,
|
46 |
+
description: str
|
47 |
+
) -> None:
|
48 |
+
client = Client()
|
49 |
+
|
50 |
+
ls_dataset = client.create_dataset(
|
51 |
+
dataset_name=dataset_name, description=description
|
52 |
+
)
|
53 |
+
|
54 |
+
# TODO: implement a Pydantic model to validate input dataset
|
55 |
+
if type(dataset) == TestDataset:
|
56 |
+
dataset_df = dataset.to_pandas()
|
57 |
+
elif type(dataset) == DataFrame:
|
58 |
+
dataset_df = dataset
|
59 |
+
else:
|
60 |
+
raise TypeError('Dataset must be ragas TestDataset or pandas DataFrame')
|
61 |
+
|
62 |
+
for idx, row in dataset_df.iterrows():
|
63 |
+
client.create_example(
|
64 |
+
inputs={"question" : row["question"], "context": row["contexts"]},
|
65 |
+
outputs={"ground_truth" : row["ground_truth"]},
|
66 |
+
metadata={'metadata': row['metadata'][0], "evolution_type": row['evolution_type']},
|
67 |
+
dataset_id=ls_dataset.id
|
68 |
+
)
|
policy_rag/text_utils.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
4 |
+
from langchain_core.documents.base import Document
|
5 |
+
from policy_rag.data_models import DocList
|
6 |
+
|
7 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
8 |
+
from langchain_experimental.text_splitter import SemanticChunker
|
9 |
+
from langchain_openai.embeddings import OpenAIEmbeddings
|
10 |
+
|
11 |
+
|
12 |
+
# Text Loading
|
13 |
+
class DocLoader:
|
14 |
+
docs: DocList = DocList([]).root
|
15 |
+
|
16 |
+
def load(self, path: str) -> List[Document]:
|
17 |
+
if path.endswith('.pdf'):
|
18 |
+
loader = PyMuPDFLoader(path)
|
19 |
+
self.docs.extend(loader.load())
|
20 |
+
else:
|
21 |
+
print(f'Skipping {path} - not PDF')
|
22 |
+
|
23 |
+
return self.docs
|
24 |
+
|
25 |
+
|
26 |
+
def load_dir(self, dir_path: str) -> List[Document]:
|
27 |
+
for doc_name in os.listdir(dir_path):
|
28 |
+
doc_path = os.path.join(dir_path, doc_name)
|
29 |
+
self.load(doc_path)
|
30 |
+
|
31 |
+
return self.docs
|
32 |
+
|
33 |
+
|
34 |
+
# Text Splitting
|
35 |
+
def get_recursive_token_chunks(
|
36 |
+
docs: List[Document],
|
37 |
+
model_name: str = 'gpt-4',
|
38 |
+
chunk_size: int = 150,
|
39 |
+
chunk_overlap: int = 0
|
40 |
+
) -> List[Document]:
|
41 |
+
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
42 |
+
model_name=model_name,
|
43 |
+
chunk_size=chunk_size,
|
44 |
+
chunk_overlap=chunk_overlap
|
45 |
+
)
|
46 |
+
|
47 |
+
return text_splitter.split_documents(docs)
|
48 |
+
|
49 |
+
|
50 |
+
def get_semantic_chunks(
|
51 |
+
docs: List[Document],
|
52 |
+
embedding_model: OpenAIEmbeddings,
|
53 |
+
breakpoint_type: str = 'gradient'
|
54 |
+
) -> List[Document]:
|
55 |
+
text_splitter = SemanticChunker(
|
56 |
+
embeddings=embedding_model,
|
57 |
+
breakpoint_threshold_type=breakpoint_type
|
58 |
+
)
|
59 |
+
|
60 |
+
return text_splitter.split_documents(docs)
|
policy_rag/vectorstore_utils.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from langchain_core.documents.base import Document
|
4 |
+
from langchain_qdrant import QdrantVectorStore
|
5 |
+
from langchain_community.vectorstores import Qdrant
|
6 |
+
from langchain_core.vectorstores import VectorStoreRetriever
|
7 |
+
from qdrant_client import QdrantClient
|
8 |
+
from qdrant_client.http.models import Distance, VectorParams
|
9 |
+
from langchain_openai import OpenAIEmbeddings
|
10 |
+
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
11 |
+
from typing import Literal, Optional, List, Any
|
12 |
+
from uuid import UUID
|
13 |
+
|
14 |
+
|
15 |
+
class QdrantVectorstoreHelper:
|
16 |
+
def __init__(self) -> Any:
|
17 |
+
self.client = None
|
18 |
+
|
19 |
+
if os.getenv('QDRANT_API_KEY') and os.getenv('QDRANT_URL'):
|
20 |
+
self.client = QdrantClient(
|
21 |
+
url=os.getenv('QDRANT_URL'),
|
22 |
+
api_key=os.getenv('QDRANT_API_KEY')
|
23 |
+
)
|
24 |
+
else:
|
25 |
+
print("Qdrant API Key and URL not present.")
|
26 |
+
|
27 |
+
|
28 |
+
def create_collection(self, name: str, vector_size: int) -> None:
|
29 |
+
if self.client:
|
30 |
+
self.client.create_collection(
|
31 |
+
collection_name=name,
|
32 |
+
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
|
33 |
+
)
|
34 |
+
else:
|
35 |
+
print('No Qdrant Client')
|
36 |
+
|
37 |
+
|
38 |
+
def create_local_vectorstore(
|
39 |
+
self,
|
40 |
+
chunks: List[Document],
|
41 |
+
embedding_model: OpenAIEmbeddings | HuggingFaceInferenceAPIEmbeddings = OpenAIEmbeddings(model='text-embedding-3-large'),
|
42 |
+
vector_size: int = 3072
|
43 |
+
) -> None:
|
44 |
+
self.local_vectorstore = Qdrant.from_documents(
|
45 |
+
documents=chunks,
|
46 |
+
vector_params={'size': vector_size, 'distance': Distance.COSINE},
|
47 |
+
embedding=embedding_model,
|
48 |
+
batch_size=32 if type(embedding_model) == HuggingFaceInferenceAPIEmbeddings else 64,
|
49 |
+
location=":memory:"
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
def create_cloud_vectorstore(
|
54 |
+
self,
|
55 |
+
chunks: List[Document],
|
56 |
+
collection_name: str,
|
57 |
+
embedding_model: OpenAIEmbeddings | HuggingFaceInferenceAPIEmbeddings = OpenAIEmbeddings(model='text-embedding-3-large'),
|
58 |
+
vector_size: int = 3072
|
59 |
+
) -> None:
|
60 |
+
try:
|
61 |
+
self.cloud_vectorstore = QdrantVectorStore.from_existing_collection(
|
62 |
+
embedding=embedding_model,
|
63 |
+
collection_name=collection_name,
|
64 |
+
url=os.getenv('QDRANT_URL'),
|
65 |
+
api_key=os.getenv('QDRANT_API_KEY')
|
66 |
+
)
|
67 |
+
except:
|
68 |
+
self.cloud_vectorstore = QdrantVectorStore.from_documents(
|
69 |
+
documents=chunks,
|
70 |
+
embedding=embedding_model,
|
71 |
+
vector_params={'size': vector_size, 'distance': Distance.COSINE},
|
72 |
+
collection_name=collection_name,
|
73 |
+
batch_size=4 if type(embedding_model) == HuggingFaceInferenceAPIEmbeddings else 64,
|
74 |
+
prefer_grpc=True,
|
75 |
+
url=os.getenv('QDRANT_URL'),
|
76 |
+
api_key=os.getenv('QDRANT_API_KEY')
|
77 |
+
)
|
78 |
+
|
79 |
+
|
80 |
+
def add_docs_to_vectorstore(
|
81 |
+
self,
|
82 |
+
collection_name: Literal['memory'] | str,
|
83 |
+
chunks: List[Document],
|
84 |
+
uuids: UUID
|
85 |
+
) -> None:
|
86 |
+
str_uuids = [str(uuid) for uuid in uuids]
|
87 |
+
if collection_name == 'memory':
|
88 |
+
self.local_vectorstore.add_documents(documents=chunks, ids=str_uuids)
|
89 |
+
else:
|
90 |
+
self.cloud_vectorstore = QdrantVectorStore.from_existing_collection(
|
91 |
+
collection_name=collection_name,
|
92 |
+
url=os.getenv('QDRANT_URL'),
|
93 |
+
api_key=os.getenv('QDRANT_API_KEY')
|
94 |
+
)
|
95 |
+
|
96 |
+
self.cloud_vectorstore.add_documents(documents=chunks, ids=str_uuids)
|
97 |
+
|
98 |
+
|
99 |
+
def get_retriever(
|
100 |
+
self,
|
101 |
+
collection_name: Literal['memory'] | str,
|
102 |
+
k: int = 3,
|
103 |
+
embedding_model: OpenAIEmbeddings = OpenAIEmbeddings(model='text-embedding-3-large')
|
104 |
+
) -> VectorStoreRetriever:
|
105 |
+
if collection_name == 'memory':
|
106 |
+
return self.local_vectorstore.as_retriever(search_kwargs={'k': k})
|
107 |
+
else:
|
108 |
+
self.cloud_vectorstore = QdrantVectorStore.from_existing_collection(
|
109 |
+
collection_name=collection_name,
|
110 |
+
embedding=embedding_model,
|
111 |
+
url=os.getenv('QDRANT_URL'),
|
112 |
+
api_key=os.getenv('QDRANT_API_KEY')
|
113 |
+
)
|
114 |
+
|
115 |
+
return self.cloud_vectorstore.as_retriever(search_kwargs={'k': k})
|