svystun-taras's picture
created the updated web ui
0fdb130
from statistics import mean
import sys
import os
import json
from datetime import datetime
import warnings
from pprint import pprint
from langchain.text_splitter import RecursiveCharacterTextSplitter
warnings.filterwarnings("ignore")
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
# sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..', 'financial_dataset')))
dataset_dir = os.path.abspath(os.path.join(os.getcwd(), '..', '..', 'financial_dataset'))
sys.path.append(dataset_dir)
from load_test_data import get_labels_df, get_texts
from app import (
summarize,
read_and_split_file,
get_label_prediction
)
from config import (
labels, headers_inference_api, headers_inference_endpoint,
# summarization_prompt_template,
prompt_template,
# task_explain_for_predictor_model,
summarizers, predictors, summary_scores_template,
summarization_system_msg, summarization_user_prompt, prediction_user_prompt, prediction_system_msg,
# prediction_prompt,
chat_prompt, instruction_prompt
)
def split_text(text, chunk_size=1200, chunk_overlap=200):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap,
length_function = len, separators=[" ", ",", "\n"]
)
text_chunks = text_splitter.create_documents([text])
return text_chunks
predictions = {
# method: {name: {'actual': []}}
'summarization+classification': {
'bart-pegasus+gpt': [], # list of pred_labels
'gpt+gpt': [],
},
'chunk_classification': {},
'embedding_classification': {},
'zero-shot_classification': {},
'full_text_classification': {},
'QA_classification': {}
}
# if __name__ == '__main__':
labels_dir = dataset_dir + '/csvs/'
df = get_labels_df(labels_dir)
texts_dir = dataset_dir + '/txts/'
texts = get_texts(texts_dir)
# print(len(df), len(texts))
# print(mean(list(map(len, texts))))
# summarization+classification
# for selected_summarizer in summarizers:
# print(selected_summarizer)
# # for selected_predictor in predictors:
# # predictions['summarization+classification'][selected_summarizer + '+' + selected_predictor] = []
# for text, (idx, (year, label, company)) in zip(texts, df.iterrows()):
# print(year, label, company)
# # summary_filename = f'./texts/{year}_{company}_{selected_summarizer}_summary.txt'
# summary_filename = f'./texts/{company}_{year}_{selected_summarizer}_summary.txt'
# if os.path.isfile(summary_filename):
# print('Loading summary from the cache')
# with open(summary_filename, 'r') as f:
# summary = f.read()
# else:
# print(f'Making request to {selected_summarizer} to summarize {company}, {year}')
# text_chunks = split_text(text,
# chunk_size=summarizers[selected_summarizer]['chunk_size'],
# chunk_overlap=100)
# summary, summary_score = summarize(selected_summarizer, text_chunks)
# with open(summary_filename, 'w') as f:
# f.write(summary)
# print('-' * 50)
# # break
# # summary_chunks = split_text(summary, chunk_size=3_600)
# # predicted_label = get_label_prediction(selected_predictor, summary_chunks)
# # if predicted_label in labels:
# # predictions['summarization+classification'][selected_summarizer + '+' + selected_predictor].append(predicted_label)
# print()
# break
# # chunk_classification
# for selected_predictor in predictors:
# predictions['chunk_classification'][selected_predictor] = []
# for text, (idx, (year, label, company)) in zip(texts, df.iterrows()):
# print(year, label, company)
# text_chunks = split_text(text, chunk_size=3600)
# predicted_label = get_label_prediction(selected_predictor, text_chunks)
# if predicted_label in labels:
# predictions['summarization+chunk_classification'][selected_predictor].append(predicted_label)
# print('-' * 50)
# with open(f'predictions/predictions_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.json', 'w') as json_file:
# json.dump(predictions, json_file, indent=4)