gt-policy-bot / vectorise.py
mohit-raghavendra's picture
Initial upload
bbd7b76
import tqdm
import yaml
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
BATCH_SIZE = 2
class Vectorizer:
def __init__(self, model_name: str):
self.model_name = model_name
self.model = SentenceTransformer(model_name)
self.batch_size = BATCH_SIZE
def get_query_embedding(self, query: str) -> np.ndarray:
return self.model.encode(query)
def get_embeddings(self, df: pd.DataFrame, data_col: str):
docs = df[data_col]
num_docs = len(docs)
embeddings = []
for i in tqdm.tqdm(range(0, num_docs, self.batch_size)):
docs_batch = docs[i: i + self.batch_size].to_list()
vectors_batch = self.model.encode(docs_batch).tolist()
embeddings.append(vectors_batch)
embeddings_flattened = [embedding for batch in embeddings for embedding in batch]
assert len(embeddings_flattened) == num_docs
return embeddings_flattened
def embed_docs(self, df: pd.DataFrame, data_col: str) -> pd.DataFrame:
embeddings = self.get_embeddings(df, data_col)
df['embeddings'] = embeddings
return df
def run_vectorizer(configFilePath="config.yml"):
with open(configFilePath, 'r') as file:
config = yaml.safe_load(file)
print("Config File Loaded ...")
print(config)
data_path = config['paths']['data_path']
project = config['paths']['project']
format = '.csv'
data_col_name = 'chunks'
df = pd.read_csv(data_path + project + format)
vectorizer = Vectorizer(config['sentence-transformers']['model-name'])
df_embeddings = vectorizer.embed_docs(df, data_col_name)
print("Creation of embedding completed ...")
print(df_embeddings.head())
file_path_embedding = data_path + project + '_embedding' + format
df_embeddings.to_csv(file_path_embedding)
df_read = pd.read_csv(file_path_embedding, index_col=0)
assert len(df_read) == len(df_embeddings)
print(file_path_embedding + "created ...")
if __name__ == "__main__":
run_vectorizer()