import tqdm
import yaml

import numpy as np
import pandas as pd

from sentence_transformers import SentenceTransformer

BATCH_SIZE = 2


class Vectorizer:
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.model = SentenceTransformer(model_name)
        self.batch_size = BATCH_SIZE

    def get_query_embedding(self, query: str) -> np.ndarray:
        return self.model.encode(query)
                              
    def get_embeddings(self, df: pd.DataFrame, data_col: str):
        docs = df[data_col]
        num_docs = len(docs)
        embeddings = []
        for i in tqdm.tqdm(range(0, num_docs, self.batch_size)):
            docs_batch = docs[i: i + self.batch_size].to_list()
            vectors_batch = self.model.encode(docs_batch).tolist()
            embeddings.append(vectors_batch)

        embeddings_flattened = [embedding for batch in embeddings for embedding in batch]

        assert len(embeddings_flattened) == num_docs
        return embeddings_flattened

    def embed_docs(self, df: pd.DataFrame, data_col: str) -> pd.DataFrame:
        embeddings = self.get_embeddings(df, data_col)
        df['embeddings'] = embeddings
        
        return df


def run_vectorizer(configFilePath="config.yml"):
    with open(configFilePath, 'r') as file:
        config = yaml.safe_load(file)
    print("Config File Loaded ...")
    print(config)

    data_path = config['paths']['data_path']
    project = config['paths']['project']
    format = '.csv'

    data_col_name = 'chunks'
    df = pd.read_csv(data_path + project + format)

    vectorizer = Vectorizer(config['sentence-transformers']['model-name'])
    df_embeddings = vectorizer.embed_docs(df, data_col_name)
    print("Creation of embedding completed ...")
    print(df_embeddings.head())

    file_path_embedding = data_path + project + '_embedding' + format
    df_embeddings.to_csv(file_path_embedding)

    df_read = pd.read_csv(file_path_embedding, index_col=0)
    assert len(df_read) == len(df_embeddings)
    print(file_path_embedding + "created ...")


if __name__ == "__main__":
    run_vectorizer()