Spaces:
Runtime error
Runtime error
import tqdm | |
import yaml | |
import numpy as np | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer | |
BATCH_SIZE = 2 | |
class Vectorizer: | |
def __init__(self, model_name: str): | |
self.model_name = model_name | |
self.model = SentenceTransformer(model_name) | |
self.batch_size = BATCH_SIZE | |
def get_query_embedding(self, query: str) -> np.ndarray: | |
return self.model.encode(query) | |
def get_embeddings(self, df: pd.DataFrame, data_col: str): | |
docs = df[data_col] | |
num_docs = len(docs) | |
embeddings = [] | |
for i in tqdm.tqdm(range(0, num_docs, self.batch_size)): | |
docs_batch = docs[i: i + self.batch_size].to_list() | |
vectors_batch = self.model.encode(docs_batch).tolist() | |
embeddings.append(vectors_batch) | |
embeddings_flattened = [embedding for batch in embeddings for embedding in batch] | |
assert len(embeddings_flattened) == num_docs | |
return embeddings_flattened | |
def embed_docs(self, df: pd.DataFrame, data_col: str) -> pd.DataFrame: | |
embeddings = self.get_embeddings(df, data_col) | |
df['embeddings'] = embeddings | |
return df | |
def run_vectorizer(configFilePath="config.yml"): | |
with open(configFilePath, 'r') as file: | |
config = yaml.safe_load(file) | |
print("Config File Loaded ...") | |
print(config) | |
data_path = config['paths']['data_path'] | |
project = config['paths']['project'] | |
format = '.csv' | |
data_col_name = 'chunks' | |
df = pd.read_csv(data_path + project + format) | |
vectorizer = Vectorizer(config['sentence-transformers']['model-name']) | |
df_embeddings = vectorizer.embed_docs(df, data_col_name) | |
print("Creation of embedding completed ...") | |
print(df_embeddings.head()) | |
file_path_embedding = data_path + project + '_embedding' + format | |
df_embeddings.to_csv(file_path_embedding) | |
df_read = pd.read_csv(file_path_embedding, index_col=0) | |
assert len(df_read) == len(df_embeddings) | |
print(file_path_embedding + "created ...") | |
if __name__ == "__main__": | |
run_vectorizer() |