FoldMark / process_data.py
Zaixi's picture
Add large file
89c0b51
import argparse
import csv
from pathlib import Path
from typing import Optional
import logging
import gradio as gr
import os
import uuid
from datetime import datetime
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from tqdm import tqdm
from protenix.data.data_pipeline import DataPipeline
from protenix.utils.file_io import dump_gzip_pickle
from configs.configs_base import configs as configs_base
from configs.configs_data import data_configs
from configs.configs_inference import inference_configs
from protenix.config import parse_configs
from protenix.data.dataloader import KeySumBalancedSampler
from protenix.data.dataset import BaseSingleDataset
from runner.inference import download_infercence_cache, update_inference_configs, infer_detect, InferenceRunner
from scripts.prepare_training_data import run_gen_data
from torch.utils.data import DataLoader
def process_data(path):
try:
run_gen_data(
input_path=path,
output_indices_csv=os.path.join(path, 'output.csv'),
bioassembly_output_dir=path,
cluster_file=None,
distillation=False,
num_workers=1,
)
return False
except:
print('Use Distillation')
run_gen_data(
input_path=path,
output_indices_csv=os.path.join(path, 'output.csv'),
bioassembly_output_dir=path,
cluster_file=None,
distillation=True,
num_workers=1,
)
return True
# logger = logging.getLogger(__name__)
# LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s"
# logging.basicConfig(
# format=LOG_FORMAT,
# level=logging.INFO,
# datefmt="%Y-%m-%d %H:%M:%S",
# filemode="w",
# )
# configs_base["use_deepspeed_evo_attention"] = (
# os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "true"
# )
# arg_str = "--seeds 101 --dump_dir ./output --input_json_path ./examples/example.json --model.N_cycle 10 --sample_diffusion.N_sample 5 --sample_diffusion.N_step 200 "
# configs = {**configs_base, **{"data": data_configs}, **inference_configs}
# configs = parse_configs(
# configs=configs,
# arg_str=arg_str,
# fill_required_with_null=True,
# )
# configs.load_checkpoint_path = '/n/netscratch/mzitnik_lab/Lab/zzx/output/protenix_new_finetune_20250202_032321/checkpoints/599.pt'
# download_infercence_cache(configs, model_version="v0.2.0")
# configs.process_success = process_data('./dataset')
# configs.subdir = './dataset'
# runner = InferenceRunner(configs)
# result = infer_detect(runner, configs)
# if result==False:
# print("Not Watermarked")
# else:
# print("Watermarked")
# print('Completed')