import faiss
import numpy as np

def generate_faiss_index(embeddings):
    # Ensure that the embeddings are converted to np.float32 (FAISS expects float32)
    embeddings = np.array(embeddings, dtype=np.float32)
    index = faiss.IndexFlatL2(768)  # Assuming 768-dimensional embeddings for a model like MiniLM
    index.add(embeddings)
    return index

def load_faiss_index_to_gpu(index):
    # If you're using GPU, ensure the index is moved to the GPU
    res = faiss.StandardGpuResources()  # Create resources for the GPU
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index)  # Load into GPU (assuming GPU 0 is available)
    return gpu_index

def query_faiss_index(query_embedding, gpu_index):
    # Query the FAISS index with the query embedding
    query_embedding = np.array(query_embedding, dtype=np.float32)  # Ensure the query is a np.array with the right type
    distances, indices = gpu_index.search(query_embedding.reshape(1, -1), 1)  # Reshaping as FAISS expects 2D array
    return indices, distances