Sutra_AI / utils /embeddings_utils.py
Inferno-721's picture
Initial
0753d2e
import openai
import numpy as np
import faiss
from typing import List
class EmbeddingsManager:
def __init__(self, api_key: str):
self.api_key = api_key
self.index = None
self.chunks = []
def generate_embeddings(self, text_chunks: List[str]):
"""Generate embeddings for text chunks using OpenAI API."""
batch_size = 10
embeddings = []
for i in range(0, len(text_chunks), batch_size):
batch = text_chunks[i:i + batch_size]
response = openai.embeddings.create(
input=batch,
model="text-embedding-ada-002"
)
# Access the embeddings using attributes
batch_embeddings = [item.embedding for item in response.data]
embeddings.extend(batch_embeddings)
# Create FAISS index
dimension = len(embeddings[0])
self.index = faiss.IndexFlatL2(dimension)
embeddings_array = np.array(embeddings).astype('float32')
self.index.add(embeddings_array)
self.chunks = text_chunks
def find_relevant_chunks(self, query: str, k: int = 3) -> List[str]:
"""Find most relevant text chunks for a given query."""
response = openai.embeddings.create(
input=[query],
model="text-embedding-ada-002"
)
# Access the query embedding using attributes
query_embedding = response.data[0].embedding
D, I = self.index.search(
np.array([query_embedding]).astype('float32'),
k
)
return [self.chunks[i] for i in I[0] if i != -1]