Zaixi commited on
Commit
287a06f
·
1 Parent(s): c6eeef3
configs/configs_data.py CHANGED
@@ -60,7 +60,7 @@ default_weighted_pdb_configs = {
60
  "shuffle_sym_ids": GlobalConfigValue("train_shuffle_sym_ids"),
61
  }
62
 
63
- DATA_ROOT_DIR = "./"
64
 
65
  # Use CCD cache created by scripts/gen_ccd_cache.py priority. (without date in filename)
66
  # See: docs/prepare_data.md
 
60
  "shuffle_sym_ids": GlobalConfigValue("train_shuffle_sym_ids"),
61
  }
62
 
63
+ DATA_ROOT_DIR = "./release_data/ccd_cache"
64
 
65
  # Use CCD cache created by scripts/gen_ccd_cache.py priority. (without date in filename)
66
  # See: docs/prepare_data.md
protenix/data/data_pipeline.py CHANGED
@@ -57,41 +57,41 @@ class DataPipeline(object):
57
  sample_indices_list (list[dict[str, Any]]): The sample indices list (each one is a chain or an interface).
58
  bioassembly_dict (dict[str, Any]): The bioassembly dict with sequence, atom_array, and token_array.
59
  """
60
- try:
61
- if dataset == "WeightedPDB":
62
- parser = MMCIFParser(mmcif_file=mmcif)
63
- bioassembly_dict = parser.get_bioassembly()
64
- elif dataset == "Distillation":
65
- parser = DistillationMMCIFParser(mmcif_file=mmcif)
66
- bioassembly_dict = parser.get_structure_dict()
67
- else:
68
- raise NotImplementedError(
69
- 'Unsupported "dataset", please input either "WeightedPDB" or "Distillation".'
70
- )
71
-
72
- sample_indices_list = parser.make_indices(
73
- bioassembly_dict=bioassembly_dict, pdb_cluster_file=pdb_cluster_file
74
  )
75
- if len(sample_indices_list) == 0:
76
- # empty indices and AtomArray
77
- return [], bioassembly_dict
78
 
79
- atom_array = bioassembly_dict["atom_array"]
80
- atom_array.set_annotation(
81
- "resolution", [parser.resolution] * len(atom_array)
82
- )
 
 
 
 
 
 
 
83
 
84
- tokenizer = AtomArrayTokenizer(atom_array)
85
- token_array = tokenizer.get_token_array()
86
- bioassembly_dict["msa_features"] = None
87
- bioassembly_dict["template_features"] = None
88
 
89
- bioassembly_dict["token_array"] = token_array
90
- return sample_indices_list, bioassembly_dict
91
 
92
- except Exception as e:
93
- logging.warning("Gen data failed for %s due to %s", mmcif, e)
94
- return [], {}
95
 
96
  @staticmethod
97
  def get_label_entity_id_to_asym_id_int(atom_array: AtomArray) -> dict[str, int]:
 
57
  sample_indices_list (list[dict[str, Any]]): The sample indices list (each one is a chain or an interface).
58
  bioassembly_dict (dict[str, Any]): The bioassembly dict with sequence, atom_array, and token_array.
59
  """
60
+ #try:
61
+ if dataset == "WeightedPDB":
62
+ parser = MMCIFParser(mmcif_file=mmcif)
63
+ bioassembly_dict = parser.get_bioassembly()
64
+ elif dataset == "Distillation":
65
+ parser = DistillationMMCIFParser(mmcif_file=mmcif)
66
+ bioassembly_dict = parser.get_structure_dict()
67
+ else:
68
+ raise NotImplementedError(
69
+ 'Unsupported "dataset", please input either "WeightedPDB" or "Distillation".'
 
 
 
 
70
  )
 
 
 
71
 
72
+ sample_indices_list = parser.make_indices(
73
+ bioassembly_dict=bioassembly_dict, pdb_cluster_file=pdb_cluster_file
74
+ )
75
+ if len(sample_indices_list) == 0:
76
+ # empty indices and AtomArray
77
+ return [], bioassembly_dict
78
+
79
+ atom_array = bioassembly_dict["atom_array"]
80
+ atom_array.set_annotation(
81
+ "resolution", [parser.resolution] * len(atom_array)
82
+ )
83
 
84
+ tokenizer = AtomArrayTokenizer(atom_array)
85
+ token_array = tokenizer.get_token_array()
86
+ bioassembly_dict["msa_features"] = None
87
+ bioassembly_dict["template_features"] = None
88
 
89
+ bioassembly_dict["token_array"] = token_array
90
+ return sample_indices_list, bioassembly_dict
91
 
92
+ # except Exception as e:
93
+ # logging.warning("Gen data failed for %s due to %s", mmcif, e)
94
+ # return [], {}
95
 
96
  @staticmethod
97
  def get_label_entity_id_to_asym_id_int(atom_array: AtomArray) -> dict[str, int]:
runner/inference.py CHANGED
@@ -208,12 +208,19 @@ def download_infercence_cache(configs: Any, model_version: str = "v0.2.0") -> No
208
  os.makedirs(data_cache_dir, exist_ok=True)
209
  for cache_name, fname in [
210
  ("ccd_components_file", "components.v20240608.cif"),
211
- ("ccd_components_rdkit_mol_file", "components.v20240608.cif.rdkit_mol.pkl"),
212
  ]:
213
  if not opexists(cache_path := os.path.abspath(opjoin(data_cache_dir, fname))):
214
  tos_url = URL[cache_name]
215
  logger.info(f"Downloading data cache from\n {tos_url}...")
216
  urllib.request.urlretrieve(tos_url, cache_path)
 
 
 
 
 
 
 
 
217
 
218
  if not os.path.exists('./checkpoint.pt'):
219
  # Google Drive file ID
 
208
  os.makedirs(data_cache_dir, exist_ok=True)
209
  for cache_name, fname in [
210
  ("ccd_components_file", "components.v20240608.cif"),
 
211
  ]:
212
  if not opexists(cache_path := os.path.abspath(opjoin(data_cache_dir, fname))):
213
  tos_url = URL[cache_name]
214
  logger.info(f"Downloading data cache from\n {tos_url}...")
215
  urllib.request.urlretrieve(tos_url, cache_path)
216
+
217
+ if not os.path.exists('./release_data/ccd_cache/components.v20240608.cif.rdkit_mol.pkl'):
218
+ file_id = '1R9d678aBfQwTd0Rh15doRmW-fETNdeWf'
219
+ # Construct the download URL
220
+ download_url = f'https://drive.google.com/uc?id={file_id}'
221
+ # Specify the output file name
222
+ output_file = './release_data/ccd_cache/components.v20240608.cif.rdkit_mol.pkl'
223
+ gdown.download(download_url, output_file, quiet=False)
224
 
225
  if not os.path.exists('./checkpoint.pt'):
226
  # Google Drive file ID