import openai
import numpy as np
import faiss
from typing import List

class EmbeddingsManager:
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.index = None
        self.chunks = []

    def generate_embeddings(self, text_chunks: List[str]):
        """Generate embeddings for text chunks using OpenAI API."""
        batch_size = 10
        embeddings = []

        for i in range(0, len(text_chunks), batch_size):
            batch = text_chunks[i:i + batch_size]
            response = openai.embeddings.create(
                input=batch,
                model="text-embedding-ada-002"
            )
            # Access the embeddings using attributes
            batch_embeddings = [item.embedding for item in response.data]
            embeddings.extend(batch_embeddings)

        # Create FAISS index
        dimension = len(embeddings[0])
        self.index = faiss.IndexFlatL2(dimension)
        embeddings_array = np.array(embeddings).astype('float32')
        self.index.add(embeddings_array)
        self.chunks = text_chunks

    def find_relevant_chunks(self, query: str, k: int = 3) -> List[str]:
        """Find most relevant text chunks for a given query."""
        response = openai.embeddings.create(
            input=[query],
            model="text-embedding-ada-002"
        )
        # Access the query embedding using attributes
        query_embedding = response.data[0].embedding

        D, I = self.index.search(
            np.array([query_embedding]).astype('float32'),
            k
        )

        return [self.chunks[i] for i in I[0] if i != -1]