|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
|
|
@dataclass |
|
class LocalColabFoldConfig: |
|
"""Configuration for ColabFold search.""" |
|
|
|
colabsearch: str |
|
query_fpath: str |
|
db_dir: str |
|
results_dir: str |
|
mmseqs_path: Optional[str] = None |
|
db1: str = "uniref30_2302_db" |
|
db2: Optional[str] = None |
|
db3: Optional[str] = "colabfold_envdb_202108_db" |
|
use_env: int = 1 |
|
filter: int = 1 |
|
db_load_mode: int = 0 |
|
|
|
|
|
class A3MProcessor: |
|
"""Processor for A3M file format.""" |
|
|
|
def __init__(self, a3m_file: str, out_dir: str): |
|
self.out_dir = out_dir |
|
self.a3m_file = Path(a3m_file) |
|
self.a3m_content = self._read_a3m_file() |
|
self.chain_info = self._parse_header() |
|
|
|
def _read_a3m_file(self) -> str: |
|
"""Read A3M file content.""" |
|
return self.a3m_file.read_text() |
|
|
|
def _parse_header(self) -> Tuple[List[str], Dict[str, Tuple[int, int]]]: |
|
"""Parse A3M header to get chain information.""" |
|
first_line = self.a3m_content.split("\n")[0] |
|
if first_line[0] == "#": |
|
lengths, oligomeric_state = first_line.split("\t") |
|
|
|
chain_lengths = [int(x) for x in lengths[1:].split(",")] |
|
chain_names = [f"10{x+1}" for x in range(len(oligomeric_state.split(",")))] |
|
|
|
|
|
seq_ranges = {} |
|
for i, name in enumerate(chain_names): |
|
start = sum(chain_lengths[:i]) |
|
end = sum(chain_lengths[: i + 1]) |
|
seq_ranges[name] = (start, end) |
|
|
|
return chain_names, seq_ranges |
|
else: |
|
non_pairing = ">query\n" + "\n".join(self.a3m_content.split("\n")[1:]) |
|
query_seq = self.a3m_content.split("\n")[1] |
|
pairing = f">query\n{query_seq}" |
|
msa_path = Path(self.out_dir) / "msa" |
|
msa_path.mkdir(exist_ok=True) |
|
msa_path = msa_path / "0" |
|
msa_path.mkdir(exist_ok=True) |
|
with open(msa_path / "non_pairing.a3m", "w") as f: |
|
f.write(non_pairing) |
|
|
|
with open(msa_path / "pairing.a3m", "w") as f: |
|
f.write(pairing) |
|
|
|
return [None] |
|
|
|
def _extract_sequence(self, line: str, range_tuple: Tuple[int, int]) -> str: |
|
"""Extract sequence for specific range.""" |
|
seq = [] |
|
no_insert_count = 0 |
|
start, end = range_tuple |
|
|
|
for char in line: |
|
if char.isupper() or char == "-": |
|
no_insert_count += 1 |
|
|
|
if start < no_insert_count <= end: |
|
seq.append(char) |
|
elif no_insert_count > end: |
|
break |
|
|
|
return "".join(seq) |
|
|
|
def split_sequences(self) -> None: |
|
"""Split A3M file into pairing and non-pairing sequences.""" |
|
out_dir = Path(self.out_dir) / "msa" |
|
chain_names, seq_ranges = self.chain_info |
|
|
|
pairing_a3ms = {name: [] for name in chain_names} |
|
nonpairing_a3ms = {name: [] for name in chain_names} |
|
|
|
current_query = None |
|
for line in self.a3m_content.split("\n"): |
|
if line.startswith("#"): |
|
continue |
|
|
|
if line.startswith(">"): |
|
name = line[1:] |
|
if name in chain_names: |
|
current_query = chain_names[chain_names.index(name)] |
|
elif name == "\t".join(chain_names): |
|
current_query = None |
|
|
|
|
|
if current_query: |
|
nonpairing_a3ms[current_query].append(line) |
|
else: |
|
for name in chain_names: |
|
pairing_a3ms[name].append(line) |
|
continue |
|
|
|
|
|
if not line: |
|
continue |
|
|
|
if current_query: |
|
seq = self._extract_sequence(line, seq_ranges[current_query]) |
|
nonpairing_a3ms[current_query].append(seq) |
|
else: |
|
for name in chain_names: |
|
seq = self._extract_sequence(line, seq_ranges[name]) |
|
pairing_a3ms[name].append(seq) |
|
|
|
self._write_output_files(out_dir, nonpairing_a3ms, pairing_a3ms) |
|
|
|
def _write_output_files( |
|
self, |
|
out_dir: Path, |
|
nonpairing_a3ms: Dict[str, List[str]], |
|
pairing_a3ms: Dict[str, List[str]], |
|
) -> None: |
|
"""Write split sequences to output files.""" |
|
out_dir.mkdir(exist_ok=True) |
|
|
|
|
|
for i, (name, lines) in enumerate(nonpairing_a3ms.items()): |
|
chain_dir = out_dir / str(i) |
|
chain_dir.mkdir(exist_ok=True) |
|
|
|
with open(chain_dir / "non_pairing.a3m", "w") as f: |
|
query_seq = lines[1] |
|
f.write(">query\n") |
|
f.write(f"{query_seq}\n") |
|
f.write("\n".join(lines[2:])) |
|
|
|
|
|
for i, (name, lines) in enumerate(pairing_a3ms.items()): |
|
chain_dir = out_dir / str(i) |
|
chain_dir.mkdir(exist_ok=True) |
|
|
|
with open(chain_dir / "pairing.a3m", "w") as f: |
|
query_seq = lines[1] |
|
f.write(">query\n") |
|
f.write(f"{query_seq}\n") |
|
|
|
|
|
sequences = {} |
|
for j, line in enumerate(lines[2:]): |
|
if line.startswith(">"): |
|
current_name = f"UniRef100_{line[1:].split()[i]}_{j}" |
|
sequences[current_name] = "" |
|
elif line and "DUMMY" not in current_name: |
|
sequences[current_name] = line |
|
|
|
|
|
for seq_name, seq in sequences.items(): |
|
if seq: |
|
f.write(f">{seq_name}\n{seq}\n") |
|
|
|
|
|
def run_colabfold_search(config: LocalColabFoldConfig) -> str: |
|
"""Run ColabFold search with given configuration.""" |
|
cmd = [config.colabsearch, config.query_fpath, config.db_dir, config.results_dir] |
|
|
|
|
|
if config.db1: |
|
cmd.extend(["--db1", config.db1]) |
|
if config.db2: |
|
cmd.extend(["--db2", config.db2]) |
|
if config.db3: |
|
cmd.extend(["--db3", config.db3]) |
|
if config.mmseqs_path: |
|
cmd.extend(["--mmseqs", config.mmseqs_path]) |
|
else: |
|
cmd.extend(["--mmseqs", "mmseqs"]) |
|
if config.use_env: |
|
cmd.extend(["--use-env", str(config.use_env)]) |
|
if config.filter: |
|
cmd.extend(["--filter", str(config.filter)]) |
|
if config.db_load_mode: |
|
cmd.extend(["--db-load-mode", str(config.db_load_mode)]) |
|
|
|
cmd = " ".join(cmd) |
|
os.system(cmd) |
|
|
|
|
|
result_files = list(Path(config.results_dir).glob("*.a3m")) |
|
if not result_files: |
|
raise FileNotFoundError(f"No .a3m files found in {config.results_dir}") |
|
return str(result_files[0]) |
|
|
|
|
|
def parse_args(): |
|
"""Parse command line arguments.""" |
|
parser = argparse.ArgumentParser( |
|
description="ColabFold search and A3M processing tool", |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
|
) |
|
|
|
|
|
parser.add_argument("query_fpath", help="Path to the query FASTA file") |
|
parser.add_argument("db_dir", help="Directory containing the databases") |
|
parser.add_argument("results_dir", help="Directory for storing results") |
|
|
|
|
|
parser.add_argument( |
|
"--colabsearch", help="Path to colabfold_search", default="colabfold_search" |
|
) |
|
parser.add_argument( |
|
"--mmseqs_path", help="Path to MMseqs2 binary", default="mmseqs" |
|
) |
|
parser.add_argument("--db1", help="First database name", default="uniref30_2302_db") |
|
parser.add_argument("--db2", help="Templates database") |
|
parser.add_argument( |
|
"--db3", help="Environmental database (default: colabfold_envdb_202108_db)" |
|
) |
|
parser.add_argument( |
|
"--use_env", help="Use environment settings", type=int, default=1 |
|
) |
|
parser.add_argument("--filter", help="Apply filtering", type=int, default=1) |
|
parser.add_argument( |
|
"--db_load_mode", help="Database load mode", type=int, default=0 |
|
) |
|
parser.add_argument( |
|
"--output_split", help="Directory for split A3M files", default=None |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
|
|
|
|
config = LocalColabFoldConfig( |
|
colabsearch=args.colabsearch, |
|
query_fpath=args.query_fpath, |
|
db_dir=args.db_dir, |
|
results_dir=args.results_dir, |
|
mmseqs_path=args.mmseqs_path, |
|
db1=args.db1, |
|
db2=args.db2, |
|
db3=args.db3, |
|
use_env=args.use_env, |
|
filter=args.filter, |
|
db_load_mode=args.db_load_mode, |
|
) |
|
|
|
|
|
results_a3m = run_colabfold_search(config) |
|
|
|
processor = A3MProcessor(results_a3m, args.results_dir) |
|
if len(processor.chain_info) == 2: |
|
processor.split_sequences() |
|
|