| import os | |
| import zipfile | |
| import pickle | |
| from glob import glob | |
| from pathlib import Path | |
| import pandas as pd | |
| import gradio as gr | |
| from indexrl.training import ( | |
| DynamicBuffer, | |
| create_model, | |
| save_model, | |
| explore, | |
| train_iter, | |
| ) | |
| from indexrl.environment import IndexRLEnv | |
| from indexrl.utils import get_n_channels, state_to_expression | |
| max_exp_len = 12 | |
| data_dir = "data/" | |
| global_logs_dir = os.path.join(data_dir, "logs") | |
| os.makedirs(data_dir, exist_ok=True) | |
| meta_data_file = os.path.join(data_dir, "metadata.csv") | |
| if not os.path.exists(meta_data_file): | |
| with open(meta_data_file, "w") as fp: | |
| fp.write("Name,Channels,Path\n") | |
| def save_dataset(name, zip): | |
| with zipfile.ZipFile(zip.name, "r") as zip_ref: | |
| data_path = os.path.join(data_dir, name) | |
| zip_ref.extractall(data_path) | |
| img_path = glob(os.path.join(data_path, "images", "*.npy"))[0] | |
| n_channels = get_n_channels(img_path) | |
| with open(meta_data_file, "a") as fp: | |
| fp.write(f"{name},{n_channels},{data_path}\n") | |
| meta_data_df = pd.read_csv(meta_data_file) | |
| return meta_data_df, gr.Dropdown.update(choices=meta_data_df["Name"].to_list()) | |
| def get_tree(exp_num: int = 1, tree_num: int = 1): | |
| tree_num = max(tree_num, 1) | |
| tree_path = os.path.join( | |
| global_logs_dir, f"tree_{int(exp_num)}_{int(tree_num)}.txt" | |
| ) | |
| if os.path.exists(tree_path): | |
| with open(tree_path, "r", encoding="utf-8") as fp: | |
| tree = fp.read() | |
| return tree | |
| print(f"Tree at {tree_path} not found!") | |
| return "" | |
| def change_expression(exp_num: int = 1, tree_num: int = 1): | |
| try: | |
| paths = glob(os.path.join(global_logs_dir, f"tree_{int(exp_num)}_*.txt")) | |
| except TypeError: | |
| return "", gr.Slider.update() | |
| tree_num = max(min(len(paths), tree_num), 1) | |
| tree = get_tree(exp_num, tree_num) | |
| return tree, gr.Slider.update(value=tree_num, maximum=len(paths), interactive=True) | |
| def find_expression(dataset_name: str): | |
| if dataset_name == "": | |
| return ("", gr.Slider.update(value=1, interactive=False)) | |
| global global_logs_dir | |
| meta_data_df = pd.read_csv(meta_data_file, index_col="Name") | |
| n_channels = meta_data_df["Channels"][dataset_name] | |
| data_dir = meta_data_df["Path"][dataset_name] | |
| image_dir = os.path.join(data_dir, "images") | |
| mask_dir = os.path.join(data_dir, "masks") | |
| cache_dir = os.path.join(data_dir, "cache") | |
| global_logs_dir = logs_dir = os.path.join(data_dir, "logs") | |
| models_dir = os.path.join(data_dir, "models") | |
| for dir_name in (cache_dir, logs_dir, models_dir): | |
| Path(dir_name).mkdir(parents=True, exist_ok=True) | |
| action_list = ( | |
| list("()+-*/=") + ["sq", "sqrt"] + [f"c{c}" for c in range(n_channels)] | |
| ) | |
| env = IndexRLEnv(action_list, max_exp_len) | |
| agent, optimizer = create_model(len(action_list)) | |
| seen_path = os.path.join(cache_dir, "seen.pkl") if cache_dir else "" | |
| env.save_seen(seen_path) | |
| data_buffer = DynamicBuffer() | |
| i = 0 | |
| while True: | |
| i += 1 | |
| print(f"----------------\nIteration {i}") | |
| print("Collecting data...") | |
| data = explore( | |
| env.copy(), | |
| agent, | |
| image_dir, | |
| mask_dir, | |
| 1, | |
| logs_dir, | |
| seen_path, | |
| tree_prefix=f"tree_{int(i)}", | |
| n_iters=1000, | |
| ) | |
| print( | |
| f"Data collection done. Collected {len(data)} examples. Buffer size = {len(data_buffer)}." | |
| ) | |
| data_buffer.add_data(data) | |
| print(f"Buffer size new = {len(data_buffer)}.") | |
| agent, optimizer, loss = train_iter(agent, optimizer, data_buffer) | |
| print("Loss:", loss) | |
| i_str = str(i).rjust(3, "0") | |
| if models_dir: | |
| save_model(agent, f"{models_dir}/model_{i_str}_loss-{loss}.pt") | |
| if cache_dir: | |
| with open(f"{cache_dir}/data_buffer_{i_str}.pkl", "wb") as fp: | |
| pickle.dump(data_buffer, fp) | |
| tree = get_tree() | |
| top_5 = data_buffer.get_top_n(5) | |
| top_5_str = "\n".join( | |
| map( | |
| lambda x: " ".join(state_to_expression(x[0], action_list)) | |
| + " " | |
| + str(x[1]), | |
| top_5, | |
| ) | |
| ) | |
| yield top_5_str, gr.Slider.update(value=i, maximum=i, interactive=True) | |
| with gr.Blocks(title="IndexRL") as demo: | |
| gr.Markdown("# IndexRL") | |
| meta_data_df = pd.read_csv(meta_data_file) | |
| with gr.Tab("Find Expressions"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| select_dataset = gr.Dropdown( | |
| label="Select Dataset", | |
| choices=meta_data_df["Name"].to_list(), | |
| ) | |
| find_exp_btn = gr.Button("Find Expressions", variant="primary") | |
| stop_btn = gr.Button("Stop", variant="stop") | |
| best_exps = gr.Textbox(label="Best Expressions", interactive=False) | |
| with gr.Column(): | |
| select_exp = gr.Slider( | |
| value=1, label="Iteration", interactive=False, minimum=1, step=1 | |
| ) | |
| select_tree = gr.Slider( | |
| value=1, label="Tree Number", interactive=False, minimum=1, step=1 | |
| ) | |
| out_exp_tree = gr.Textbox( | |
| label="Latest Expression Tree", interactive=False | |
| ) | |
| with gr.Tab("Datasets"): | |
| dataset_upload = gr.File(label="Upload Data ZIP file") | |
| dataset_name = gr.Textbox(label="Dataset Name") | |
| dataset_upload_btn = gr.Button("Upload") | |
| dataset_table = gr.Dataframe(meta_data_df, label="Dataset Table") | |
| find_exp_event = find_exp_btn.click( | |
| find_expression, | |
| inputs=[select_dataset], | |
| outputs=[best_exps, select_exp], | |
| ) | |
| stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[find_exp_event]) | |
| select_exp.change( | |
| fn=lambda x, y: change_expression(x, y), | |
| inputs=[select_exp, select_tree], | |
| outputs=[out_exp_tree, select_tree], | |
| ) | |
| select_tree.change( | |
| fn=lambda x, y: get_tree(x, y), | |
| inputs=[select_exp, select_tree], | |
| outputs=out_exp_tree, | |
| ) | |
| dataset_upload.upload( | |
| lambda x: ".".join(os.path.basename(x.orig_name).split(".")[:-1]), | |
| inputs=dataset_upload, | |
| outputs=dataset_name, | |
| ) | |
| dataset_upload_btn.click( | |
| save_dataset, | |
| inputs=[dataset_name, dataset_upload], | |
| outputs=[dataset_table, select_dataset], | |
| ) | |
| demo.queue(concurrency_count=10).launch(debug=True) | |