import argparse import csv from pathlib import Path from typing import Optional import logging import gradio as gr import os import uuid from datetime import datetime import numpy as np import pandas as pd from joblib import Parallel, delayed from tqdm import tqdm from protenix.data.data_pipeline import DataPipeline from protenix.utils.file_io import dump_gzip_pickle from configs.configs_base import configs as configs_base from configs.configs_data import data_configs from configs.configs_inference import inference_configs from protenix.config import parse_configs from protenix.data.dataloader import KeySumBalancedSampler from protenix.data.dataset import BaseSingleDataset from runner.inference import download_infercence_cache, update_inference_configs, infer_detect, InferenceRunner from scripts.prepare_training_data import run_gen_data from torch.utils.data import DataLoader def process_data(path): try: run_gen_data( input_path=path, output_indices_csv=os.path.join(path, 'output.csv'), bioassembly_output_dir=path, cluster_file=None, distillation=False, num_workers=1, ) return False except: print('Use Distillation') run_gen_data( input_path=path, output_indices_csv=os.path.join(path, 'output.csv'), bioassembly_output_dir=path, cluster_file=None, distillation=True, num_workers=1, ) return True # logger = logging.getLogger(__name__) # LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s" # logging.basicConfig( # format=LOG_FORMAT, # level=logging.INFO, # datefmt="%Y-%m-%d %H:%M:%S", # filemode="w", # ) # configs_base["use_deepspeed_evo_attention"] = ( # os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "true" # ) # arg_str = "--seeds 101 --dump_dir ./output --input_json_path ./examples/example.json --model.N_cycle 10 --sample_diffusion.N_sample 5 --sample_diffusion.N_step 200 " # configs = {**configs_base, **{"data": data_configs}, **inference_configs} # configs = parse_configs( # configs=configs, # arg_str=arg_str, # fill_required_with_null=True, # ) # configs.load_checkpoint_path = '/n/netscratch/mzitnik_lab/Lab/zzx/output/protenix_new_finetune_20250202_032321/checkpoints/599.pt' # download_infercence_cache(configs, model_version="v0.2.0") # configs.process_success = process_data('./dataset') # configs.subdir = './dataset' # runner = InferenceRunner(configs) # result = infer_detect(runner, configs) # if result==False: # print("Not Watermarked") # else: # print("Watermarked") # print('Completed')