FoldMark / scripts /colabfold_msa.py
Zaixi's picture
Add large file
89c0b51
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@dataclass
class LocalColabFoldConfig:
"""Configuration for ColabFold search."""
colabsearch: str
query_fpath: str
db_dir: str
results_dir: str
mmseqs_path: Optional[str] = None
db1: str = "uniref30_2302_db"
db2: Optional[str] = None
db3: Optional[str] = "colabfold_envdb_202108_db"
use_env: int = 1
filter: int = 1
db_load_mode: int = 0
class A3MProcessor:
"""Processor for A3M file format."""
def __init__(self, a3m_file: str, out_dir: str):
self.out_dir = out_dir
self.a3m_file = Path(a3m_file)
self.a3m_content = self._read_a3m_file()
self.chain_info = self._parse_header()
def _read_a3m_file(self) -> str:
"""Read A3M file content."""
return self.a3m_file.read_text()
def _parse_header(self) -> Tuple[List[str], Dict[str, Tuple[int, int]]]:
"""Parse A3M header to get chain information."""
first_line = self.a3m_content.split("\n")[0]
if first_line[0] == "#":
lengths, oligomeric_state = first_line.split("\t")
chain_lengths = [int(x) for x in lengths[1:].split(",")]
chain_names = [f"10{x+1}" for x in range(len(oligomeric_state.split(",")))]
# Calculate sequence ranges for each chain
seq_ranges = {}
for i, name in enumerate(chain_names):
start = sum(chain_lengths[:i])
end = sum(chain_lengths[: i + 1])
seq_ranges[name] = (start, end)
return chain_names, seq_ranges
else:
non_pairing = ">query\n" + "\n".join(self.a3m_content.split("\n")[1:])
query_seq = self.a3m_content.split("\n")[1]
pairing = f">query\n{query_seq}"
msa_path = Path(self.out_dir) / "msa"
msa_path.mkdir(exist_ok=True)
msa_path = msa_path / "0"
msa_path.mkdir(exist_ok=True)
with open(msa_path / "non_pairing.a3m", "w") as f:
f.write(non_pairing)
with open(msa_path / "pairing.a3m", "w") as f:
f.write(pairing)
return [None]
def _extract_sequence(self, line: str, range_tuple: Tuple[int, int]) -> str:
"""Extract sequence for specific range."""
seq = []
no_insert_count = 0
start, end = range_tuple
for char in line:
if char.isupper() or char == "-":
no_insert_count += 1
# we keep insertions
if start < no_insert_count <= end:
seq.append(char)
elif no_insert_count > end:
break
return "".join(seq)
def split_sequences(self) -> None:
"""Split A3M file into pairing and non-pairing sequences."""
out_dir = Path(self.out_dir) / "msa"
chain_names, seq_ranges = self.chain_info
pairing_a3ms = {name: [] for name in chain_names}
nonpairing_a3ms = {name: [] for name in chain_names}
current_query = None
for line in self.a3m_content.split("\n"):
if line.startswith("#"):
continue
if line.startswith(">"):
name = line[1:]
if name in chain_names:
current_query = chain_names[chain_names.index(name)]
elif name == "\t".join(chain_names):
current_query = None
# Add header line to appropriate dictionary
if current_query:
nonpairing_a3ms[current_query].append(line)
else:
for name in chain_names:
pairing_a3ms[name].append(line)
continue
# Process sequence line
if not line:
continue
if current_query:
seq = self._extract_sequence(line, seq_ranges[current_query])
nonpairing_a3ms[current_query].append(seq)
else:
for name in chain_names:
seq = self._extract_sequence(line, seq_ranges[name])
pairing_a3ms[name].append(seq)
self._write_output_files(out_dir, nonpairing_a3ms, pairing_a3ms)
def _write_output_files(
self,
out_dir: Path,
nonpairing_a3ms: Dict[str, List[str]],
pairing_a3ms: Dict[str, List[str]],
) -> None:
"""Write split sequences to output files."""
out_dir.mkdir(exist_ok=True)
# Write non-pairing sequences
for i, (name, lines) in enumerate(nonpairing_a3ms.items()):
chain_dir = out_dir / str(i)
chain_dir.mkdir(exist_ok=True)
with open(chain_dir / "non_pairing.a3m", "w") as f:
query_seq = lines[1]
f.write(">query\n")
f.write(f"{query_seq}\n")
f.write("\n".join(lines[2:]))
# Write pairing sequences
for i, (name, lines) in enumerate(pairing_a3ms.items()):
chain_dir = out_dir / str(i)
chain_dir.mkdir(exist_ok=True)
with open(chain_dir / "pairing.a3m", "w") as f:
query_seq = lines[1]
f.write(">query\n")
f.write(f"{query_seq}\n")
# Process remaining sequences
sequences = {}
for j, line in enumerate(lines[2:]):
if line.startswith(">"):
current_name = f"UniRef100_{line[1:].split()[i]}_{j}"
sequences[current_name] = ""
elif line and "DUMMY" not in current_name:
sequences[current_name] = line
# Write processed sequences
for seq_name, seq in sequences.items():
if seq: # Only write non-empty sequences
f.write(f">{seq_name}\n{seq}\n")
def run_colabfold_search(config: LocalColabFoldConfig) -> str:
"""Run ColabFold search with given configuration."""
cmd = [config.colabsearch, config.query_fpath, config.db_dir, config.results_dir]
# Add optional parameters
if config.db1:
cmd.extend(["--db1", config.db1])
if config.db2:
cmd.extend(["--db2", config.db2])
if config.db3:
cmd.extend(["--db3", config.db3])
if config.mmseqs_path:
cmd.extend(["--mmseqs", config.mmseqs_path])
else:
cmd.extend(["--mmseqs", "mmseqs"])
if config.use_env:
cmd.extend(["--use-env", str(config.use_env)])
if config.filter:
cmd.extend(["--filter", str(config.filter)])
if config.db_load_mode:
cmd.extend(["--db-load-mode", str(config.db_load_mode)])
cmd = " ".join(cmd)
os.system(cmd)
# Return the first .a3m file found in results directory
result_files = list(Path(config.results_dir).glob("*.a3m"))
if not result_files:
raise FileNotFoundError(f"No .a3m files found in {config.results_dir}")
return str(result_files[0])
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="ColabFold search and A3M processing tool",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Required arguments
parser.add_argument("query_fpath", help="Path to the query FASTA file")
parser.add_argument("db_dir", help="Directory containing the databases")
parser.add_argument("results_dir", help="Directory for storing results")
# Optional arguments
parser.add_argument(
"--colabsearch", help="Path to colabfold_search", default="colabfold_search"
)
parser.add_argument(
"--mmseqs_path", help="Path to MMseqs2 binary", default="mmseqs"
)
parser.add_argument("--db1", help="First database name", default="uniref30_2302_db")
parser.add_argument("--db2", help="Templates database")
parser.add_argument(
"--db3", help="Environmental database (default: colabfold_envdb_202108_db)"
)
parser.add_argument(
"--use_env", help="Use environment settings", type=int, default=1
)
parser.add_argument("--filter", help="Apply filtering", type=int, default=1)
parser.add_argument(
"--db_load_mode", help="Database load mode", type=int, default=0
)
parser.add_argument(
"--output_split", help="Directory for split A3M files", default=None
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# Create configuration from arguments
config = LocalColabFoldConfig(
colabsearch=args.colabsearch,
query_fpath=args.query_fpath,
db_dir=args.db_dir,
results_dir=args.results_dir,
mmseqs_path=args.mmseqs_path,
db1=args.db1,
db2=args.db2,
db3=args.db3,
use_env=args.use_env,
filter=args.filter,
db_load_mode=args.db_load_mode,
)
# Run search
results_a3m = run_colabfold_search(config)
processor = A3MProcessor(results_a3m, args.results_dir)
if len(processor.chain_info) == 2:
processor.split_sequences()