diff --git a/README.md b/README.md index 5e14e8a646b186988df1851508a8858108df06fb..f9fc06717dff90ee78abc20986860edd0d90c806 100644 --- a/README.md +++ b/README.md @@ -11,21 +11,4 @@ ### 工程 1. 实现多进程预测功能,例如在一个模型推ç†çš„æ—¶å€™ï¼ŒåŒæ¥åŠ è½½å…¶ä»–æ¨¡åž‹çš„checkpoint到Memoryä¸ - æ€è·¯æ˜¯ç»´æŠ¤4个独立è¿è¡Œçš„uvicornæœåŠ¡ or MOSæœåŠ¡ -2. 打包æˆdocker imageï¼Œæ¯æ¬¡å¯åŠ¨çš„æ—¶å€™ä»ŽAOSS上下载envå’Œcheckpoint到本地 - -## è‰ç¨¿ -### Configuration -#### Chai-1 -```bash -mkdir -p ~/.cache/ -cd ~/.cache -ln -s /ai/share/workspace/zhangcb/.cache/huggingface huggingface - -source /ai/share/workspace/zhangcb/conda_env/chai/bin/activate -chai fold --use-msa-server $input_path $output_dir -``` - -#### ColabFold MSA -```bash -export PATH=$PATH:/ai/share/workspace/zhangcb/LocalBins -``` \ No newline at end of file +2. 打包æˆdocker imageï¼Œæ¯æ¬¡å¯åŠ¨çš„æ—¶å€™ä»ŽAOSS上下载envå’Œcheckpoint到本地 \ No newline at end of file diff --git a/protenix/configs/configs_base.py b/protenix/configs/configs_base.py index 4a6e219ba19237316ef251731f9f83449f982cb7..327b2eccd81a9a5b8be3918cf3aeb5a418df042b 100644 --- a/protenix/configs/configs_base.py +++ b/protenix/configs/configs_base.py @@ -31,9 +31,9 @@ basic_configs = { "eval_first": False, # run evaluate() before training steps "iters_to_accumulate": 1, "eval_only": False, - "load_checkpoint_path": "", + "load_checkpoint_path": "/ai/share/workspace/zhangcb/ABCPFold/src/protenix/release_data/checkpoint/model_v0.2.0.pt", "load_ema_checkpoint_path": "", - "load_strict": False, + "load_strict": True, "load_params_only": True, "skip_load_step": False, "skip_load_optimizer": False, diff --git a/protenix/configs/configs_data.py b/protenix/configs/configs_data.py index ecf50725ae27e9f0f9a39bc3a690700ed01e3a0d..99f3c7a09781cb06d99fdcfd94db22a0eae1b7d9 100644 --- a/protenix/configs/configs_data.py +++ b/protenix/configs/configs_data.py @@ -60,12 +60,24 @@ default_weighted_pdb_configs = { "shuffle_sym_ids": GlobalConfigValue("train_shuffle_sym_ids"), } -DATA_ROOT_DIR = "/af3-dev/release_data/" -CCD_COMPONENTS_FILE_PATH = os.path.join(DATA_ROOT_DIR, "components.v20240608.cif") +DATA_ROOT_DIR = os.environ.get("PROTENIX_DATA_ROOT_DIR", "/af3-dev/release_data/") + +# Use CCD cache created by scripts/gen_ccd_cache.py priority. (without date in filename) +# See: docs/prepare_data.md +CCD_COMPONENTS_FILE_PATH = os.path.join(DATA_ROOT_DIR, "components.cif") CCD_COMPONENTS_RDKIT_MOL_FILE_PATH = os.path.join( - DATA_ROOT_DIR, "components.v20240608.cif.rdkit_mol.pkl" + DATA_ROOT_DIR, "components.cif.rdkit_mol.pkl" ) +if (not os.path.exists(CCD_COMPONENTS_FILE_PATH)) or ( + not os.path.exists(CCD_COMPONENTS_RDKIT_MOL_FILE_PATH) +): + CCD_COMPONENTS_FILE_PATH = os.path.join(DATA_ROOT_DIR, "components.v20240608.cif") + CCD_COMPONENTS_RDKIT_MOL_FILE_PATH = os.path.join( + DATA_ROOT_DIR, "components.v20240608.cif.rdkit_mol.pkl" + ) + + # This is a patch in inference stage for users that do not have root permission. # If you run # ``` @@ -80,17 +92,25 @@ if (not os.path.exists(CCD_COMPONENTS_FILE_PATH)) or ( not os.path.exists(CCD_COMPONENTS_RDKIT_MOL_FILE_PATH) ): print("Try to find the ccd cache data in the code directory for inference.") - # current_file_path = os.path.abspath(__file__) - # current_directory = os.path.dirname(current_file_path) - # code_directory = os.path.dirname(current_directory) - code_directory = '/ai/share/workspace/zhangcb/Protenix/' + current_file_path = os.path.abspath(__file__) + current_directory = os.path.dirname(current_file_path) + code_directory = os.path.dirname(current_directory) data_cache_dir = os.path.join(code_directory, "release_data/ccd_cache") - CCD_COMPONENTS_FILE_PATH = os.path.join(data_cache_dir, "components.v20240608.cif") + CCD_COMPONENTS_FILE_PATH = os.path.join(data_cache_dir, "components.cif") CCD_COMPONENTS_RDKIT_MOL_FILE_PATH = os.path.join( - data_cache_dir, "components.v20240608.cif.rdkit_mol.pkl" + data_cache_dir, "components.cif.rdkit_mol.pkl" ) + if (not os.path.exists(CCD_COMPONENTS_FILE_PATH)) or ( + not os.path.exists(CCD_COMPONENTS_RDKIT_MOL_FILE_PATH) + ): + CCD_COMPONENTS_FILE_PATH = os.path.join( + data_cache_dir, "components.v20240608.cif" + ) + CCD_COMPONENTS_RDKIT_MOL_FILE_PATH = os.path.join( + data_cache_dir, "components.v20240608.cif.rdkit_mol.pkl" + ) data_configs = { "num_dl_workers": 16, diff --git a/protenix/configs/configs_inference.py b/protenix/configs/configs_inference.py index b52f7a4333d5c38c40ae066d809d8d4d8f2e7026..2155148694d4561aaf860dc35b8018c05554cbcf 100644 --- a/protenix/configs/configs_inference.py +++ b/protenix/configs/configs_inference.py @@ -24,10 +24,9 @@ code_directory = os.path.dirname(current_directory) # "./release_data/checkpoint/model_v0.2.0.pt" inference_configs = { "seeds": ListValue([101]), - # "seeds": ListValue(range(101, 111)), - # "seeds": ListValue(range(1, 101)), "dump_dir": "./output", "need_atom_confidence": False, + "sorted_by_ranking_score": True, "input_json_path": RequiredValue(str), "load_checkpoint_path": os.path.join( code_directory, "./release_data/checkpoint/model_v0.2.0.pt" diff --git a/protenix/inference.py b/protenix/inference.py index 0d12c365b561455cbbb215f544addd8e9ec3c2bf..5b3e1b172126c4a26d9581cbef31e9f24005f119 100644 --- a/protenix/inference.py +++ b/protenix/inference.py @@ -23,10 +23,11 @@ from typing import Any, Mapping import torch import torch.distributed as dist -import sys from configs.configs_base import configs as configs_base from configs.configs_data import data_configs from configs.configs_inference import inference_configs +from runner.dumper import DataDumper + from protenix.config import parse_configs, parse_sys_args from protenix.data.infer_data_pipeline import get_inference_dataloader from protenix.model.protenix import Protenix @@ -34,8 +35,6 @@ from protenix.utils.distributed import DIST_WRAPPER from protenix.utils.seed import seed_everything from protenix.utils.torch_utils import to_device from protenix.web_service.dependency_url import URL -from runner.dumper import DataDumper -from runner.msa_search import contain_msa_res, msa_search_update logger = logging.getLogger(__name__) @@ -47,7 +46,10 @@ class InferenceRunner(object): self.init_basics() self.init_model() self.load_checkpoint() - self.init_dumper(need_atom_confidence=configs.need_atom_confidence) + self.init_dumper( + need_atom_confidence=configs.need_atom_confidence, + sorted_by_ranking_score=configs.sorted_by_ranking_score, + ) def init_env(self) -> None: self.print( @@ -112,14 +114,18 @@ class InferenceRunner(object): } self.model.load_state_dict( state_dict=checkpoint["model"], - strict=True, + strict=self.configs.load_strict, ) self.model.eval() self.print(f"Finish loading checkpoint.") - def init_dumper(self, need_atom_confidence: bool = False): + def init_dumper( + self, need_atom_confidence: bool = False, sorted_by_ranking_score: bool = True + ): self.dumper = DataDumper( - base_dir=self.dump_dir, need_atom_confidence=need_atom_confidence + base_dir=self.dump_dir, + need_atom_confidence=need_atom_confidence, + sorted_by_ranking_score=sorted_by_ranking_score, ) # Adapted from runner.train.Trainer.evaluate @@ -157,32 +163,29 @@ class InferenceRunner(object): def download_infercence_cache(configs: Any, model_version: str = "v0.2.0") -> None: - # current_file_path = os.path.abspath(__file__) - # current_directory = os.path.dirname(current_file_path) - # code_directory = os.path.dirname(current_directory) - code_directory = os.path.dirname(os.path.abspath(__file__)) - # code_directory = '/ai/share/workspace/zhangcb/Protenix/' - - data_cache_dir = os.path.join(code_directory, "release_data/ccd_cache") - os.makedirs(data_cache_dir, exist_ok=True) - for cache_name, fname in [ - ("ccd_components_file", "components.v20240608.cif"), - ("ccd_components_rdkit_mol_file", "components.v20240608.cif.rdkit_mol.pkl"), - ]: - if not opexists(cache_path := os.path.abspath(opjoin(data_cache_dir, fname))): + + for cache_name in ("ccd_components_file", "ccd_components_rdkit_mol_file"): + cur_cache_fpath = configs["data"][cache_name] + if not opexists(cur_cache_fpath): + os.makedirs(os.path.dirname(cur_cache_fpath), exist_ok=True) tos_url = URL[cache_name] - logger.info(f"Downloading data cache from\n {tos_url} to {cache_path}") - urllib.request.urlretrieve(tos_url, cache_path) + assert os.path.basename(tos_url) == os.path.basename(cur_cache_fpath), ( + f"{cache_name} file name is incorrect, `{tos_url}` and " + f"`{cur_cache_fpath}`. Please check and try again." + ) + logger.info( + f"Downloading data cache from\n {tos_url}... to {cur_cache_fpath}" + ) + urllib.request.urlretrieve(tos_url, cur_cache_fpath) checkpoint_path = configs.load_checkpoint_path if not opexists(checkpoint_path): - checkpoint_path = os.path.join( - code_directory, f"release_data/checkpoint/model_{model_version}.pt" - ) os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) tos_url = URL[f"model_{model_version}"] - logger.info(f"Downloading model checkpoint from\n {tos_url} to {checkpoint_path}") + logger.info( + f"Downloading model checkpoint from\n {tos_url}... to {checkpoint_path}" + ) urllib.request.urlretrieve(tos_url, checkpoint_path) try: ckpt = torch.load(checkpoint_path) @@ -193,7 +196,6 @@ def download_infercence_cache(configs: Any, model_version: str = "v0.2.0") -> No "Download model checkpoint failed, please download by yourself with " f"wget {tos_url} -O {checkpoint_path}" ) - configs.load_checkpoint_path = checkpoint_path def update_inference_configs(configs: Any, N_token: int): @@ -213,38 +215,31 @@ def update_inference_configs(configs: Any, N_token: int): def infer_predict(runner: InferenceRunner, configs: Any) -> None: - # update msa result if not contains precomputed msa dir - if not contain_msa_res(configs.input_json_path): - logger.info( - f"{configs.input_json_path} dose not contain precomputed msa dir, now searching it." - ) - configs.input_json_path = msa_search_update( - configs.input_json_path, configs.dump_dir - ) - logger.info( - f"msa searching completed, new input json is {configs.input_json_path}" - ) # Data logger.info(f"Loading data from\n{configs.input_json_path}") - dataloader = get_inference_dataloader(configs=configs) + try: + dataloader = get_inference_dataloader(configs=configs) + except Exception as e: + error_message = f"{e}:\n{traceback.format_exc()}" + logger.info(error_message) + with open(opjoin(runner.error_dir, "error.txt"), "a") as f: + f.write(error_message) + return num_data = len(dataloader.dataset) for seed in configs.seeds: - seed_everything(seed=seed, deterministic=False) + seed_everything(seed=seed, deterministic=configs.deterministic) for batch in dataloader: try: data, atom_array, data_error_message = batch[0] + sample_name = data["sample_name"] if len(data_error_message) > 0: logger.info(data_error_message) - with open( - opjoin(runner.error_dir, f"{data['sample_name']}.txt"), - "w", - ) as f: + with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f: f.write(data_error_message) continue - sample_name = data["sample_name"] logger.info( ( f"[Rank {DIST_WRAPPER.rank} ({data['sample_index'] + 1}/{num_data})] {sample_name}: " @@ -273,15 +268,10 @@ def infer_predict(runner: InferenceRunner, configs: Any) -> None: error_message = f"[Rank {DIST_WRAPPER.rank}]{data['sample_name']} {e}:\n{traceback.format_exc()}" logger.info(error_message) # Save error info - if opexists( - error_path := opjoin(runner.error_dir, f"{sample_name}.txt") - ): - os.remove(error_path) - with open(error_path, "w") as f: + with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f: f.write(error_message) if hasattr(torch.cuda, "empty_cache"): torch.cuda.empty_cache() - raise RuntimeError(f"run infer failed: {str(e)}") def main(configs: Any) -> None: @@ -299,7 +289,7 @@ def run() -> None: filemode="w", ) configs_base["use_deepspeed_evo_attention"] = ( - os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "true" + os.environ.get("USE_DEEPSPEED_EVO_ATTENTION", False) == "true" ) configs = {**configs_base, **{"data": data_configs}, **inference_configs} configs = parse_configs( diff --git a/protenix/protenix/data/constants.py b/protenix/protenix/data/constants.py index 4e786a83999d5ab463f24ad7f6922e044bda9398..308adf6c30be434c2a3eeb2ac85fd49ceee2b66a 100644 --- a/protenix/protenix/data/constants.py +++ b/protenix/protenix/data/constants.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from rdkit.Chem import GetPeriodicTable + EvaluationChainInterface = [ "intra_ligand", "intra_dna", @@ -36,10 +38,16 @@ EntityPolyTypeDict = { "ligand": ["cyclic-pseudo-peptide", "other"], } +CRYSTALLIZATION_METHODS = { + "X-RAY DIFFRACTION", + "NEUTRON DIFFRACTION", + "ELECTRON CRYSTALLOGRAPHY", + "POWDER CRYSTALLOGRAPHY", + "FIBER DIFFRACTION", +} ### Protein Constants ### # https://mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v40.dic/Items/_entity_poly.pdbx_seq_one_letter_code_can.html -from rdkit.Chem import GetPeriodicTable mmcif_restype_1to3 = { "A": "ALA", diff --git a/protenix/protenix/data/data_pipeline.py b/protenix/protenix/data/data_pipeline.py index 8dad0b7ec8f45c35d7850d996e9e0054891fad52..e254655dfcde2f1aa7ba6d0996cb5877a21e4f96 100644 --- a/protenix/protenix/data/data_pipeline.py +++ b/protenix/protenix/data/data_pipeline.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from collections import defaultdict from pathlib import Path @@ -24,7 +25,8 @@ import torch from biotite.structure import AtomArray from protenix.data.msa_featurizer import MSAFeaturizer -from protenix.data.tokenizer import TokenArray +from protenix.data.parser import DistillationMMCIFParser, MMCIFParser +from protenix.data.tokenizer import AtomArrayTokenizer, TokenArray from protenix.utils.cropping import CropData from protenix.utils.file_io import load_gzip_pickle @@ -32,6 +34,64 @@ torch.multiprocessing.set_sharing_strategy("file_system") class DataPipeline(object): + """ + DataPipeline class provides static methods to handle various data processing tasks related to bioassembly structures. + """ + + @staticmethod + def get_data_from_mmcif( + mmcif: Union[str, Path], + pdb_cluster_file: Union[str, Path, None] = None, + dataset: str = "WeightedPDB", + ) -> tuple[list[dict[str, Any]], dict[str, Any]]: + """ + Get raw data from mmcif with tokenizer and a list of chains and interfaces for sampling. + + Args: + mmcif (Union[str, Path]): The raw mmcif file. + pdb_cluster_file (Union[str, Path, None], optional): Cluster info txt file. Defaults to None. + dataset (str, optional): The dataset type, either "WeightedPDB" or "Distillation". Defaults to "WeightedPDB". + + Returns: + tuple[list[dict[str, Any]], dict[str, Any]]: + sample_indices_list (list[dict[str, Any]]): The sample indices list (each one is a chain or an interface). + bioassembly_dict (dict[str, Any]): The bioassembly dict with sequence, atom_array, and token_array. + """ + try: + if dataset == "WeightedPDB": + parser = MMCIFParser(mmcif_file=mmcif) + bioassembly_dict = parser.get_bioassembly() + elif dataset == "Distillation": + parser = DistillationMMCIFParser(mmcif_file=mmcif) + bioassembly_dict = parser.get_structure_dict() + else: + raise NotImplementedError( + 'Unsupported "dataset", please input either "WeightedPDB" or "Distillation".' + ) + + sample_indices_list = parser.make_indices( + bioassembly_dict=bioassembly_dict, pdb_cluster_file=pdb_cluster_file + ) + if len(sample_indices_list) == 0: + # empty indices and AtomArray + return [], bioassembly_dict + + atom_array = bioassembly_dict["atom_array"] + atom_array.set_annotation( + "resolution", [parser.resolution] * len(atom_array) + ) + + tokenizer = AtomArrayTokenizer(atom_array) + token_array = tokenizer.get_token_array() + bioassembly_dict["msa_features"] = None + bioassembly_dict["template_features"] = None + + bioassembly_dict["token_array"] = token_array + return sample_indices_list, bioassembly_dict + + except Exception as e: + logging.warning("Gen data failed for %s due to %s", mmcif, e) + return [], {} @staticmethod def get_label_entity_id_to_asym_id_int(atom_array: AtomArray) -> dict[str, int]: diff --git a/protenix/protenix/data/dataset.py b/protenix/protenix/data/dataset.py index 62878cbf072d4cd71ee15f89b110342fa6c45136..fcaf3164de54281232cdaa012efdbdff1eeb6d97 100644 --- a/protenix/protenix/data/dataset.py +++ b/protenix/protenix/data/dataset.py @@ -894,25 +894,34 @@ def get_weighted_pdb_weight( cluster_size: int, chain_count: dict, eps: float = 1e-9, - beta_dict: dict = { - "chain": 0.5, - "interface": 1, - }, - alpha_dict: dict = { - "prot": 3, - "nuc": 3, - "ligand": 1, - }, + beta_dict: Optional[dict] = None, + alpha_dict: Optional[dict] = None, ) -> float: """ - Get sample weight for each examples in weighted pdb dataset. AF3-SI (1) - Args: - data_type: chain or interface - cluster_size: cluster size of this chain/interface - chain_count: count of each kinds of chains, {"prot": int, "nuc": int, "ligand": int} + Get sample weight for each example in a weighted PDB dataset. + + data_type (str): Type of data, either 'chain' or 'interface'. + cluster_size (int): Cluster size of this chain/interface. + chain_count (dict): Count of each kind of chains, e.g., {"prot": int, "nuc": int, "ligand": int}. + eps (float, optional): A small epsilon value to avoid division by zero. Default is 1e-9. + beta_dict (Optional[dict], optional): Dictionary containing beta values for 'chain' and 'interface'. + alpha_dict (Optional[dict], optional): Dictionary containing alpha values for different chain types. + Returns: - weights: float + float: Calculated weight for the given chain/interface. """ + if not beta_dict: + beta_dict = { + "chain": 0.5, + "interface": 1, + } + if not alpha_dict: + alpha_dict = { + "prot": 3, + "nuc": 3, + "ligand": 1, + } + assert cluster_size > 0 assert data_type in ["chain", "interface"] beta = beta_dict[data_type] diff --git a/protenix/protenix/data/filter.py b/protenix/protenix/data/filter.py index d9d05b43007dea83d425e039ed4064ae66426f81..1066b8f4a9cc11aeee7e5356336f73eb11de3b0d 100644 --- a/protenix/protenix/data/filter.py +++ b/protenix/protenix/data/filter.py @@ -14,7 +14,8 @@ import biotite.structure as struc import numpy as np -from biotite.structure import AtomArray +from biotite.structure import AtomArray, get_molecule_indices +from scipy.spatial.distance import cdist from protenix.data.constants import CRYSTALLIZATION_AIDS @@ -80,3 +81,557 @@ class Filter(object): non_aids_mask = ~np.isin(atom_array.res_name, CRYSTALLIZATION_AIDS) poly_mask = np.isin(atom_array.label_entity_id, list(entity_poly_type.keys())) return atom_array[poly_mask | non_aids_mask] + + @staticmethod + def _get_clashing_chains( + atom_array: AtomArray, chain_ids: list[str] + ) -> tuple[np.ndarray, list[int]]: + """ + Calculate the number of atoms clashing with other chains for each chain + and return a matrix that records the count of clashing atoms. + + Note: if two chains are covalent, they are not considered as clashing. + + Args: + atom_array (AtomArray): All atoms, including those not resolved. + chain_ids (list[str]): Unique chain indices of resolved atoms. + + Returns: + tuple: + clash_records (numpy.ndarray): Matrix of clashing atom num. + (i, j) means the ratio of i's atom clashed with j's atoms. + Note: (i, j) != (j, i). + chain_resolved_atom_nums (list[int]): The number of resolved atoms corresponding to each chain ID. + """ + is_resolved_centre_atom = ( + atom_array.centre_atom_mask == 1 + ) & atom_array.is_resolved + cell_list = struc.CellList( + atom_array, cell_size=1.7, selection=is_resolved_centre_atom + ) + + # (i, j) means the ratio of i's atom clashed with j's atoms + clash_records = np.zeros((len(chain_ids), len(chain_ids))) + + # record the number of resolved atoms for each chain + chain_resolved_atom_nums = [] + + # record covalent relationship between chains + chains_covalent_dict = {} + for idx, chain_id_i in enumerate(chain_ids): + for chain_id_j in chain_ids[idx + 1 :]: + mol_indices = get_molecule_indices( + atom_array[np.isin(atom_array.chain_id, [chain_id_i, chain_id_j])] + ) + if len(mol_indices) == 1: + covalent = 1 + else: + covalent = 0 + chains_covalent_dict[(chain_id_i, chain_id_j)] = covalent + chains_covalent_dict[(chain_id_j, chain_id_i)] = covalent + + for i, chain_id in enumerate(chain_ids): + coords = atom_array.coord[ + (atom_array.chain_id == chain_id) & is_resolved_centre_atom + ] + chain_resolved_atom_nums.append(len(coords)) + chain_atom_ids = np.where(atom_array.chain_id == chain_id)[0] + chain_atom_ids_set = set(chain_atom_ids) | {-1} + + # Get atom indices from the current cell and the eight surrounding cells. + neighbors_ids_2d = cell_list.get_atoms_in_cells(coords, cell_radius=1) + neighbors_ids = np.unique(neighbors_ids_2d) + + # Remove the atom indices of the current chain. + other_chain_atom_ids = list(set(neighbors_ids) - chain_atom_ids_set) + + if not other_chain_atom_ids: + continue + else: + # Calculate the distance matrix with neighboring atoms. + other_chain_atom_coords = atom_array.coord[other_chain_atom_ids] + dist_mat = cdist(coords, other_chain_atom_coords, metric="euclidean") + clash_mat = dist_mat < 1.6 # change 1.7 to 1.6 for more compatibility + if np.any(clash_mat): + clashed_other_chain_ids = atom_array.chain_id[other_chain_atom_ids] + + for other_chain_id in set(clashed_other_chain_ids): + + # two chains covalent with each other + if chains_covalent_dict[(chain_id, other_chain_id)]: + continue + + cols = np.where(clashed_other_chain_ids == other_chain_id)[0] + + # how many i's atoms clashed with j + any_atom_clashed = np.any( + clash_mat[:, cols].astype(int), axis=1 + ) + clashed_atom_num = np.sum(any_atom_clashed.astype(int)) + + if clashed_atom_num > 0: + j = chain_ids.index(other_chain_id) + clash_records[i][j] += clashed_atom_num + return clash_records, chain_resolved_atom_nums + + @staticmethod + def _get_removed_clash_chain_ids( + clash_records: np.ndarray, + chain_ids: list[str], + chain_resolved_atom_nums: list[int], + core_chain_id: np.ndarray = [], + ) -> list[str]: + """ + Perform pairwise comparisons on the chains, and select the chain IDs + to be deleted according to the clahsing chain rules. + + Args: + clash_records (numpy.ndarray): Matrix of clashing atom num. + (i, j) means the ratio of i's atom clashed with j's atoms. + Note: (i, j) != (j, i). + chain_ids (list[str]): Unique chain indices of resolved atoms. + chain_resolved_atom_nums (list[int]): The number of resolved atoms corresponding to each chain ID. + core_chain_id (np.ndarray): The chain ID of the core chain. + + Returns: + list[str]: A list of chain IDs that have been determined for deletion. + """ + removed_chain_ids = [] + for i in range(len(chain_ids)): + atom_num_i = chain_resolved_atom_nums[i] + chain_idx_i = chain_ids[i] + + if chain_idx_i in removed_chain_ids: + continue + + for j in range(i + 1, len(chain_ids)): + atom_num_j = chain_resolved_atom_nums[j] + chain_idx_j = chain_ids[j] + + if chain_idx_j in removed_chain_ids: + continue + + clash_num_ij, clash_num_ji = ( + clash_records[i][j], + clash_records[j][i], + ) + + clash_ratio_ij = clash_num_ij / atom_num_i + clash_ratio_ji = clash_num_ji / atom_num_j + + if clash_ratio_ij <= 0.3 and clash_ratio_ji <= 0.3: + # not reaches the threshold + continue + else: + # clashing chains + if ( + chain_idx_i in core_chain_id + and chain_idx_j not in core_chain_id + ): + removed_chain_idx = chain_idx_j + elif ( + chain_idx_i not in core_chain_id + and chain_idx_j in core_chain_id + ): + removed_chain_idx = chain_idx_i + + elif clash_ratio_ij > clash_ratio_ji: + removed_chain_idx = chain_idx_i + elif clash_ratio_ij < clash_ratio_ji: + removed_chain_idx = chain_idx_j + else: + if atom_num_i < atom_num_j: + removed_chain_idx = chain_idx_i + elif atom_num_i > atom_num_j: + removed_chain_idx = chain_idx_j + else: + removed_chain_idx = sorted([chain_idx_i, chain_idx_j])[1] + + removed_chain_ids.append(removed_chain_idx) + + if removed_chain_idx == chain_idx_i: + # chain i already removed + break + return removed_chain_ids + + @staticmethod + def remove_polymer_chains_all_residues_unknown( + atom_array: AtomArray, + entity_poly_type: dict, + ) -> AtomArray: + """remove chains with all residues unknown""" + chain_starts = struc.get_chain_starts(atom_array, add_exclusive_stop=True) + invalid_chains = [] # list of [start, end) + for index in range(len(chain_starts) - 1): + start, end = chain_starts[index], chain_starts[index + 1] + entity_id = atom_array[start].label_entity_id + if ( + entity_poly_type.get(entity_id, "non-poly") == "polypeptide(L)" + and np.all(atom_array.res_name[start:end] == "UNK") + ) or ( + entity_poly_type.get(entity_id, "non-poly") + in ( + "polyribonucleotide", + "polydeoxyribonucleotide", + ) + and np.all(atom_array.res_name[start:end] == "N") + ): + invalid_chains.append((start, end)) + mask = np.ones(len(atom_array), dtype=bool) + for start, end in invalid_chains: + mask[start:end] = False + atom_array = atom_array[mask] + return atom_array + + @staticmethod + def remove_polymer_chains_too_short( + atom_array: AtomArray, entity_poly_type: dict + ) -> AtomArray: + chain_starts = struc.get_chain_starts(atom_array, add_exclusive_stop=True) + invalid_chains = [] # list of [start, end) + for index in range(len(chain_starts) - 1): + start, end = chain_starts[index], chain_starts[index + 1] + entity_id = atom_array[start].label_entity_id + num_residue_ids = len(set(atom_array.label_seq_id[start:end])) + if ( + entity_poly_type.get(entity_id, "non-poly") + in ( + "polypeptide(L)", # TODO: how to handle polypeptide(D)? + "polyribonucleotide", + "polydeoxyribonucleotide", + ) + and num_residue_ids < 4 + ): + invalid_chains.append((start, end)) + mask = np.ones(len(atom_array), dtype=bool) + for start, end in invalid_chains: + mask[start:end] = False + atom_array = atom_array[mask] + return atom_array + + @staticmethod + def remove_polymer_chains_with_consecutive_c_alpha_too_far_away( + atom_array: AtomArray, entity_poly_type: dict, max_distance: float = 10.0 + ) -> AtomArray: + chain_starts = struc.get_chain_starts(atom_array, add_exclusive_stop=True) + invalid_chains = [] # list of [start, end) + for index in range(len(chain_starts) - 1): + start, end = chain_starts[index], chain_starts[index + 1] + entity_id = atom_array.label_entity_id[start] + if entity_poly_type.get(entity_id, "non-poly") == "polypeptide(L)": + peptide_atoms = atom_array[start:end] + ca_atoms = peptide_atoms[peptide_atoms.atom_name == "CA"] + seq_ids = ca_atoms.label_seq_id + seq_ids[seq_ids == "."] = "-100" + seq_ids = seq_ids.astype(np.int64) + dist_square = np.sum( + (ca_atoms[:-1].coord - ca_atoms[1:].coord) ** 2, axis=-1 + ) + invalid_neighbor_mask = (dist_square > max_distance**2) & ( + seq_ids[:-1] + 1 == seq_ids[1:] + ) + if np.any(invalid_neighbor_mask): + invalid_chains.append((start, end)) + mask = np.ones(len(atom_array), dtype=bool) + for start, end in invalid_chains: + mask[start:end] = False + atom_array = atom_array[mask] + return atom_array + + @staticmethod + def too_many_chains_filter( + atom_array: AtomArray, + interface_radius: int = 15, + max_chains_num: int = 20, + core_indices: list[int] = None, + max_tokens_num: int = None, + ) -> tuple[AtomArray, int]: + """ + Ref: AlphaFold3 SI Chapter 2.5.4 + + For bioassemblies with greater than 20 chains, we select a random interface token + (with a centre atom <15 Ã… to the centre atom of a token in another chain) + and select the closest 20 chains to this token based on + minimum distance between any tokens centre atom. + + Note: due to the presence of covalent small molecules, + treat the covalent small molecule and the polymer it is attached to + as a single chain to avoid inadvertently removing the covalent small molecules. + Use the mol_id added to the AtomArray to differentiate between the various + parts of the structure composed of covalent bonds. + + Args: + atom_array (AtomArray): Biotite AtomArray Object of a Bioassembly. + interface_radius (int, optional): Atoms within this distance of the central atom are considered interface atoms. + Defaults to 15. + max_chains_num (int, optional): The maximum number of chains permitted in a bioassembly. + Filtration will be applied if exceeds this value. Defaults to 20. + core_indices (list[int], optional): A list of indices to be used as chose the central atom. + And corresponding chains in the list will be selected proriority. + If None, a random index from whole AtomArray will be selected. Defaults to None. + max_tokens_num (int, optional): The maximum number of tokens permitted in a bioassembly. + If not None, after more than max_chains_num, if the max_tokens_num is not reached, + it will continue to append the chains. + + Returns: + tuple: + - atom_array (AtomArray): An AtomArray that has been processed through this filter. + - input_chains_num (int): The number of chain in the input AtomArray. + This is to log whether the filter has been utilized. + """ + # each mol is a so called "chain" in the context of this filter. + input_chains_num = len(np.unique(atom_array.mol_id)) + if input_chains_num <= max_chains_num: + # no change + return atom_array, input_chains_num + + is_resolved_centre_atom = ( + atom_array.centre_atom_mask == 1 + ) & atom_array.is_resolved + + cell_list = struc.CellList( + atom_array, cell_size=interface_radius, selection=is_resolved_centre_atom + ) + resolved_centre_atom = atom_array[is_resolved_centre_atom] + + assert resolved_centre_atom, "There is no resolved central atom." + + # random pick centre atom + if core_indices is None: + index_shuf = np.random.default_rng(seed=42).permutation( + len(resolved_centre_atom) + ) + else: + index_shuf = np.array(core_indices) + resolved_centre_atom_indices = np.nonzero(is_resolved_centre_atom)[0] + + # get indices of resolved_centre_atom + index_shuf = np.array( + [ + np.where(resolved_centre_atom_indices == idx)[0][0] + for idx in index_shuf + if idx in resolved_centre_atom_indices + ] + ) + np.random.default_rng(seed=42).shuffle(index_shuf) + + chosen_centre_atom = None + for idx in index_shuf: + centre_atom = resolved_centre_atom[idx] + neighbors_indices = cell_list.get_atoms( + centre_atom.coord, radius=interface_radius + ) + neighbors_indices = neighbors_indices[neighbors_indices != -1] + + neighbors_chain_ids = np.unique(atom_array.mol_id[neighbors_indices]) + # neighbors include centre atom itself + if len(neighbors_chain_ids) > 1: + chosen_centre_atom = centre_atom + break + + # The distance between the central atoms in any two chains is greater than 15 angstroms. + if chosen_centre_atom is None: + return None, input_chains_num + + dist_mat = cdist(centre_atom.coord.reshape((1, -1)), resolved_centre_atom.coord) + sorted_chain_id = np.array( + [ + chain_id + for chain_id, _dist in sorted( + zip(resolved_centre_atom.mol_id, dist_mat[0]), + key=lambda pair: pair[1], + ) + ] + ) + + if core_indices is not None: + # select core proriority + core_mol_id = np.unique(atom_array.mol_id[core_indices]) + in_core_mask = np.isin(sorted_chain_id, core_mol_id) + sorted_chain_id = np.concatenate( + (sorted_chain_id[in_core_mask], sorted_chain_id[~in_core_mask]) + ) + + closest_chain_id = set() + chain_ids_to_token_num = {} + if max_tokens_num is None: + max_tokens_num = 0 + + tokens = 0 + for chain_id in sorted_chain_id: + # get token num + if chain_id not in chain_ids_to_token_num: + chain_ids_to_token_num[chain_id] = atom_array.centre_atom_mask[ + atom_array.mol_id == chain_id + ].sum() + chain_token_num = chain_ids_to_token_num[chain_id] + + if len(closest_chain_id) >= max_chains_num: + if tokens + chain_token_num > max_tokens_num: + break + + closest_chain_id.add(chain_id) + tokens += chain_token_num + + atom_array = atom_array[np.isin(atom_array.mol_id, list(closest_chain_id))] + output_chains_num = len(np.unique(atom_array.mol_id)) + assert ( + output_chains_num == max_chains_num + or atom_array.centre_atom_mask.sum() <= max_tokens_num + ) + return atom_array, input_chains_num + + @staticmethod + def remove_clashing_chains( + atom_array: AtomArray, + core_indices: list[int] = None, + ) -> AtomArray: + """ + Ref: AlphaFold3 SI Chapter 2.5.4 + + Clashing chains are removed. + Clashing chains are defined as those with >30% of atoms within 1.7 Ã… of an atom in another chain. + If two chains are clashing with each other, the chain with the greater percentage of clashing atoms will be removed. + If the same fraction of atoms are clashing, the chain with fewer total atoms is removed. + If the chains have the same number of atoms, then the chain with the larger chain id is removed. + + Note: if two chains are covalent, they are not considered as clashing. + + Args: + atom_array (AtomArray): Biotite AtomArray Object of a Bioassembly. + core_indices (list[int]): A list of indices for core structures, + where these indices correspond to structures that will be preferentially + retained when pairwise clash chain assessments are performed. + + Returns: + atom_array (AtomArray): An AtomArray that has been processed through this filter. + removed_chain_ids (list[str]): A list of chain IDs that have been determined for deletion. + This is to log whether the filter has been utilized. + """ + chain_ids = np.unique(atom_array.chain_id[atom_array.is_resolved]).tolist() + + if core_indices is not None: + core_chain_id = np.unique(atom_array.chain_id[core_indices]) + else: + core_chain_id = np.array([]) + + clash_records, chain_resolved_atom_nums = Filter._get_clashing_chains( + atom_array, chain_ids + ) + removed_chain_ids = Filter._get_removed_clash_chain_ids( + clash_records, + chain_ids, + chain_resolved_atom_nums, + core_chain_id=core_chain_id, + ) + + atom_array = atom_array[~np.isin(atom_array.chain_id, removed_chain_ids)] + return atom_array, removed_chain_ids + + @staticmethod + def remove_unresolved_mols(atom_array: AtomArray) -> AtomArray: + """ + Remove molecules from a bioassembly object which all atoms are not resolved. + + Args: + atom_array (AtomArray): Biotite AtomArray Object of a bioassembly. + + Returns: + AtomArray: An AtomArray object with unresolved molecules removed. + """ + valid_mol_id = [] + for mol_id in np.unique(atom_array.mol_id): + resolved = atom_array.is_resolved[atom_array.mol_id == mol_id] + if np.any(resolved): + valid_mol_id.append(mol_id) + + atom_array = atom_array[np.isin(atom_array.mol_id, valid_mol_id)] + return atom_array + + @staticmethod + def remove_asymmetric_polymer_ligand_bonds( + atom_array: AtomArray, entity_poly_type: dict + ) -> AtomArray: + """remove asymmetric polymer ligand bonds (including protein-protein bond, like disulfide bond). + + AF3 SI 5.1 Structure filters + Bonds for structures with homomeric subcomplexes lacking the corresponding homomeric symmetry are also removed + - e.g. if a certain bonded ligand only exists for some of the symmetric copies, but not for all, + we remove the corresponding bond information from the input. + In consequence the model has to learn to infer these bonds by itself. + + Args: + atom_array (AtomArray): input atom array + + Returns: + AtomArray: output atom array with asymmetric polymer ligand bonds removed. + """ + # get inter chain bonds + inter_chain_bonds = set() + for i, j, b in atom_array.bonds.as_array(): + if atom_array.chain_id[i] != atom_array.chain_id[j]: + inter_chain_bonds.add((i, j)) + + # get asymmetric polymer ligand bonds + asymmetric_bonds = set() + chain_starts = struc.get_chain_starts(atom_array, add_exclusive_stop=False) + for bond in inter_chain_bonds: + + if bond in asymmetric_bonds: + continue + + i, j = bond + atom_i = atom_array[i] + atom_j = atom_array[j] + i_is_polymer = atom_i.label_entity_id in entity_poly_type + j_is_polymer = atom_j.label_entity_id in entity_poly_type + if i_is_polymer: + pass + elif j_is_polymer: + i, j = j, i + atom_i, atom_j = atom_j, atom_i + i_is_polymer, j_is_polymer = j_is_polymer, i_is_polymer + else: + # both entity is not polymer + continue + + # get atom i mask from all entity i copies + entity_mask_i = atom_array.label_entity_id == atom_i.label_entity_id + num_copies = np.isin(chain_starts, np.flatnonzero(entity_mask_i)).sum() + mask_i = ( + entity_mask_i + & (atom_array.res_id == atom_i.res_id) + & (atom_array.atom_name == atom_i.atom_name) + ) + indices_i = np.flatnonzero(mask_i) + + if len(indices_i) != num_copies: + # not every copy of entity i has atom i. + asymmetric_bonds.add(bond) + continue + + # check all atom i in entity i bond to an atom j in entity j. + target_bonds = [] + for ii in indices_i: + ii_bonds = [b for b in inter_chain_bonds if ii in b] + for bond in ii_bonds: + jj = bond[1] if ii == bond[0] else bond[0] + atom_jj = atom_array[jj] + if atom_jj.label_entity_id != atom_j.label_entity_id: + continue + if atom_jj.res_name != atom_j.res_name: + continue + if atom_jj.atom_name != atom_j.atom_name: + continue + if j_is_polymer and atom_jj.res_id != atom_j.res_id: + # only for polymer, check res_id + continue + # found bond (ii, jj) with same enity_id, res_name, atom_name to bond (i,j) + target_bonds.append((min(ii, jj), max(ii, jj))) + break + if len(target_bonds) != num_copies: + asymmetric_bonds |= set(target_bonds) + + for bond in asymmetric_bonds: + atom_array.bonds.remove_bond(bond[0], bond[1]) + return atom_array diff --git a/protenix/protenix/data/json_maker.py b/protenix/protenix/data/json_maker.py index 99855fe39663abbfcf22a4347413285d095fd369..39b1213320e61ad7ecdce1889fe2c7bd21e8bc11 100644 --- a/protenix/protenix/data/json_maker.py +++ b/protenix/protenix/data/json_maker.py @@ -46,12 +46,12 @@ def merge_covalent_bonds( for bond_dict in covalent_bonds: bond_unique_string = [] entity_counts = ( - all_entity_counts[str(bond_dict["left_entity"])], - all_entity_counts[str(bond_dict["right_entity"])], + all_entity_counts[str(bond_dict["entity1"])], + all_entity_counts[str(bond_dict["entity2"])], ) - for i in ["left", "right"]: + for i in range(2): for j in ["entity", "position", "atom"]: - k = f"{i}_{j}" + k = f"{j}{i+1}" bond_unique_string.append(str(bond_dict[k])) bond_unique_string = "_".join(bond_unique_string) bonds_recorder[bond_unique_string].append(bond_dict) @@ -59,12 +59,12 @@ def merge_covalent_bonds( merged_covalent_bonds = [] for k, v in bonds_recorder.items(): - left_counts = bonds_entity_counts[k][0] - right_counts = bonds_entity_counts[k][1] - if left_counts == right_counts == len(v): + counts1 = bonds_entity_counts[k][0] + counts2 = bonds_entity_counts[k][1] + if counts1 == counts2 == len(v): bond_dict_copy = copy.deepcopy(v[0]) - del bond_dict_copy["left_copy"] - del bond_dict_copy["right_copy"] + del bond_dict_copy["copy1"] + del bond_dict_copy["copy2"] merged_covalent_bonds.append(bond_dict_copy) else: merged_covalent_bonds.extend(v) @@ -217,15 +217,15 @@ def atom_array_to_input_json( covalent_bonds = [] for atoms in inter_entity_bonds[:, :2]: bond_dict = {} - for idx, i in enumerate(["left", "right"]): - atom = atom_array[atoms[idx]] + for i in range(2): + atom = atom_array[atoms[i]] positon = atom.res_id - bond_dict[f"{i}_entity"] = int( + bond_dict[f"entity{i+1}"] = int( label_entity_id_to_entity_id_in_json[atom.label_entity_id] ) - bond_dict[f"{i}_position"] = int(positon) - bond_dict[f"{i}_atom"] = atom.atom_name - bond_dict[f"{i}_copy"] = int(atom.copy_id) + bond_dict[f"position{i+1}"] = int(positon) + bond_dict[f"atom{i+1}"] = atom.atom_name + bond_dict[f"copy{i+1}"] = int(atom.copy_id) covalent_bonds.append(bond_dict) @@ -299,8 +299,15 @@ def cif_to_input_json( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--cif_file", type=str, required=True, help="The cif file to parse") - parser.add_argument("--json_file", type=str, required=False, default=None, help="The json file path to generate") + parser.add_argument( + "--cif_file", type=str, required=True, help="The cif file to parse" + ) + parser.add_argument( + "--json_file", + type=str, + required=False, + default=None, + help="The json file path to generate", + ) args = parser.parse_args() print(cif_to_input_json(args.cif_file, output_json=args.json_file)) - diff --git a/protenix/protenix/data/json_parser.py b/protenix/protenix/data/json_parser.py index a40bf42f5764844d4fa0b3a620ccd2cc09dee638..20fd3163943f31c47472d479c69ab23875aaec1b 100644 --- a/protenix/protenix/data/json_parser.py +++ b/protenix/protenix/data/json_parser.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import concurrent.futures import copy import logging import random @@ -500,7 +501,22 @@ def smiles_to_atom_info(smiles: str) -> dict: atom_info = {} mol = Chem.MolFromSmiles(smiles) mol = Chem.AddHs(mol) - ret_code = AllChem.EmbedMolecule(mol) + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(AllChem.EmbedMolecule, mol) + + try: + ret_code = future.result(timeout=90) + except concurrent.futures.TimeoutError as exc: + raise TimeoutError( + 'Conformer generation timed out. \ + Please change the "ligand" input format to "CCD_" or "FILE_".' + ) from exc + + if ret_code != 0: + # retry with random coords + ret_code = AllChem.EmbedMolecule(mol, useRandomCoords=True) + assert ret_code == 0, f"Conformer generation failed for input SMILES: {smiles}" atom_info = rdkit_mol_to_atom_info(mol) return atom_info diff --git a/protenix/protenix/data/json_to_feature.py b/protenix/protenix/data/json_to_feature.py index 92cb2814860f6406fdb21768266d944a45127a06..e1f46a1fe8dc16d438e12d2df3d956dc1d19b2bd 100644 --- a/protenix/protenix/data/json_to_feature.py +++ b/protenix/protenix/data/json_to_feature.py @@ -167,11 +167,26 @@ class SampleDictToFeatures: bond_count = {} for bond_info_dict in self.input_dict["covalent_bonds"]: bond_atoms = [] - for i in ["left", "right"]: - entity_id = int(bond_info_dict[f"{i}_entity"]) - copy_id = int(bond_info_dict.get(f"{i}_copy")) - position = int(bond_info_dict[f"{i}_position"]) - atom_name = bond_info_dict[f"{i}_atom"] + for idx, i in enumerate(["left", "right"]): + entity_id = int( + bond_info_dict.get( + f"{i}_entity", bond_info_dict.get(f"entity{idx+1}") + ) + ) + copy_id = bond_info_dict.get( + f"{i}_copy", bond_info_dict.get(f"copy{idx+1}") + ) + position = int( + bond_info_dict.get( + f"{i}_position", bond_info_dict.get(f"position{idx+1}") + ) + ) + atom_name = bond_info_dict.get( + f"{i}_atom", bond_info_dict.get(f"atom{idx+1}") + ) + + if copy_id is not None: + copy_id = int(copy_id) if isinstance(atom_name, str): if atom_name.isdigit(): @@ -194,11 +209,11 @@ class SampleDictToFeatures: atom_indices.size > 0 ), f"No atom found for {atom_name} in entity {entity_id} at position {position}." bond_atoms.append(atom_indices) - - assert len(bond_atoms[0]) == len( - bond_atoms[1] - ), f'Can not create bonds because the "count" of entity {bond_info_dict["left_entity"]} \ - and {bond_info_dict["right_entity"]} are not equal. ' + assert len(bond_atoms[0]) == len(bond_atoms[1]), ( + 'Can not create bonds because the "count" of entity1 ' + f'({bond_info_dict.get("left_entity", bond_info_dict.get("entity1"))}) ' + f'and entity2 ({bond_info_dict.get("right_entity", bond_info_dict.get("entity2"))}) are not equal. ' + ) # Create bond between each asym chain pair for atom_idx1, atom_idx2 in zip(bond_atoms[0], bond_atoms[1]): diff --git a/protenix/protenix/data/msa_featurizer.py b/protenix/protenix/data/msa_featurizer.py index c2df9d3b4b4627ea5104c129c6fcd3a59eacd9b4..284d8758b43b9abc7c32cff80bd61a6544fab118 100644 --- a/protenix/protenix/data/msa_featurizer.py +++ b/protenix/protenix/data/msa_featurizer.py @@ -915,7 +915,6 @@ def merge_all_chain_features( ) if msa_entity_type == "rna": np_example = rna_merge( - is_homomer_or_monomer=is_homomer_or_monomer, all_chain_features=all_chain_features, merge_method=merge_method, msa_crop_size=max_size, diff --git a/protenix/protenix/data/parser.py b/protenix/protenix/data/parser.py index a64047d450c32e08c164b8b70d9a5cf4fd028a23..682b89222efa027c6890c83a926252c954dd8f9c 100644 --- a/protenix/protenix/data/parser.py +++ b/protenix/protenix/data/parser.py @@ -16,9 +16,11 @@ import copy import functools import gzip import logging -from collections import Counter +import random +from collections import Counter, defaultdict +from datetime import datetime from pathlib import Path -from typing import Optional, Union +from typing import Any, Optional, Union import biotite.structure as struc import biotite.structure.io.pdbx as pdbx @@ -31,13 +33,24 @@ from biotite.structure.molecules import get_molecule_indices from protenix.data import ccd from protenix.data.ccd import get_ccd_ref_info from protenix.data.constants import ( + CRYSTALLIZATION_METHODS, DNA_STD_RESIDUES, + GLYCANS, + LIGAND_EXCLUSION, + PRO_STD_RESIDUES, PROT_STD_RESIDUES_ONE_TO_THREE, RES_ATOMS_DICT, RNA_STD_RESIDUES, STD_RESIDUES, ) -from protenix.data.utils import get_starts_by +from protenix.data.filter import Filter +from protenix.data.utils import ( + atom_select, + get_inter_residue_bonds, + get_ligand_polymer_bond_mask, + get_starts_by, + parse_pdb_cluster_file_to_dict, +) logger = logging.getLogger(__name__) @@ -47,7 +60,11 @@ if "metalc" in pdbx_convert.PDBX_COVALENT_TYPES: # for reload class MMCIFParser: - def __init__(self, mmcif_file: Union[str, Path]) -> None: + """ + Parsing and extracting information from mmCIF files. + """ + + def __init__(self, mmcif_file: Union[str, Path]): self.cif = self._parse(mmcif_file=mmcif_file) def _parse(self, mmcif_file: Union[str, Path]) -> pdbx.CIFFile: @@ -61,12 +78,120 @@ class MMCIFParser: return cif_file def get_category_table(self, name: str) -> Union[pd.DataFrame, None]: + """ + Retrieve a category table from the CIF block and return it as a pandas DataFrame. + + Args: + name (str): The name of the category to retrieve from the CIF block. + + Returns: + Union[pd.DataFrame, None]: A pandas DataFrame containing the category data if the category exists, + otherwise None. + """ if name not in self.cif.block: return None category = self.cif.block[name] category_dict = {k: column.as_array() for k, column in category.items()} return pd.DataFrame(category_dict, dtype=str) + @functools.cached_property + def pdb_id(self) -> str: + """ + Extracts and returns the PDB ID from the CIF block. + + Returns: + str: The PDB ID in lowercase if present, otherwise an empty string. + """ + + if "entry" not in self.cif.block: + return "" + else: + return self.cif.block["entry"]["id"].as_item().lower() + + def num_assembly_polymer_chains(self, assembly_id: str = "1") -> int: + """ + Calculate the number of polymer chains in a specified assembly. + + Args: + assembly_id (str): The ID of the assembly to count polymer chains for. + Defaults to "1". If "all", counts chains for all assemblies. + + Returns: + int: The total number of polymer chains in the specified assembly. + If the oligomeric count is invalid (e.g., '?'), the function returns None. + """ + chain_count = 0 + for _assembly_id, _chain_count in zip( + self.cif.block["pdbx_struct_assembly"]["id"].as_array(), + self.cif.block["pdbx_struct_assembly"]["oligomeric_count"].as_array(), + ): + if assembly_id == "all" or _assembly_id == assembly_id: + try: + chain_count += int(_chain_count) + except ValueError: + # oligomeric_count == '?'. e.g. 1hya.cif + return + return chain_count + + @functools.cached_property + def resolution(self) -> float: + """ + Get resolution for X-ray and cryoEM. + Some methods don't have resolution, set as -1.0 + + Returns: + float: resolution (set to -1.0 if not found) + """ + block = self.cif.block + resolution_names = [ + "refine.ls_d_res_high", + "em_3d_reconstruction.resolution", + "reflns.d_resolution_high", + ] + for category_item in resolution_names: + category, item = category_item.split(".") + if category in block and item in block[category]: + try: + resolution = block[category][item].as_array(float)[0] + # "." will be converted to 0.0, but it is not a valid resolution. + if resolution == 0.0: + continue + return resolution + except ValueError: + # in some cases, resolution_str is "?" + continue + return -1.0 + + @functools.cached_property + def release_date(self) -> str: + """ + Get first release date. + + Returns: + str: yyyy-mm-dd + """ + + def _is_valid_date_format(date_string): + try: + datetime.strptime(date_string, "%Y-%m-%d") + return True + except ValueError: + return False + + if "pdbx_audit_revision_history" in self.cif.block: + history = self.cif.block["pdbx_audit_revision_history"] + # np.str_ is inherit from str, so return is str + date = history["revision_date"].as_array()[0] + else: + # no release date + date = "9999-12-31" + + valid_date = _is_valid_date_format(date) + assert ( + valid_date + ), f"Invalid date format: {date}, it should be yyyy-mm-dd format" + return date + @staticmethod def mse_to_met(atom_array: AtomArray) -> AtomArray: """ @@ -87,6 +212,42 @@ class MMCIFParser: atom_array.hetero[mse] = False return atom_array + @staticmethod + def fix_arginine(atom_array: AtomArray) -> AtomArray: + """ + Ref: AlphaFold3 SI chapter 2.1 + Arginine naming ambiguities are fixed (ensuring NH1 is always closer to CD than NH2). + + Args: + atom_array (AtomArray): Biotite AtomArray object. + + Returns: + AtomArray: Biotite AtomArray object after fix arginine . + """ + + starts = struc.get_residue_starts(atom_array, add_exclusive_stop=True) + for start_i, stop_i in zip(starts[:-1], starts[1:]): + if atom_array[start_i].res_name != "ARG": + continue + cd_idx, nh1_idx, nh2_idx = None, None, None + for idx in range(start_i, stop_i): + if atom_array.atom_name[idx] == "CD": + cd_idx = idx + if atom_array.atom_name[idx] == "NH1": + nh1_idx = idx + if atom_array.atom_name[idx] == "NH2": + nh2_idx = idx + if cd_idx and nh1_idx and nh2_idx: # all not None + cd_nh1 = atom_array.coord[nh1_idx] - atom_array.coord[cd_idx] + d2_cd_nh1 = np.sum(cd_nh1**2) + cd_nh2 = atom_array.coord[nh2_idx] - atom_array.coord[cd_idx] + d2_cd_nh2 = np.sum(cd_nh2**2) + if d2_cd_nh2 < d2_cd_nh1: + atom_array.coord[[nh1_idx, nh2_idx]] = atom_array.coord[ + [nh2_idx, nh1_idx] + ] + return atom_array + @functools.cached_property def methods(self) -> list[str]: """the methods to get the structure @@ -146,7 +307,9 @@ class MMCIFParser: seq_nums = entity_poly_seq.num[chain_mask].to_numpy(dtype=int) - if np.unique(seq_nums).size == seq_nums.size: + uniq_seq_num = np.unique(seq_nums).size + + if uniq_seq_num == seq_nums.size: # no altloc residues poly_res_names[entity_id] = seq_mon_ids continue @@ -171,10 +334,12 @@ class MMCIFParser: if select_mask[i]: matching_res_id += 1 - seq_mon_ids = seq_mon_ids[select_mask] - seq_nums = seq_nums[select_mask] - assert len(seq_nums) == max(seq_nums) - poly_res_names[entity_id] = seq_mon_ids + new_seq_mon_ids = seq_mon_ids[select_mask] + new_seq_nums = seq_nums[select_mask] + assert ( + len(new_seq_nums) == uniq_seq_num + ), f"seq_nums not match:\n{seq_nums=}\n{new_seq_nums=}\n{seq_mon_ids=}\n{new_seq_mon_ids=}" + poly_res_names[entity_id] = new_seq_mon_ids return poly_res_names def get_sequences(self, atom_array=None) -> dict: @@ -217,34 +382,99 @@ class MMCIFParser: return {i: t for i, t in zip(entity_poly.entity_id, entity_poly.type)} - def filter_altloc(self, atom_array: AtomArray, altloc: str = "first"): + def filter_altloc(self, atom_array: AtomArray, altloc: str = "first") -> AtomArray: """ - altloc: "first", "A", "B", "global_largest", etc + Filter alternate conformations (altloc) of a given AtomArray based on the specified criteria. + For example, in 2PXS, there are two res_name (XYG|DYG) at res_id 63. - Filter first alternate coformation (altloc) of a given AtomArray. - - normally first altloc_id is 'A' - - but in one case, first altloc_id is '1' in 6uwi.cif + Args: + atom_array : AtomArray + The array of atoms to filter. + altloc : str, optional + The criteria for filtering alternate conformations. Possible values are: + - "first": Keep the first alternate conformation. + - "all": Keep all alternate conformations. + - "A", "B", etc.: Keep the specified alternate conformation. + - "global_largest": Keep the alternate conformation with the largest average occupancy. - biotite v0.41 can not handle diff res_name at same res_id. - For example, in 2pxs.cif, there are two res_name (XYG|DYG) at res_id 63, - need to keep the first XYG. + Returns: + AtomArray + The filtered AtomArray based on the specified altloc criteria. """ if altloc == "all": return atom_array - altloc_id = altloc - if altloc == "first": + elif altloc == "first": letter_altloc_ids = np.unique(atom_array.label_alt_id) if len(letter_altloc_ids) == 1 and letter_altloc_ids[0] == ".": return atom_array letter_altloc_ids = letter_altloc_ids[letter_altloc_ids != "."] altloc_id = np.sort(letter_altloc_ids)[0] + return atom_array[np.isin(atom_array.label_alt_id, [altloc_id, "."])] + + elif altloc == "global_largest": + occ_dict = defaultdict(list) + res_altloc = defaultdict(list) + + res_starts = get_residue_starts(atom_array, add_exclusive_stop=True) + for res_start, _res_end in zip(res_starts[:-1], res_starts[1:]): + altloc_char = atom_array.label_alt_id[res_start] + if altloc_char == ".": + continue + + occupency = atom_array.occupancy[res_start] + occ_dict[altloc_char].append(occupency) + + chain_id = atom_array.chain_id[res_start] + res_id = atom_array.res_id[res_start] + res_altloc[(chain_id, res_id)].append(altloc_char) + + alt_and_avg_occ = [ + (altloc_char, np.mean(occ_list)) + for altloc_char, occ_list in occ_dict.items() + ] + sorted_altloc_chars = [ + i[0] for i in sorted(alt_and_avg_occ, key=lambda x: x[1], reverse=True) + ] - return atom_array[np.isin(atom_array.label_alt_id, [altloc_id, "."])] + selected_mask = np.zeros(len(atom_array), dtype=bool) + for res_start, res_end in zip(res_starts[:-1], res_starts[1:]): + chain_id = atom_array.chain_id[res_start] + res_id = atom_array.res_id[res_start] + altloc_char = atom_array.label_alt_id[res_start] + + if altloc_char == ".": + selected_mask[res_start:res_end] = True + else: + res_sorted_altloc = [ + i + for i in sorted_altloc_chars + if i in res_altloc[(chain_id, res_id)] + ] + selected_altloc = res_sorted_altloc[0] + if altloc_char == selected_altloc: + selected_mask[res_start:res_end] = True + return atom_array[selected_mask] + + else: + return atom_array[np.isin(atom_array.label_alt_id, [altloc, "."])] @staticmethod def replace_auth_with_label(atom_array: AtomArray) -> AtomArray: - # fix issue https://github.com/biotite-dev/biotite/issues/553 + """ + Replace the author-provided chain ID with the label asym ID in the given AtomArray. + + This function addresses the issue described in https://github.com/biotite-dev/biotite/issues/553. + It updates the `chain_id` of the `atom_array` to match the `label_asym_id` and resets the ligand + residue IDs (`res_id`) for chains where the `label_seq_id` is ".". The residue IDs are reset + sequentially starting from 1 within each chain. + + Args: + atom_array (AtomArray): The input AtomArray object to be modified. + + Returns: + AtomArray: The modified AtomArray with updated chain IDs and residue IDs. + """ atom_array.chain_id = atom_array.label_asym_id # reset ligand res_id @@ -444,6 +674,1093 @@ class MMCIFParser: assembly.set_annotation("assembly_1", np.array(assembly_1_mask)) return assembly + def _get_core_indices(self, atom_array): + if "assembly_1" in atom_array._annot: + core_indices = np.where(atom_array.assembly_1)[0] + else: + core_indices = None + return core_indices + + def get_bioassembly( + self, + assembly_id: str = "1", + max_assembly_chains: int = 1000, + ) -> dict[str, Any]: + """ + Build the given biological assembly. + + Args: + assembly_id (str, optional): Assembly ID. Defaults to "1". + max_assembly_chains (int, optional): Max allowed chains in the assembly. Defaults to 1000. + + Returns: + dict[str, Any]: A dictionary containing basic Bioassembly information, including: + - "pdb_id": The PDB ID. + - "sequences": The sequences associated with the assembly. + - "release_date": The release date of the structure. + - "assembly_id": The assembly ID. + - "num_assembly_polymer_chains": The number of polymer chains in the assembly. + - "num_prot_chains": The number of protein chains in the assembly. + - "entity_poly_type": The type of polymer entities. + - "resolution": The resolution of the structure. Set to -1.0 if resolution not found. + - "atom_array": The AtomArray object representing the structure. + - "num_tokens": The number of tokens in the AtomArray. + """ + num_assembly_polymer_chains = self.num_assembly_polymer_chains(assembly_id) + bioassembly_dict = { + "pdb_id": self.pdb_id, + "sequences": self.get_sequences(), # label_entity_id --> canonical_sequence + "release_date": self.release_date, + "assembly_id": assembly_id, + "num_assembly_polymer_chains": num_assembly_polymer_chains, + "num_prot_chains": -1, + "entity_poly_type": self.entity_poly_type, + "resolution": self.resolution, + "atom_array": None, + } + if (not num_assembly_polymer_chains) or ( + num_assembly_polymer_chains > max_assembly_chains + ): + return bioassembly_dict + + # created AtomArray of first model from mmcif atom_site (Asymmetric Unit) + atom_array = self.get_structure() + + # convert MSE to MET to consistent with MMCIFParser.get_poly_res_names() + atom_array = self.mse_to_met(atom_array) + + # update sequences: keep same altloc residue with atom_array + bioassembly_dict["sequences"] = self.get_sequences(atom_array) + + pipeline_functions = [ + Filter.remove_water, + Filter.remove_hydrogens, + lambda aa: Filter.remove_polymer_chains_all_residues_unknown( + aa, self.entity_poly_type + ), + # Note: Filter.remove_polymer_chains_too_short not being used + lambda aa: Filter.remove_polymer_chains_with_consecutive_c_alpha_too_far_away( + aa, self.entity_poly_type + ), + self.fix_arginine, + self.add_missing_atoms_and_residues, # and add annotation is_resolved (False for missing atoms) + Filter.remove_element_X, # remove X element (including ASX->ASP, GLX->GLU) after add_missing_atoms_and_residues() + ] + + if set(self.methods) & CRYSTALLIZATION_METHODS: + # AF3 SI 2.5.4 Crystallization aids are removed if the mmCIF method information indicates that crystallography was used. + pipeline_functions.append( + lambda aa: Filter.remove_crystallization_aids(aa, self.entity_poly_type) + ) + + for func in pipeline_functions: + atom_array = func(atom_array) + if len(atom_array) == 0: + # no atoms left + return bioassembly_dict + + atom_array = AddAtomArrayAnnot.add_token_mol_type( + atom_array, self.entity_poly_type + ) + atom_array = AddAtomArrayAnnot.add_centre_atom_mask(atom_array) + atom_array = AddAtomArrayAnnot.add_atom_mol_type_mask(atom_array) + atom_array = AddAtomArrayAnnot.add_distogram_rep_atom_mask(atom_array) + atom_array = AddAtomArrayAnnot.add_plddt_m_rep_atom_mask(atom_array) + atom_array = AddAtomArrayAnnot.add_cano_seq_resname(atom_array) + atom_array = AddAtomArrayAnnot.add_tokatom_idx(atom_array) + atom_array = AddAtomArrayAnnot.add_modified_res_mask(atom_array) + assert ( + atom_array.centre_atom_mask.sum() + == atom_array.distogram_rep_atom_mask.sum() + ) + + # expand created AtomArray by expand bioassembly + atom_array = self.expand_assembly(atom_array, assembly_id) + + if len(atom_array) == 0: + # If no chains corresponding to the assembly_id remain in the AtomArray + # expand_assembly will return an empty AtomArray. + return bioassembly_dict + + # reset the coords after expand assembly + atom_array.coord[~atom_array.is_resolved, :] = 0.0 + + # rename chain_ids from A A B to A0 A1 B0 and add asym_id_int, entity_id_int, sym_id_int + atom_array = AddAtomArrayAnnot.unique_chain_and_add_ids(atom_array) + + # get chain id before remove chains + core_indices = self._get_core_indices(atom_array) + if core_indices is not None: + ori_chain_ids = np.unique(atom_array.chain_id[core_indices]) + else: + ori_chain_ids = np.unique(atom_array.chain_id) + + atom_array = AddAtomArrayAnnot.add_mol_id(atom_array) + atom_array = Filter.remove_unresolved_mols(atom_array) + + # update core indices after remove unresolved mols + core_indices = np.where(np.isin(atom_array.chain_id, ori_chain_ids))[0] + + # If the number of chains has already reached `max_chains_num`, but the token count hasn't reached `max_tokens_num`, + # chains will continue to be added until `max_tokens_num` is exceeded. + atom_array, _input_chains_num = Filter.too_many_chains_filter( + atom_array, + core_indices=core_indices, + max_chains_num=20, + max_tokens_num=5120, + ) + + if atom_array is None: + # The distance between the central atoms in any two chains is greater than 15 angstroms. + return bioassembly_dict + + # update core indices after too_many_chains_filter + core_indices = np.where(np.isin(atom_array.chain_id, ori_chain_ids))[0] + + atom_array, _removed_chain_ids = Filter.remove_clashing_chains( + atom_array, core_indices=core_indices + ) + + # remove asymmetric polymer ligand bonds (including protein-protein bond, like disulfide bond) + # apply to assembly atom array + atom_array = Filter.remove_asymmetric_polymer_ligand_bonds( + atom_array, self.entity_poly_type + ) + + # add_mol_id before applying the two filters below to ensure that covalent components are not removed as individual chains. + atom_array = AddAtomArrayAnnot.find_equiv_mol_and_assign_ids( + atom_array, self.entity_poly_type + ) + + # numerical encoding of (chain id, residue index) + atom_array = AddAtomArrayAnnot.add_ref_space_uid(atom_array) + atom_array = AddAtomArrayAnnot.add_ref_info_and_res_perm(atom_array) + + # the number of protein chains in the assembly + prot_label_entity_ids = [ + k for k, v in self.entity_poly_type.items() if "polypeptide" in v + ] + num_prot_chains = len( + np.unique( + atom_array.chain_id[ + np.isin(atom_array.label_entity_id, prot_label_entity_ids) + ] + ) + ) + bioassembly_dict["num_prot_chains"] = num_prot_chains + + bioassembly_dict["atom_array"] = atom_array + bioassembly_dict["num_tokens"] = atom_array.centre_atom_mask.sum() + return bioassembly_dict + + @staticmethod + def create_empty_annotation_like( + source_array: AtomArray, target_array: AtomArray + ) -> AtomArray: + """create empty annotation like source_array""" + # create empty annotation, atom array addition only keep common annotation + for k, v in source_array._annot.items(): + if k not in target_array._annot: + target_array._annot[k] = np.zeros(len(target_array), dtype=v.dtype) + return target_array + + @staticmethod + def find_non_ccd_leaving_atoms( + atom_array: AtomArray, + select_dict: dict[str, Any], + component: AtomArray, + ) -> list[str]: + """ " + handle mismatch bettween CCD and mmcif + some residue has bond in non-central atom (without leaving atoms in CCD) + and its neighbors should be removed like atom_array from mmcif. + + Args: + atom_array (AtomArray): Biotite AtomArray object from mmcif. + select_dict dict[str, Any]: entity_id, res_id, atom_name,... of central atom in atom_array. + component (AtomArray): CCD component AtomArray object. + + Returns: + list[str]: list of atom_name to be removed. + """ + # find non-CCD central atoms in atom_array + indices_in_atom_array = atom_select(atom_array, select_dict) + + if len(indices_in_atom_array) == 0: + return [] + + if component.bonds is None: + return [] + + # atom_name not in CCD component, return [] + atom_name = select_dict["atom_name"] + idx_in_comp = np.where(component.atom_name == atom_name)[0] + if len(idx_in_comp) == 0: + return [] + idx_in_comp = idx_in_comp[0] + + # find non-CCD leaving atoms in atom_array + remove_atom_names = [] + for idx in indices_in_atom_array: + neighbor_idx, types = atom_array.bonds.get_bonds(idx) + ref_neighbor_idx, types = component.bonds.get_bonds(idx_in_comp) + # neighbor_atom only bond to central atom in CCD component + ref_neighbor_idx = [ + i for i in ref_neighbor_idx if len(component.bonds.get_bonds(i)[0]) == 1 + ] + removed_mask = ~np.isin( + component.atom_name[ref_neighbor_idx], + atom_array.atom_name[neighbor_idx], + ) + remove_atom_names.append( + component.atom_name[ref_neighbor_idx][removed_mask].tolist() + ) + max_id = np.argmax(map(len, remove_atom_names)) + return remove_atom_names[max_id] + + def build_ref_chain_with_atom_array(self, atom_array: AtomArray) -> AtomArray: + """ + build ref chain with atom_array and poly_res_names + """ + # count inter residue bonds of each potential central atom for removing leaving atoms later + central_bond_count = Counter() # (entity_id,res_id,atom_name) -> bond_count + + # build reference entity atom array, including missing residues + poly_res_names = self.get_poly_res_names(atom_array) + entity_atom_array = {} + for entity_id, poly_type in self.entity_poly_type.items(): + chain = struc.AtomArray(0) + for res_id, res_name in enumerate(poly_res_names[entity_id]): + # keep all leaving atoms, will remove leaving atoms later in this function + residue = ccd.get_component_atom_array( + res_name, keep_leaving_atoms=True, keep_hydrogens=False + ) + residue.res_id[:] = res_id + 1 + chain += residue + res_starts = struc.get_residue_starts(chain, add_exclusive_stop=True) + inter_bonds = ccd._connect_inter_residue(chain, res_starts) + + # filter out non-std polymer bonds + bond_mask = np.ones(len(inter_bonds._bonds), dtype=bool) + for b_idx, (atom_i, atom_j, b_type) in enumerate(inter_bonds._bonds): + idx_i = atom_select( + atom_array, + { + "label_entity_id": entity_id, + "res_id": chain.res_id[atom_i], + "atom_name": chain.atom_name[atom_i], + }, + ) + idx_j = atom_select( + atom_array, + { + "label_entity_id": entity_id, + "res_id": chain.res_id[atom_j], + "atom_name": chain.atom_name[atom_j], + }, + ) + for i in idx_i: + for j in idx_j: + # both i, j exist in same chain but not bond in atom_array, non-std polymer bonds, remove from chain + if atom_array.chain_id[i] == atom_array.chain_id[j]: + bonds, types = atom_array.bonds.get_bonds(i) + if j not in bonds: + bond_mask[b_idx] = False + break + + if bond_mask[b_idx]: + # keep this bond, add to central_bond_count + central_atom_idx = ( + atom_i if chain.atom_name[atom_i] in ("C", "P") else atom_j + ) + atom_key = ( + entity_id, + chain.res_id[central_atom_idx], + chain.atom_name[central_atom_idx], + ) + # use ref chain bond count if no inter bond in atom_array. + central_bond_count[atom_key] = 1 + + inter_bonds._bonds = inter_bonds._bonds[bond_mask] + chain.bonds = chain.bonds.merge(inter_bonds) + + chain.hetero[:] = False + entity_atom_array[entity_id] = chain + + # remove leaving atoms of residues based on atom_array + + # count inter residue bonds from atom_array for removing leaving atoms later + inter_residue_bonds = get_inter_residue_bonds(atom_array) + for i in inter_residue_bonds.flat: + bonds, types = atom_array.bonds.get_bonds(i) + bond_count = ( + (atom_array.res_id[bonds] != atom_array.res_id[i]) + | (atom_array.chain_id[bonds] != atom_array.chain_id[i]) + ).sum() + atom_key = ( + atom_array.label_entity_id[i], + atom_array.res_id[i], + atom_array.atom_name[i], + ) + # remove leaving atoms if central atom has inter residue bond in any copy of a entity + central_bond_count[atom_key] = max(central_bond_count[atom_key], bond_count) + + # remove leaving atoms for each central atom based in atom_array info + # so the residue in reference chain can be used directly. + for entity_id, chain in entity_atom_array.items(): + keep_atom_mask = np.ones(len(chain), dtype=bool) + starts = struc.get_residue_starts(chain, add_exclusive_stop=True) + for start, stop in zip(starts[:-1], starts[1:]): + res_name = chain.res_name[start] + remove_atom_names = [] + for i in range(start, stop): + central_atom_name = chain.atom_name[i] + atom_key = (entity_id, chain.res_id[i], central_atom_name) + inter_bond_count = central_bond_count[atom_key] + + if inter_bond_count == 0: + continue + + # num of remove leaving groups equals to num of inter residue bonds (inter_bond_count) + component = ccd.get_component_atom_array( + res_name, keep_leaving_atoms=True + ) + + if component.central_to_leaving_groups is None: + # The leaving atoms might be labeled wrongly. The residue remains as it is. + break + + # central_to_leaving_groups:dict[str, list[list[str]]], central atom name to leaving atom groups (atom names). + if central_atom_name in component.central_to_leaving_groups: + leaving_groups = component.central_to_leaving_groups[ + central_atom_name + ] + # removed only when there are leaving atoms. + if inter_bond_count >= len(leaving_groups): + remove_groups = leaving_groups + else: + # subsample leaving atoms, keep resolved leaving atoms first + exist_group = [] + not_exist_group = [] + for group in leaving_groups: + for leaving_atom_name in group: + atom_idx = atom_select( + atom_array, + select_dict={ + "label_entity_id": entity_id, + "res_id": chain.res_id[i], + "atom_name": leaving_atom_name, + }, + ) + if len(atom_idx) > 0: # resolved + exist_group.append(group) + break + else: + not_exist_group.append(group) + if inter_bond_count <= len(not_exist_group): + remove_groups = random.sample( + not_exist_group, inter_bond_count + ) + else: + remove_groups = not_exist_group + random.sample( + exist_group, inter_bond_count - len(not_exist_group) + ) + names = [name for group in remove_groups for name in group] + remove_atom_names.extend(names) + + else: + # may has non-std leaving atom + non_std_leaving_atoms = self.find_non_ccd_leaving_atoms( + atom_array=atom_array, + select_dict={ + "label_entity_id": entity_id, + "res_id": chain.res_id[i], + "atom_name": chain.atom_name[i], + }, + component=component, + ) + if len(non_std_leaving_atoms) > 0: + remove_atom_names.extend(non_std_leaving_atoms) + + # remove leaving atoms of this residue + remove_mask = np.isin(chain.atom_name[start:stop], remove_atom_names) + keep_atom_mask[np.arange(start, stop)[remove_mask]] = False + + entity_atom_array[entity_id] = chain[keep_atom_mask] + return entity_atom_array + + @staticmethod + def make_new_residue( + atom_array, res_start, res_stop, ref_chain=None + ) -> tuple[AtomArray, dict[int, int]]: + """ + make new residue from atom_array[res_start:res_stop], ref_chain is the reference chain. + 1. only remove leavning atom when central atom covalent to other residue. + 2. if ref_chain is provided, remove all atoms not match the residue in ref_chain. + """ + res_id = atom_array.res_id[res_start] + res_name = atom_array.res_name[res_start] + ref_residue = ccd.get_component_atom_array( + res_name, + keep_leaving_atoms=True, + keep_hydrogens=False, + ) + if ref_residue is None: # only https://www.rcsb.org/ligand/UNL + return atom_array[res_start:res_stop] + + if ref_residue.central_to_leaving_groups is None: + # ambiguous: one leaving group bond to more than one central atom, keep same atoms with PDB entry. + return atom_array[res_start:res_stop] + + if ref_chain is not None: + return ref_chain[ref_chain.res_id == res_id] + + keep_atom_mask = np.ones(len(ref_residue), dtype=bool) + + # remove leavning atoms when covalent to other residue + for i in range(res_start, res_stop): + central_name = atom_array.atom_name[i] + old_atom_names = atom_array.atom_name[res_start:res_stop] + idx = np.where(old_atom_names == central_name)[0] + if len(idx) == 0: + # central atom is not resolved in atom_array, not remove leaving atoms + continue + idx = idx[0] + res_start + bonds, types = atom_array.bonds.get_bonds(idx) + bond_count = (res_id != atom_array.res_id[bonds]).sum() + if bond_count == 0: + # central atom is not covalent to other residue, not remove leaving atoms + continue + + if central_name in ref_residue.central_to_leaving_groups: + leaving_groups = ref_residue.central_to_leaving_groups[central_name] + # removed only when there are leaving atoms. + if bond_count >= len(leaving_groups): + remove_groups = leaving_groups + else: + # subsample leaving atoms, remove unresolved leaving atoms first + exist_group = [] + not_exist_group = [] + for group in leaving_groups: + for leaving_atom_name in group: + atom_idx = atom_select( + atom_array, + select_dict={ + "chain_id": atom_array.chain_id[i], + "res_id": atom_array.res_id[i], + "atom_name": leaving_atom_name, + }, + ) + if len(atom_idx) > 0: # resolved + exist_group.append(group) + break + else: + not_exist_group.append(group) + + # not remove leaving atoms of B and BE, if all leaving atoms is exist in atom_array + if central_name in ["B", "BE"]: + if not not_exist_group: + continue + + if bond_count <= len(not_exist_group): + remove_groups = random.sample(not_exist_group, bond_count) + else: + remove_groups = not_exist_group + random.sample( + exist_group, bond_count - len(not_exist_group) + ) + else: + leaving_atoms = MMCIFParser.find_non_ccd_leaving_atoms( + atom_array=atom_array, + select_dict={ + "chain_id": atom_array.chain_id[i], + "res_id": atom_array.res_id[i], + "atom_name": atom_array.atom_name[i], + }, + component=ref_residue, + ) + remove_groups = [leaving_atoms] + + names = [name for group in remove_groups for name in group] + remove_mask = np.isin(ref_residue.atom_name, names) + keep_atom_mask &= ~remove_mask + + return ref_residue[keep_atom_mask] + + def add_missing_atoms_and_residues(self, atom_array: AtomArray) -> AtomArray: + """add missing atoms and residues based on CCD and mmcif info. + + Args: + atom_array (AtomArray): structure with missing residues and atoms, from PDB entry. + + Returns: + AtomArray: structure added missing residues and atoms (label atom_array.is_resolved as False). + """ + # build reference entity atom array, including missing residues + entity_atom_array = self.build_ref_chain_with_atom_array(atom_array) + + # build new atom array and copy info from input atom array to it (new_array). + new_array = None + new_global_start = 0 + o2n_amap = {} # old to new atom map + chain_starts = struc.get_chain_starts(atom_array, add_exclusive_stop=True) + res_starts = struc.get_residue_starts(atom_array, add_exclusive_stop=True) + for c_start, c_stop in zip(chain_starts[:-1], chain_starts[1:]): + # get reference chain atom array + entity_id = atom_array.label_entity_id[c_start] + has_ref_chain = False + if entity_id in entity_atom_array: + has_ref_chain = True + ref_chain_array = entity_atom_array[entity_id].copy() + ref_chain_array = self.create_empty_annotation_like( + atom_array, ref_chain_array + ) + + chain_array = None + c_res_starts = res_starts[(c_start <= res_starts) & (res_starts <= c_stop)] + + # add missing residues + prev_res_id = 0 + for r_start, r_stop in zip(c_res_starts[:-1], c_res_starts[1:]): + curr_res_id = atom_array.res_id[r_start] + if has_ref_chain and curr_res_id - prev_res_id > 1: + # missing residue in head or middle, res_id is 1-based int. + segment = ref_chain_array[ + (prev_res_id < ref_chain_array.res_id) + & (ref_chain_array.res_id < curr_res_id) + ] + if chain_array is None: + chain_array = segment + else: + chain_array += segment + + new_global_start = 0 if new_array is None else len(new_array) + new_global_start += 0 if chain_array is None else len(chain_array) + + # add missing atoms of existing residue + ref_chain = ref_chain_array if has_ref_chain else None + new_residue = self.make_new_residue( + atom_array, r_start, r_stop, ref_chain + ) + + new_residue = self.create_empty_annotation_like(atom_array, new_residue) + + # copy residue level info + residue_fields = ["res_id", "hetero", "label_seq_id", "auth_seq_id"] + for k in residue_fields: + v = atom_array._annot[k][r_start] + new_residue._annot[k][:] = v + + # make o2n_amap: old to new atom map + name_to_index_new = { + name: idx for idx, name in enumerate(new_residue.atom_name) + } + res_o2n_amap = {} + res_mismatch_idx = [] + for old_idx in range(r_start, r_stop): + old_name = atom_array.atom_name[old_idx] + if old_name not in name_to_index_new: + # AF3 SI 2.5.4 Filtering + # For residues or small molecules with CCD codes, atoms outside of the CCD code’s defined set of atom names are removed. + res_mismatch_idx.append(old_idx) + else: + new_idx = name_to_index_new[old_name] + res_o2n_amap[old_idx] = new_global_start + new_idx + if len(res_o2n_amap) > len(res_mismatch_idx): + # Match residues only if more than half of their resolved atoms are matched. + # e.g. 1gbt GBS shows 2/12 match, not add to o2n_amap, all atoms are marked as is_resolved=False. + o2n_amap.update(res_o2n_amap) + + if chain_array is None: + chain_array = new_residue + else: + chain_array += new_residue + + prev_res_id = curr_res_id + + # missing residue in tail + if has_ref_chain: + last_res_id = ref_chain_array.res_id[-1] + if last_res_id > curr_res_id: + chain_array += ref_chain_array[ref_chain_array.res_id > curr_res_id] + + # copy chain level info + chain_fields = [ + "chain_id", + "label_asym_id", + "label_entity_id", + "auth_asym_id", + # "asym_id_int", + # "entity_id_int", + # "sym_id_int", + ] + for k in chain_fields: + chain_array._annot[k][:] = atom_array._annot[k][c_start] + + if new_array is None: + new_array = chain_array + else: + new_array += chain_array + + # copy atom level info + old_idx = list(o2n_amap.keys()) + new_idx = list(o2n_amap.values()) + atom_fields = ["b_factor", "occupancy", "charge"] + for k in atom_fields: + if k not in atom_array._annot: + continue + new_array._annot[k][new_idx] = atom_array._annot[k][old_idx] + + # add is_resolved annotation + is_resolved = np.zeros(len(new_array), dtype=bool) + is_resolved[new_idx] = True + new_array.set_annotation("is_resolved", is_resolved) + + # copy coord + new_array.coord[:] = 0.0 + new_array.coord[new_idx] = atom_array.coord[old_idx] + # copy bonds + old_bonds = atom_array.bonds.as_array() # *n x 3* np.ndarray (i,j,bond_type) + + # some non-leaving atoms are not in the new_array for atom name mismatch, e.g. 4msw TYF + # only keep bonds of matching atoms + old_bonds = old_bonds[ + np.isin(old_bonds[:, 0], old_idx) & np.isin(old_bonds[:, 1], old_idx) + ] + + old_bonds[:, 0] = [o2n_amap[i] for i in old_bonds[:, 0]] + old_bonds[:, 1] = [o2n_amap[i] for i in old_bonds[:, 1]] + new_bonds = struc.BondList(len(new_array), old_bonds) + if new_array.bonds is None: + new_array.bonds = new_bonds + else: + new_array.bonds = new_array.bonds.merge(new_bonds) + + # add peptide bonds and nucleic acid bonds based on CCD type + new_array = ccd.add_inter_residue_bonds( + new_array, exclude_struct_conn_pairs=True, remove_far_inter_chain_pairs=True + ) + return new_array + + def make_chain_indices( + self, atom_array: AtomArray, pdb_cluster_file: Union[str, Path] = None + ) -> list: + """ + Make chain indices. + + Args: + atom_array (AtomArray): Biotite AtomArray object. + pdb_cluster_file (Union[str, Path]): cluster info txt file. + """ + if pdb_cluster_file is None: + pdb_cluster_dict = {} + else: + pdb_cluster_dict = parse_pdb_cluster_file_to_dict(pdb_cluster_file) + poly_res_names = self.get_poly_res_names(atom_array) + starts = struc.get_chain_starts(atom_array, add_exclusive_stop=True) + chain_indices_list = [] + + is_centre_atom_and_is_resolved = ( + atom_array.is_resolved & atom_array.centre_atom_mask.astype(bool) + ) + for start, stop in zip(starts[:-1], starts[1:]): + chain_id = atom_array.chain_id[start] + entity_id = atom_array.label_entity_id[start] + + # skip if centre atoms within a chain are all unresolved, e.g. 1zc8 + if ~np.any(is_centre_atom_and_is_resolved[start:stop]): + continue + + # AF3 SI 2.5.1 Weighted PDB dataset + entity_type = self.entity_poly_type.get(entity_id, "non-poly") + + res_names = poly_res_names.get(entity_id, None) + if res_names is None: + chain_atoms = atom_array[start:stop] + res_ids, res_names = struc.get_residues(chain_atoms) + + if "polypeptide" in entity_type: + mol_type = "prot" + sequence = ccd.res_names_to_sequence(res_names) + if len(sequence) < 10: + cluster_id = sequence + else: + pdb_entity = f"{self.pdb_id}_{entity_id}" + if pdb_entity in pdb_cluster_dict: + cluster_id, _ = pdb_cluster_dict[pdb_entity] + elif entity_type == "polypeptide(D)": + cluster_id = sequence + elif sequence == "X" * len(sequence): + chain_atoms = atom_array[start:stop] + res_ids, res_names = struc.get_residues(chain_atoms) + if np.all(res_names == "UNK"): + cluster_id = "poly_UNK" + else: + cluster_id = "_".join(res_names) + else: + cluster_id = "NotInClusterTxt" + + elif "ribonucleotide" in entity_type: + mol_type = "nuc" + cluster_id = ccd.res_names_to_sequence(res_names) + else: + mol_type = "ligand" + cluster_id = "_".join(res_names) + + chain_dict = { + "entity_id": entity_id, # str + "chain_id": chain_id, + "mol_type": mol_type, + "cluster_id": cluster_id, + } + chain_indices_list.append(chain_dict) + return chain_indices_list + + def make_interface_indices( + self, atom_array: AtomArray, chain_indices_list: list + ) -> list: + """make interface indices + As described in SI 2.5.1, interfaces defined as pairs of chains with minimum heavy atom + (i.e. non-hydrogen) separation less than 5 Ã… + Args: + atom_array (AtomArray): _description_ + chain_indices_list (List): _description_ + """ + + chain_indices_dict = {i["chain_id"]: i for i in chain_indices_list} + interface_indices_dict = {} + + cell_list = struc.CellList( + atom_array, cell_size=5, selection=atom_array.is_resolved + ) + for chain_i, chain_i_dict in chain_indices_dict.items(): + chain_mask = atom_array.chain_id == chain_i + coord = atom_array.coord[chain_mask & atom_array.is_resolved] + neighbors_indices_2d = cell_list.get_atoms( + coord, radius=5 + ) # shape:(n_coord, max_n_neighbors), padding with -1 + neighbors_indices = np.unique(neighbors_indices_2d) + neighbors_indices = neighbors_indices[neighbors_indices != -1] + + chain_j_list = np.unique(atom_array.chain_id[neighbors_indices]) + for chain_j in chain_j_list: + if chain_i == chain_j: + continue + + # skip if centre atoms within a chain are all unresolved, e.g. 1zc8 + if chain_j not in chain_indices_dict: + continue + + interface_id = "_".join(sorted([chain_i, chain_j])) + if interface_id in interface_indices_dict: + continue + chain_j_dict = chain_indices_dict[chain_j] + interface_dict = {} + # chain_id --> chain_1_id + # mol_type --> mol_1_type + # entity_id --> entity_1_id + # cluster_id --> cluster_1_id + interface_dict.update( + {k.replace("_", "_1_"): v for k, v in chain_i_dict.items()} + ) + interface_dict.update( + {k.replace("_", "_2_"): v for k, v in chain_j_dict.items()} + ) + interface_indices_dict[interface_id] = interface_dict + return list(interface_indices_dict.values()) + + @staticmethod + def add_sub_mol_type( + atom_array: AtomArray, + indices_dict: dict[str, Any], + ) -> dict[str, Any]: + """ + Add a "sub_mol_[i]_type" field to indices_dict. + It includes the following mol_types and sub_mol_types: + + prot + - prot + - glycosylation_prot + - modified_prot + + nuc + - dna + - rna + - modified_dna + - modified_rna + - dna_rna_hybrid + + ligand + - bonded_ligand + - non_bonded_ligand + + excluded_ligand + - excluded_ligand + + glycans + - glycans + + ions + - ions + + Args: + atom_array (AtomArray): Biotite AtomArray object of bioassembly. + indices_dict (dict[str, Any]): A dict of chain or interface indices info. + + Returns: + dict[str, Any]: A dict of chain or interface indices info with "sub_mol_[i]_type" field. + """ + polymer_lig_bonds = get_ligand_polymer_bond_mask(atom_array) + if len(polymer_lig_bonds) == 0: + lig_polymer_bond_chain_id = [] + else: + lig_polymer_bond_chain_id = atom_array.chain_id[ + np.unique(polymer_lig_bonds[:, :2]) + ] + + for i in ["1", "2"]: + if indices_dict[f"entity_{i}_id"] == "": + indices_dict[f"sub_mol_{i}_type"] = "" + continue + entity_type = indices_dict[f"mol_{i}_type"] + mol_id = atom_array.mol_id[ + atom_array.label_entity_id == indices_dict[f"entity_{i}_id"] + ][0] + mol_all_res_name = atom_array.res_name[atom_array.mol_id == mol_id] + chain_all_mol_type = atom_array.mol_type[ + atom_array.chain_id == indices_dict[f"chain_{i}_id"] + ] + chain_all_res_name = atom_array.res_name[ + atom_array.chain_id == indices_dict[f"chain_{i}_id"] + ] + + if entity_type == "ligand": + ccd_code = indices_dict[f"cluster_{i}_id"] + if ccd_code in GLYCANS: + indices_dict[f"sub_mol_{i}_type"] = "glycans" + + elif ccd_code in LIGAND_EXCLUSION: + indices_dict[f"sub_mol_{i}_type"] = "excluded_ligand" + + elif indices_dict[f"chain_{i}_id"] in lig_polymer_bond_chain_id: + indices_dict[f"sub_mol_{i}_type"] = "bonded_ligand" + else: + indices_dict[f"sub_mol_{i}_type"] = "non_bonded_ligand" + + elif entity_type == "prot": + # glycosylation + if np.any(np.isin([mol_all_res_name], list(GLYCANS))): + indices_dict[f"sub_mol_{i}_type"] = "glycosylation_prot" + + if ~np.all(np.isin(chain_all_res_name, list(PRO_STD_RESIDUES.keys()))): + indices_dict[f"sub_mol_{i}_type"] = "modified_prot" + + elif entity_type == "nuc": + if np.all(chain_all_mol_type == "dna"): + if np.any( + np.isin(chain_all_res_name, list(DNA_STD_RESIDUES.keys())) + ): + indices_dict[f"sub_mol_{i}_type"] = "dna" + else: + indices_dict[f"sub_mol_{i}_type"] = "modified_dna" + + elif np.all(chain_all_mol_type == "rna"): + if np.any( + np.isin(chain_all_res_name, list(RNA_STD_RESIDUES.keys())) + ): + indices_dict[f"sub_mol_{i}_type"] = "rna" + else: + indices_dict[f"sub_mol_{i}_type"] = "modified_rna" + else: + indices_dict[f"sub_mol_{i}_type"] = "dna_rna_hybrid" + + else: + indices_dict[f"sub_mol_{i}_type"] = [f"mol_{i}_type"] + + if indices_dict.get(f"sub_mol_{i}_type") is None: + indices_dict[f"sub_mol_{i}_type"] = indices_dict[f"mol_{i}_type"] + return indices_dict + + @staticmethod + def add_eval_type(indices_dict: dict[str, Any]) -> dict[str, Any]: + """ + Differentiate DNA and RNA from the nucleus. + + Args: + indices_dict (dict[str, Any]): A dict of chain or interface indices info. + + Returns: + dict[str, Any]: A dict of chain or interface indices info with "eval_type" field. + """ + if indices_dict["mol_type_group"] not in ["intra_nuc", "nuc_prot"]: + eval_type = indices_dict["mol_type_group"] + elif "dna_rna_hybrid" in [ + indices_dict["sub_mol_1_type"], + indices_dict["sub_mol_2_type"], + ]: + eval_type = indices_dict["mol_type_group"] + else: + if indices_dict["mol_type_group"] == "intra_nuc": + nuc_type = str(indices_dict["sub_mol_1_type"]).split("_")[-1] + eval_type = f"intra_{nuc_type}" + else: + nuc_type1 = str(indices_dict["sub_mol_1_type"]).split("_")[-1] + nuc_type2 = str(indices_dict["sub_mol_2_type"]).split("_")[-1] + if "dna" in [nuc_type1, nuc_type2]: + eval_type = "dna_prot" + else: + eval_type = "rna_prot" + indices_dict["eval_type"] = eval_type + return indices_dict + + def make_indices( + self, + bioassembly_dict: dict[str, Any], + pdb_cluster_file: Union[str, Path] = None, + ) -> list: + """generate indices of chains and interfaces for sampling data + + Args: + bioassembly_dict (dict): dict from MMCIFParser.get_bioassembly(). + cluster_file (str): PDB cluster file. Defaults to None. + Return: + List(Dict(str, str)): sample_indices_list + """ + atom_array = bioassembly_dict["atom_array"] + if atom_array is None: + print( + f"Warning: make_indices() input atom_array is None, return empty list (PDB Code:{bioassembly_dict['pdb_id']})" + ) + return [] + chain_indices_list = self.make_chain_indices(atom_array, pdb_cluster_file) + interface_indices_list = self.make_interface_indices( + atom_array, chain_indices_list + ) + meta_dict = { + "pdb_id": bioassembly_dict["pdb_id"], + "assembly_id": bioassembly_dict["assembly_id"], + "release_date": self.release_date, + "num_tokens": bioassembly_dict["num_tokens"], + "num_prot_chains": bioassembly_dict["num_prot_chains"], + "resolution": self.resolution, + } + sample_indices_list = [] + for chain_dict in chain_indices_list: + chain_dict_out = {k.replace("_", "_1_"): v for k, v in chain_dict.items()} + chain_dict_out.update( + {k.replace("_", "_2_"): "" for k, v in chain_dict.items()} + ) + chain_dict_out["cluster_id"] = chain_dict["cluster_id"] + chain_dict_out.update(meta_dict) + chain_dict_out["type"] = "chain" + sample_indices_list.append(chain_dict_out) + + for interface_dict in interface_indices_list: + cluster_ids = [ + interface_dict["cluster_1_id"], + interface_dict["cluster_2_id"], + ] + interface_dict["cluster_id"] = ":".join(sorted(cluster_ids)) + interface_dict.update(meta_dict) + interface_dict["type"] = "interface" + sample_indices_list.append(interface_dict) + + for indices in sample_indices_list: + for i in ["1", "2"]: + chain_id = indices[f"chain_{i}_id"] + if chain_id == "": + continue + chain_atom_num = np.sum([atom_array.chain_id == chain_id]) + if chain_atom_num == 1: + indices[f"mol_{i}_type"] = "ions" + + if indices["type"] == "chain": + indices["mol_type_group"] = f'intra_{indices["mol_1_type"]}' + else: + indices["mol_type_group"] = "_".join( + sorted([indices["mol_1_type"], indices["mol_2_type"]]) + ) + indices = self.add_sub_mol_type(atom_array, indices) + indices = self.add_eval_type(indices) + return sample_indices_list + + +class DistillationMMCIFParser(MMCIFParser): + + def get_structure_dict(self) -> dict[str, Any]: + """ + Get an AtomArray from a CIF file of distillation data. + + Returns: + Dict[str, Any]: a dict of asymmetric unit structure info. + """ + # created AtomArray of first model from mmcif atom_site (Asymmetric Unit) + atom_array = self.get_structure() + + # convert MSE to MET to consistent with MMCIFParser.get_poly_res_names() + atom_array = self.mse_to_met(atom_array) + + structure_dict = { + "pdb_id": self.pdb_id, + "atom_array": None, + "assembly_id": None, + "sequences": self.get_sequences(atom_array), + "entity_poly_type": self.entity_poly_type, + "num_tokens": -1, + "num_prot_chains": -1, + } + + pipeline_functions = [ + self.fix_arginine, + self.add_missing_atoms_and_residues, # add UNK + ] + + for func in pipeline_functions: + atom_array = func(atom_array) + if len(atom_array) == 0: + # no atoms left + return structure_dict + + atom_array = AddAtomArrayAnnot.add_token_mol_type( + atom_array, self.entity_poly_type + ) + atom_array = AddAtomArrayAnnot.add_centre_atom_mask(atom_array) + atom_array = AddAtomArrayAnnot.add_atom_mol_type_mask(atom_array) + atom_array = AddAtomArrayAnnot.add_distogram_rep_atom_mask(atom_array) + atom_array = AddAtomArrayAnnot.add_plddt_m_rep_atom_mask(atom_array) + atom_array = AddAtomArrayAnnot.add_cano_seq_resname(atom_array) + atom_array = AddAtomArrayAnnot.add_tokatom_idx(atom_array) + atom_array = AddAtomArrayAnnot.add_modified_res_mask(atom_array) + assert ( + atom_array.centre_atom_mask.sum() + == atom_array.distogram_rep_atom_mask.sum() + ) + + # rename chain_ids from A A B to A0 A1 B0 and add asym_id_int, entity_id_int, sym_id_int + atom_array = AddAtomArrayAnnot.unique_chain_and_add_ids(atom_array) + atom_array = AddAtomArrayAnnot.find_equiv_mol_and_assign_ids( + atom_array, self.entity_poly_type + ) + + # numerical encoding of (chain id, residue index) + atom_array = AddAtomArrayAnnot.add_ref_space_uid(atom_array) + atom_array = AddAtomArrayAnnot.add_ref_info_and_res_perm(atom_array) + + # the number of protein chains in the structure + prot_label_entity_ids = [ + k for k, v in self.entity_poly_type.items() if "polypeptide" in v + ] + num_prot_chains = len( + np.unique( + atom_array.chain_id[ + np.isin(atom_array.label_entity_id, prot_label_entity_ids) + ] + ) + ) + structure_dict["num_prot_chains"] = num_prot_chains + structure_dict["atom_array"] = atom_array + structure_dict["num_tokens"] = atom_array.centre_atom_mask.sum() + return structure_dict + class AddAtomArrayAnnot(object): """ @@ -951,7 +2268,7 @@ class AddAtomArrayAnnot(object): """ Unique chain ID and add asym_id, entity_id, sym_id. Adds a number to the chain ID to make chain IDs in the assembly unique. - Example: [A, B, A, B, C] ==> [A0, B0, A1, B1, C0] + Example: [A, B, A, B, C] -> [A, B, A.1, B.1, C] Args: atom_array (AtomArray): Biotite AtomArray object. @@ -962,49 +2279,26 @@ class AddAtomArrayAnnot(object): - entity_id_int: np.array(int) - sym_id_int: np.array(int) """ - entity_id_uniq = np.sort(np.unique(atom_array.label_entity_id)) - entity_id_dict = {e: i for i, e in enumerate(entity_id_uniq)} - asym_ids = np.zeros(len(atom_array), dtype=int) - entity_ids = np.zeros(len(atom_array), dtype=int) - sym_ids = np.zeros(len(atom_array), dtype=int) - chain_ids = np.zeros(len(atom_array), dtype="U4") - counter = Counter() - start_indices = struc.get_chain_starts(atom_array, add_exclusive_stop=True) - for i in range(len(start_indices) - 1): - start_i = start_indices[i] - stop_i = start_indices[i + 1] - asym_ids[start_i:stop_i] = i - - entity_id = atom_array.label_entity_id[start_i] - entity_ids[start_i:stop_i] = entity_id_dict[entity_id] + chain_ids = np.zeros(len(atom_array), dtype="<U8") + chain_starts = get_chain_starts(atom_array, add_exclusive_stop=True) - sym_ids[start_i:stop_i] = counter[entity_id] - counter[entity_id] += 1 - new_chain_id = f"{atom_array.chain_id[start_i]}{sym_ids[start_i]}" - chain_ids[start_i:stop_i] = new_chain_id + chain_counter = Counter() + for start, stop in zip(chain_starts[:-1], chain_starts[1:]): + ori_chain_id = atom_array.chain_id[start] + cnt = chain_counter[ori_chain_id] + if cnt == 0: + new_chain_id = ori_chain_id + else: + new_chain_id = f"{ori_chain_id}.{chain_counter[ori_chain_id]}" - atom_array.set_annotation("asym_id_int", asym_ids) - atom_array.set_annotation("entity_id_int", entity_ids) - atom_array.set_annotation("sym_id_int", sym_ids) - atom_array.chain_id = chain_ids - return atom_array + chain_ids[start:stop] = new_chain_id + chain_counter[ori_chain_id] += 1 - @staticmethod - def add_int_id(atom_array): - """ - Unique chain ID and add asym_id, entity_id, sym_id. - Adds a number to the chain ID to make chain IDs in the assembly unique. - Example: [A, B, A, B, C] ==> [A0, B0, A1, B1, C0] + assert "" not in chain_ids + # reset chain id + atom_array.del_annotation("chain_id") + atom_array.set_annotation("chain_id", chain_ids) - Args: - atom_array (AtomArray): Biotite AtomArray object. - - Returns: - AtomArray: Biotite AtomArray object with new annotations: - - asym_id_int: np.array(int) - - entity_id_int: np.array(int) - - sym_id_int: np.array(int) - """ entity_id_uniq = np.sort(np.unique(atom_array.label_entity_id)) entity_id_dict = {e: i for i, e in enumerate(entity_id_uniq)} asym_ids = np.zeros(len(atom_array), dtype=int) diff --git a/protenix/protenix/data/utils.py b/protenix/protenix/data/utils.py index 18f98aa2d76605c1099b5d00f6e530d50ca55bb1..b484e704dcbe75bb291d11cf49464868f1b9f370 100644 --- a/protenix/protenix/data/utils.py +++ b/protenix/protenix/data/utils.py @@ -14,6 +14,7 @@ import argparse import copy +import functools import os import re from collections import defaultdict @@ -59,6 +60,27 @@ def int_to_letters(n: int) -> str: return result +def get_inter_residue_bonds(atom_array: AtomArray) -> np.ndarray: + """get inter residue bonds by checking chain_id and res_id + + Args: + atom_array (AtomArray): Biotite AtomArray, must have chain_id and res_id + + Returns: + np.ndarray: inter residue bonds, shape = (n,2) + """ + if atom_array.bonds is None: + return [] + idx_i = atom_array.bonds._bonds[:, 0] + idx_j = atom_array.bonds._bonds[:, 1] + chain_id_diff = atom_array.chain_id[idx_i] != atom_array.chain_id[idx_j] + res_id_diff = atom_array.res_id[idx_i] != atom_array.res_id[idx_j] + diff_mask = chain_id_diff | res_id_diff + inter_residue_bonds = atom_array.bonds._bonds[diff_mask] + inter_residue_bonds = inter_residue_bonds[:, :2] # remove bond type + return inter_residue_bonds + + def get_starts_by( atom_array: AtomArray, by_annot: str, add_exclusive_stop=False ) -> np.ndarray: @@ -88,6 +110,26 @@ def get_starts_by( return np.concatenate(([0], starts)) +def atom_select(atom_array: AtomArray, select_dict: dict, as_mask=False) -> np.ndarray: + """return index of atom_array that match select_dict + + Args: + atom_array (AtomArray): Biotite AtomArray + select_dict (dict): select dict, eg: {'element': 'C'} + as_mask (bool, optional): return mask of atom_array. Defaults to False. + + Returns: + np.ndarray: index of atom_array that match select_dict + """ + mask = np.ones(len(atom_array), dtype=bool) + for k, v in select_dict.items(): + mask = mask & (getattr(atom_array, k) == v) + if as_mask: + return mask + else: + return np.where(mask)[0] + + def get_ligand_polymer_bond_mask( atom_array: AtomArray, lig_include_ions=False ) -> np.ndarray: @@ -140,6 +182,38 @@ def get_ligand_polymer_bond_mask( return lig_polymer_bonds +@functools.lru_cache +def parse_pdb_cluster_file_to_dict( + cluster_file: str, remove_uniprot: bool = True +) -> dict[str, tuple]: + """parse PDB cluster file, and return a pandas dataframe + example cluster file: + https://cdn.rcsb.org/resources/sequence/clusters/clusters-by-entity-40.txt + + Args: + cluster_file (str): cluster_file path + Returns: + dict(str, tuple(str, str)): {pdb_id}_{entity_id} --> [cluster_id, cluster_size] + """ + pdb_cluster_dict = {} + with open(cluster_file) as f: + for line in f: + pdb_clusters = [] + for ids in line.strip().split(): + if remove_uniprot: + if ids.startswith("AF_") or ids.startswith("MA_"): + continue + pdb_clusters.append(ids) + cluster_size = len(pdb_clusters) + if cluster_size == 0: + continue + # use first member as cluster id. + cluster_id = f"pdb_cluster_{pdb_clusters[0]}" + for ids in pdb_clusters: + pdb_cluster_dict[ids.lower()] = (cluster_id, cluster_size) + return pdb_cluster_dict + + def get_clean_data(atom_array: AtomArray) -> AtomArray: """ Removes unresolved atoms from the AtomArray. @@ -230,6 +304,21 @@ class CIFWriter: self.atom_array = atom_array self.entity_poly_type = entity_poly_type + def _get_entity_block(self): + if self.entity_poly_type is None: + return {} + entity_ids_in_atom_array = np.sort(np.unique(self.atom_array.label_entity_id)) + entity_block_dict = defaultdict(list) + for entity_id in entity_ids_in_atom_array: + if entity_id not in self.entity_poly_type: + entity_type = "non-polymer" + else: + entity_type = "polymer" + entity_block_dict["id"].append(entity_id) + entity_block_dict["pdbx_description"].append(".") + entity_block_dict["type"].append(entity_type) + return pdbx.CIFCategory(entity_block_dict) + def _get_entity_poly_and_entity_poly_seq_block(self): entity_poly = defaultdict(list) for entity_id, entity_type in self.entity_poly_type.items(): @@ -247,6 +336,9 @@ class CIFWriter: entity_poly["entity_id"].append(entity_id) entity_poly["pdbx_strand_id"].append(label_asym_ids_str) entity_poly["type"].append(entity_type) + + if not entity_poly: + return {} entity_poly_seq = defaultdict(list) for entity_id, label_asym_ids_str in zip( @@ -305,6 +397,7 @@ class CIFWriter: block_dict = {"entry": pdbx.CIFCategory({"id": entry_id})} if self.entity_poly_type: + block_dict["entity"] = self._get_entity_block() block_dict.update(self._get_entity_poly_and_entity_poly_seq_block()) block = pdbx.CIFBlock(block_dict) @@ -317,6 +410,11 @@ class CIFWriter: pdbx.set_structure(cif, self.atom_array, include_bonds=include_bonds) block = cif.block atom_site = block.get("atom_site") + + occ = atom_site.get("occupancy") + if occ is None: + atom_site["occupancy"] = np.ones(len(self.atom_array), dtype=float) + atom_site["label_entity_id"] = self.atom_array.label_entity_id cif.write(output_path) @@ -678,7 +776,11 @@ def pdb_to_cif(input_fname: str, output_fname: str, entry_id: str = None): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--pdb_file", type=str, required=True, help="The pdb file to parse") - parser.add_argument("--cif_file", type=str, required=True, help="The cif file path to generate") + parser.add_argument( + "--pdb_file", type=str, required=True, help="The pdb file to parse" + ) + parser.add_argument( + "--cif_file", type=str, required=True, help="The cif file path to generate" + ) args = parser.parse_args() - pdb_to_cif(args.pdb_file, args.cif_file) \ No newline at end of file + pdb_to_cif(args.pdb_file, args.cif_file) diff --git a/protenix/protenix/model/modules/pairformer.py b/protenix/protenix/model/modules/pairformer.py index ad40f4b683c4e5699eb4e81ad5cc69b4cc3294b6..4ee3054887081798b39d5bef134eea880fa97e74 100644 --- a/protenix/protenix/model/modules/pairformer.py +++ b/protenix/protenix/model/modules/pairformer.py @@ -141,7 +141,7 @@ class PairformerBlock(nn.Module): z = z.transpose(-2, -3).contiguous() z += self.tri_att_end( z, - mask=pair_mask.tranpose(-1, -2) if pair_mask is not None else None, + mask=pair_mask.transpose(-1, -2) if pair_mask is not None else None, use_memory_efficient_kernel=use_memory_efficient_kernel, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_lma=use_lma, @@ -184,7 +184,7 @@ class PairformerBlock(nn.Module): z = z + self.dropout_row( self.tri_att_end( z, - mask=pair_mask.tranpose(-1, -2) if pair_mask is not None else None, + mask=pair_mask.transpose(-1, -2) if pair_mask is not None else None, use_memory_efficient_kernel=use_memory_efficient_kernel, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_lma=use_lma, diff --git a/protenix/protenix/openfold_local/model/primitives.py b/protenix/protenix/openfold_local/model/primitives.py index 1d0fb68d34cd1acf5cd85e7197884ab2c9bdd1f6..f1ad20a0756a2c069600087b8168208c95bfb199 100644 --- a/protenix/protenix/openfold_local/model/primitives.py +++ b/protenix/protenix/openfold_local/model/primitives.py @@ -544,13 +544,6 @@ class Attention(nn.Module): if use_memory_efficient_kernel: raise Exception(f"use_memory_efficient_kernel=True not supported!!!") - if len(biases) > 2: - raise ValueError( - "If use_memory_efficient_kernel is True, you may only " - "provide up to two bias terms" - ) - o = attention_core(q, k, v, *((biases + [None] * 2)[:2])) - o = o.transpose(-2, -3) elif use_deepspeed_evo_attention: if len(biases) > 2: raise ValueError( diff --git a/protenix/protenix/openfold_local/np/residue_constants.py b/protenix/protenix/openfold_local/np/residue_constants.py index dfafab7381a4ebc5d80a9084dff19a4a7c4c9eeb..9aff5ffa69ef0de297ca9a3529d693d2b80ecf39 100644 --- a/protenix/protenix/openfold_local/np/residue_constants.py +++ b/protenix/protenix/openfold_local/np/residue_constants.py @@ -22,7 +22,7 @@ from importlib import resources from typing import List, Mapping, Tuple import numpy as np -import tree +import optree # Distance from one CA to next CA [trans configuration: omega = 180]. ca_ca = 3.80209737096 @@ -1072,7 +1072,7 @@ chi_atom_2_one_hot = chi_angle_atom(2) # An array like chi_angles_atoms but using indices rather than names. chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes] -chi_angles_atom_indices = tree.map_structure( +chi_angles_atom_indices = optree.tree_map( lambda atom_name: atom_order[atom_name], chi_angles_atom_indices ) chi_angles_atom_indices = np.array( diff --git a/protenix/protenix/openfold_local/utils/checkpointing.py b/protenix/protenix/openfold_local/utils/checkpointing.py index 149fdec69a5f57a5db1ea96e87f26e49198ddd41..7ba7dfb51e06925f8da3b376f58109ffee04016f 100644 --- a/protenix/protenix/openfold_local/utils/checkpointing.py +++ b/protenix/protenix/openfold_local/utils/checkpointing.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import importlib from typing import Any, List, Callable, Optional @@ -33,7 +34,7 @@ def get_checkpoint_fn(): if deepspeed_is_configured: checkpoint = deepspeed.checkpointing.checkpoint else: - checkpoint = torch.utils.checkpoint.checkpoint + checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) return checkpoint diff --git a/protenix/protenix/openfold_local/utils/chunk_utils.py b/protenix/protenix/openfold_local/utils/chunk_utils.py index 3a7dc885961c3347280ff078314a2fe6338a508f..777ab68a4c126a8b0eedf521e163f8ad98961b76 100644 --- a/protenix/protenix/openfold_local/utils/chunk_utils.py +++ b/protenix/protenix/openfold_local/utils/chunk_utils.py @@ -14,28 +14,15 @@ import logging import math from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple +from typing import Any, Callable, Optional, Sequence, Tuple +import optree import torch - from protenix.openfold_local.utils.tensor_utils import tensor_tree_map, tree_map def _fetch_dims(tree): - shapes = [] - tree_type = type(tree) - if tree_type is dict: - for v in tree.values(): - shapes.extend(_fetch_dims(v)) - elif tree_type is list or tree_type is tuple: - for t in tree: - shapes.extend(_fetch_dims(t)) - elif tree_type is torch.Tensor: - shapes.append(tree.shape) - else: - raise ValueError("Not supported") - - return shapes + return optree.tree_flatten(optree.tree_map(lambda x: x.shape, tree))[0] @torch.jit.ignore diff --git a/protenix/protenix/web_service/colab_request_parser.py b/protenix/protenix/web_service/colab_request_parser.py index a7cc4f7f359c5f92937d9945b96607bf59fb832d..937ba9df13e59e12c0666e2eaf04bba030982c3e 100644 --- a/protenix/protenix/web_service/colab_request_parser.py +++ b/protenix/protenix/web_service/colab_request_parser.py @@ -32,7 +32,9 @@ from protenix.data.json_to_feature import SampleDictToFeatures from protenix.web_service.colab_request_utils import run_mmseqs2_service from protenix.web_service.dependency_url import URL -MMSEQS_SERVICE_HOST_URL = os.getenv("MMSEQS_SERVICE_HOST_URL", "http://101.126.11.40:80") +MMSEQS_SERVICE_HOST_URL = os.getenv( + "MMSEQS_SERVICE_HOST_URL", "https://protenix-server.com/api/msa" +) MAX_ATOM_NUM = 60000 MAX_TOKEN_NUM = 5000 DATA_CACHE_DIR = "/af3-dev/release_data/" @@ -76,7 +78,9 @@ class TooLargeComplexError(Exception): class RequestParser(object): - def __init__(self, request_json_path: str, request_dir: str, email: str = "") -> None: + def __init__( + self, request_json_path: str, request_dir: str, email: str = "" + ) -> None: with open(request_json_path, "r") as f: self.request = json.load(f) self.request_dir = request_dir @@ -207,7 +211,10 @@ class RequestParser(object): @staticmethod def msa_search( - seqs_pending_msa: Sequence[str], tmp_fasta_fpath: str, msa_res_dir: str, email: str = "" + seqs_pending_msa: Sequence[str], + tmp_fasta_fpath: str, + msa_res_dir: str, + email: str = "", ) -> None: lines = [] for idx, seq in enumerate(seqs_pending_msa): @@ -229,7 +236,7 @@ class RequestParser(object): use_templates=False, host_url=MMSEQS_SERVICE_HOST_URL, user_agent="colabfold/1.5.5", - email=email + email=email, ) except Exception as e: error_message = f"MMSEQS2 failed with the following error message:\n{traceback.format_exc()}" @@ -250,13 +257,20 @@ class RequestParser(object): def read_a3m(a3m_file: str) -> Tuple[List[str], List[str]]: heads = [] seqs = [] + # Record the row index. The index before this index is the MSA of Uniref30 DB, + # and the index after this index is the MSA of ColabfoldDB. + uniref_index = 0 with open(a3m_file, "r") as infile: - for line in infile: + for idx, line in enumerate(infile): if line.startswith(">"): heads.append(line) + if idx == 0: + query_name = line + elif idx > 0 and line == query_name: + uniref_index = idx else: seqs.append(line) - return heads, seqs + return heads, seqs, uniref_index def make_pairing_and_non_pairing_msa( query_seq: str, @@ -265,18 +279,22 @@ class RequestParser(object): uniref_to_ncbi_taxid: Mapping[str, str], ) -> List[str]: - heads, msa_seqs = read_a3m(raw_a3m_path) + heads, msa_seqs, uniref_index = read_a3m(raw_a3m_path) uniref100_lines = [">query\n", f"{query_seq}\n"] other_lines = [">query\n", f"{query_seq}\n"] - for head, msa_seq in zip(heads, msa_seqs): + for idx, (head, msa_seq) in enumerate(zip(heads, msa_seqs)): if msa_seq.rstrip("\n") == query_seq: continue - if "UniRef" in head: - uniref_id = head.split("\t")[0][1:] - ncbi_taxid = uniref_to_ncbi_taxid.get(uniref_id, None) - if ncbi_taxid is not None: + uniref_id = head.split("\t")[0][1:] + ncbi_taxid = uniref_to_ncbi_taxid.get(uniref_id, None) + if (ncbi_taxid is not None) and (idx < (uniref_index // 2)): + if not uniref_id.startswith("UniRef100_"): + head = head.replace( + uniref_id, f"UniRef100_{uniref_id}_{ncbi_taxid}/" + ) + else: head = head.replace(uniref_id, f"{uniref_id}_{ncbi_taxid}/") uniref100_lines.extend([head, msa_seq]) else: @@ -294,7 +312,7 @@ class RequestParser(object): seq_dir: str, raw_a3m_path: str, ): - heads, msa_seqs = read_a3m(raw_a3m_path) + heads, msa_seqs, _ = read_a3m(raw_a3m_path) other_lines = [">query\n", f"{query_seq}\n"] for head, msa_seq in zip(heads, msa_seqs): if msa_seq.rstrip("\n") == query_seq: @@ -390,9 +408,18 @@ class RequestParser(object): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--request_json_path", type=str, required=True, help="Path to the request JSON file.") - parser.add_argument("--request_dir", type=str, required=True, help="Path to the request directory.") - parser.add_argument("--email", type=str, required=False, default="", help="Your email address.") + parser.add_argument( + "--request_json_path", + type=str, + required=True, + help="Path to the request JSON file.", + ) + parser.add_argument( + "--request_dir", type=str, required=True, help="Path to the request directory." + ) + parser.add_argument( + "--email", type=str, required=False, default="", help="Your email address." + ) args = parser.parse_args() parser = RequestParser( diff --git a/protenix/protenix/web_service/colab_request_utils.py b/protenix/protenix/web_service/colab_request_utils.py index 5090bbcafe34c62f0faacf00cb79425a382dc12e..fd798a4e5aed3f0becf581c6125313392602a59f 100644 --- a/protenix/protenix/web_service/colab_request_utils.py +++ b/protenix/protenix/web_service/colab_request_utils.py @@ -14,16 +14,14 @@ import logging import os -import random import tarfile import time from typing import List, Tuple +import requests from requests.auth import HTTPBasicAuth from tqdm import tqdm -import requests - TQDM_BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} estimate remaining: {remaining}]" logger = logging.getLogger(__name__) @@ -42,7 +40,7 @@ def run_mmseqs2_service( pairing_strategy="greedy", host_url="https://api.colabfold.com", user_agent: str = "", - email: str = "" + email: str = "", ) -> Tuple[List[str], List[str]]: submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa" diff --git a/protenix/protenix/web_service/prediction_visualization.py b/protenix/protenix/web_service/prediction_visualization.py index c24e0b6cc2557813e2cd8d34feb10fd53566e185..37a3ab517d0b9d3091cc60826248aa0913743829 100644 --- a/protenix/protenix/web_service/prediction_visualization.py +++ b/protenix/protenix/web_service/prediction_visualization.py @@ -282,63 +282,73 @@ def plot_best_confidence_measure(prediction_fpath: str, output_dir: str = None): ): print("no full confidence data found, skip ploting.") return - full_confidence = pred_loader.full_confidence_data[:1] - summary_confidence = pred_loader.summary_confidence_data[:1] - atom_plddts = [d["atom_plddt"] for d in full_confidence] - token_pdes = [d["token_pair_pde"] for d in full_confidence] - token_paes = [d["token_pair_pae"] for d in full_confidence] - summary_keys = ["plddt", "gpde", "ptm", "iptm"] - - fig, axes = plt.subplots( - nrows=3, - ncols=1, - figsize=(10, 20), - gridspec_kw={"height_ratios": [1.5, 2, 2]}, - ) - for i in range(3): - summary_text = ", ".join( - [ - f"{k}: {v:.4f}" - for k, v in summary_confidence[0].items() - if k in summary_keys - ] + data_len = len(pred_loader.full_confidence_data) + for data_idx in range(data_len): + full_confidence = pred_loader.full_confidence_data[data_idx : data_idx + 1] + summary_confidence = pred_loader.summary_confidence_data[ + data_idx : data_idx + 1 + ] + atom_plddts = [d["atom_plddt"] for d in full_confidence] + token_pdes = [d["token_pair_pde"] for d in full_confidence] + token_paes = [d["token_pair_pae"] for d in full_confidence] + summary_keys = ["plddt", "gpde", "ptm", "iptm"] + + fig, axes = plt.subplots( + nrows=3, + ncols=1, + figsize=(10, 20), + gridspec_kw={"height_ratios": [1.5, 2, 2]}, ) - if i == 0: - axes[i].plot(atom_plddts[0], color="k") - axes[i].text( - 0.5, - 1.07, - summary_text, - ha="center", - va="center", - transform=axes[i].transAxes, - fontsize=10, + for i in range(3): + summary_text = ", ".join( + [ + f"{k}: {v:.4f}" + for k, v in summary_confidence[0].items() + if k in summary_keys + ] ) - axes[i].set_xlabel("Atom ID", fontsize=12) - axes[i].set_ylabel("pLDDT", fontsize=12) - axes[i].set_ylim([0, 1]) - axes[i].spines[["right", "top"]].set_visible(False) + if i == 0: + axes[i].plot(atom_plddts[0], color="k") + axes[i].text( + 0.5, + 1.07, + summary_text, + ha="center", + va="center", + transform=axes[i].transAxes, + fontsize=10, + ) + axes[i].set_xlabel("Atom ID", fontsize=12) + axes[i].set_ylabel("pLDDT", fontsize=12) + axes[i].set_ylim([0, 1]) + axes[i].spines[["right", "top"]].set_visible(False) - else: - data_to_plot = token_pdes[0] if i == 1 else token_paes[0] - cax = axes[i].matshow(data_to_plot, origin="lower") - axes[i].xaxis.tick_bottom() - axes[i].set_aspect("equal") - axes[i].set_xlabel("Scored Residue", fontsize=12) - axes[i].set_ylabel("Aligned Residue", fontsize=12) - axes[i].xaxis.set_major_locator(MaxNLocator(3)) # Max 5 ticks on the x-axis - axes[i].yaxis.set_major_locator(MaxNLocator(3)) # Max 5 ticks on the y-axis - color_bar = fig.colorbar( - cax, ax=axes[i], orientation="vertical", pad=0.1, shrink=0.6 - ) - cbar_label = ( - "Predicted Distance Error" if i == 1 else "Predicted Aligned Error" - ) - color_bar.set_label(cbar_label) + else: + data_to_plot = token_pdes[0] if i == 1 else token_paes[0] + cax = axes[i].matshow(data_to_plot, origin="lower") + axes[i].xaxis.tick_bottom() + axes[i].set_aspect("equal") + axes[i].set_xlabel("Scored Residue", fontsize=12) + axes[i].set_ylabel("Aligned Residue", fontsize=12) + axes[i].xaxis.set_major_locator( + MaxNLocator(3) + ) # Max 5 ticks on the x-axis + axes[i].yaxis.set_major_locator( + MaxNLocator(3) + ) # Max 5 ticks on the y-axis + color_bar = fig.colorbar( + cax, ax=axes[i], orientation="vertical", pad=0.1, shrink=0.6 + ) + cbar_label = ( + "Predicted Distance Error" if i == 1 else "Predicted Aligned Error" + ) + color_bar.set_label(cbar_label) - out_png_fpath = os.path.join(output_dir, "best_sample_confidence.png") - plt.savefig(out_png_fpath) - print(f"save best sample confidence plot successfully to {out_png_fpath}") + out_png_fpath = os.path.join( + output_dir, f"best_sample_confidence_{data_idx}.png" + ) + plt.savefig(out_png_fpath) + print(f"save {data_len} best sample confidence plot successfully to {output_dir}") if __name__ == "__main__": diff --git a/protenix/runner/batch_inference.py b/protenix/runner/batch_inference.py index acebea432b90423bf5d1c41afd1de3b0acd3c88f..8dc16d2b83c8ad578b9dd979a258a0b36fb1205f 100644 --- a/protenix/runner/batch_inference.py +++ b/protenix/runner/batch_inference.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import json import logging import os @@ -23,18 +24,18 @@ from typing import List, Optional, Union import click import tqdm from Bio import SeqIO +from rdkit import Chem + from configs.configs_base import configs as configs_base from configs.configs_data import data_configs from configs.configs_inference import inference_configs -from rdkit import Chem -from runner.inference import InferenceRunner, download_infercence_cache, infer_predict -from runner.msa_search import contain_msa_res, msa_search, msa_search_update - from protenix.config import parse_configs from protenix.data.json_maker import cif_to_input_json from protenix.data.json_parser import lig_file_to_atom_info from protenix.data.utils import pdb_to_cif from protenix.utils.logger import get_logger +from runner.inference import InferenceRunner, download_infercence_cache, infer_predict +from runner.msa_search import msa_search, update_infer_json logger = get_logger(__name__) @@ -49,9 +50,7 @@ def init_logging(): ) -def generate_infer_jsons( - protein_msa_res: dict, ligand_file: str, seeds: List[int] = [101] -) -> List[str]: +def generate_infer_jsons(protein_msa_res: dict, ligand_file: str) -> List[str]: protein_chains = [] if len(protein_msa_res) <= 0: raise RuntimeError(f"invalid `protein_msa_res` data in {protein_msa_res}") @@ -135,7 +134,6 @@ def generate_infer_jsons( infer_json_files.append(json_file_name) for smi_ligand_file in smi_ligand_files: - one_infer_seq = protein_chains[:] with open(smi_ligand_file, "r") as f: smile_list = f.readlines() one_infer_seq = protein_chains[:] @@ -161,13 +159,18 @@ def generate_infer_jsons( return infer_json_files -def get_default_runner(seeds: Optional[list] = None) -> InferenceRunner: +def get_default_runner( + seeds: Optional[tuple] = None, + n_cycle: int = 10, + n_step: int = 200, + n_sample: int = 5, +) -> InferenceRunner: configs_base["use_deepspeed_evo_attention"] = ( - os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "true" + os.environ.get("USE_DEEPSPEED_EVO_ATTENTION", False) == "true" ) - configs_base["model"]["N_cycle"] = 10 - configs_base["sample_diffusion"]["N_sample"] = 5 - configs_base["sample_diffusion"]["N_step"] = 200 + configs_base["model"]["N_cycle"] = n_cycle + configs_base["sample_diffusion"]["N_sample"] = n_sample + configs_base["sample_diffusion"]["N_step"] = n_step configs = {**configs_base, **{"data": data_configs}, **inference_configs} configs = parse_configs( configs=configs, @@ -183,7 +186,10 @@ def inference_jsons( json_file: str, out_dir: str = "./output", use_msa_server: bool = False, - seeds: list = [101], + seeds: tuple = (101,), + n_cycle: int = 10, + n_step: int = 200, + n_sample: int = 5, ) -> None: """ infer_json: json file or directory, will run infer with these jsons @@ -210,66 +216,13 @@ def inference_jsons( infer_errors = {} inference_configs["dump_dir"] = out_dir inference_configs["input_json_path"] = infer_jsons[0] - runner = get_default_runner(seeds) + runner = get_default_runner(seeds, n_cycle, n_step, n_sample) configs = runner.configs for idx, infer_json in enumerate(tqdm.tqdm(infer_jsons)): try: - if use_msa_server: - infer_json = msa_search_update( - infer_json, os.path.join(out_dir, f"msa_res_{idx}") - ) - elif not contain_msa_res(infer_json): - raise RuntimeError(f"can not find msa for {infer_json}") - configs["input_json_path"] = infer_json - if not contain_msa_res(infer_json): - raise RuntimeError( - f"`{infer_json}` has no msa result for `proteinChain`, please add first." - ) - infer_predict(runner, configs) - except Exception as exc: - infer_errors[infer_json] = str(exc) - if len(infer_errors) > 0: - logger.warning(f"run inference failed: {infer_errors}") - - -def batch_inference( - protein_msa_res: dict, - ligand_file: str, - out_dir: str = "./output", - seeds: List[int] = [101], -) -> None: - """ - ligand_file: ligand file or directory, should be in sdf format or smi with smlies list; - protein_msa_res: the msa result for `protein`, like: - { "MGHHHHHHHHHHSSGH": { - "precomputed_msa_dir": "/path/to/msa_pairing/result/msa/1", - "pairing_db": "uniref100" - }, - "MAEVIRSSAFWRSFPIFEEFDSE": { - "precomputed_msa_dir": "/path/to/msa_pairing/result/msa/2", - "pairing_db": "uniref100" - } - } - out_dir: the infer outout dir, default is `./output` - """ - - infer_jsons = generate_infer_jsons(protein_msa_res, ligand_file, seeds) - logger.info(f"will infer with {len(infer_jsons)} jsons") - if len(infer_jsons) == 0: - return - - infer_errors = {} - inference_configs["dump_dir"] = out_dir - inference_configs["input_json_path"] = infer_jsons[0] - runner = get_default_runner(seeds=seeds) - configs = runner.configs - for infer_json in tqdm.tqdm(infer_jsons): - try: - configs["input_json_path"] = infer_json - if not contain_msa_res(infer_json): - raise RuntimeError( - f"`{infer_json}` has no msa result for `proteinChain`, please add first." - ) + configs["input_json_path"] = update_infer_json( + infer_json, out_dir=out_dir, use_msa_server=use_msa_server + ) infer_predict(runner, configs) except Exception as exc: infer_errors[infer_json] = str(exc) @@ -288,8 +241,11 @@ def protenix_cli(): @click.option( "--seeds", type=str, default="101", help="the inference seed, split by comma" ) +@click.option("--cycle", type=int, default=10, help="pairformer cycle number") +@click.option("--step", type=int, default=200, help="diffusion step") +@click.option("--sample", type=int, default=5, help="sample number") @click.option("--use_msa_server", is_flag=True, help="do msa search or not") -def predict(input, out_dir, seeds, use_msa_server): +def predict(input, out_dir, seeds, cycle, step, sample, use_msa_server): """ predict: Run predictions with protenix. :param input, out_dir, use_msa_server @@ -297,10 +253,18 @@ def predict(input, out_dir, seeds, use_msa_server): """ init_logging() logger.info( - f"run infer with input={input}, out_dir={out_dir}, use_msa_server={use_msa_server}" + f"run infer with input={input}, out_dir={out_dir}, cycle={cycle}, step={step}, sample={sample}, use_msa_server={use_msa_server}" ) seeds = list(map(int, seeds.split(","))) - inference_jsons(input, out_dir, use_msa_server, seeds=seeds) + inference_jsons( + input, + out_dir, + use_msa_server, + seeds=seeds, + n_cycle=cycle, + n_step=step, + n_sample=sample, + ) @click.command() @@ -395,11 +359,9 @@ def msa(input, out_dir) -> Union[str, dict]: :return: """ init_logging() - out_dir = os.path.join(out_dir, uuid.uuid4().hex) - os.makedirs(out_dir, exist_ok=True) logger.info(f"run msa with input={input}, out_dir={out_dir}") if input.endswith(".json"): - msa_input_json = msa_search_update(input, out_dir) + msa_input_json = update_infer_json(input, out_dir, use_msa_server=True) logger.info(f"msa results have been update to {msa_input_json}") return msa_input_json elif input.endswith(".fasta"): @@ -424,22 +386,5 @@ protenix_cli.add_command(tojson) protenix_cli.add_command(msa) -def test_batch_inference(): - ligands_dir = "../examples/ligands" - protein_msa_res = { - "MASWSHPQFEKGGTHVAETSAPTRSEPDTRVLTLPGTASAPEFRLIDIDGLLNNRATTDVRDLGSGRLNAWGNSFPAAELPAPGSLITVAGIPFTWANAHARGDNIRCEGQVVDIPPGQYDWIYLLAASERRSEDTIWAHYDDGHADPLRVGISDFLDGTPAFGELSAFRTSRMHYPHHVQEGLPTTMWLTRVGMPRHGVARSLRLPRSVAMHVFALTLRTAAAVRLAEGATT": { - "precomputed_msa_dir": "../examples/7wux/msa/1", - "pairing_db": "uniref100", - }, - "MGSSHHHHHHSQDPNSTTTAPPVELWTRDLGSCLHGTLATALIRDGHDPVTVLGAPWEFRRRPGAWSSEEYFFFAEPDSLAGRLALYHPFESTWHRSDGDGVDDLREALAAGVLPIAAVDNFHLPFRPAFHDVHAAHLLVVYRITETEVYVSDAQPPAFQGAIPLADFLASWGSLNPPDDADVFFSASPSGRRWLRTRMTGPVPEPDRHWVGRVIRENVARYRQEPPADTQTGLPGLRRYLDELCALTPGTNAASEALSELYVISWNIQAQSGLHAEFLRAHSVKWRIPELAEAAAGVDAVAHGWTGVRMTGAHSRVWQRHRPAELRGHATALVRRLEAALDLLELAADAVS": { - "precomputed_msa_dir": "../examples/7wux/msa/2", - "pairing_db": "uniref100", - }, - } - out_dir = "./infer_output" - batch_inference(protein_msa_res, ligands_dir, out_dir=out_dir) - - if __name__ == "__main__": - init_logging() - test_batch_inference() + predict() diff --git a/protenix/runner/dumper.py b/protenix/runner/dumper.py index 06ac2e57381c8857971f978d301f79f2fdb868cc..e6e42c2878319a2054755db7d1954f4529dc8fc3 100644 --- a/protenix/runner/dumper.py +++ b/protenix/runner/dumper.py @@ -44,9 +44,15 @@ def get_clean_full_confidence(full_confidence_dict: dict) -> dict: class DataDumper: - def __init__(self, base_dir, need_atom_confidence: bool = False): + def __init__( + self, + base_dir, + need_atom_confidence: bool = False, + sorted_by_ranking_score: bool = True, + ) -> None: self.base_dir = base_dir self.need_atom_confidence = need_atom_confidence + self.sorted_by_ranking_score = sorted_by_ranking_score def dump( self, @@ -156,7 +162,7 @@ class DataDumper: N_sample = pred_coordinates.shape[0] if sorted_indices is None: sorted_indices = range(N_sample) # do not rank the output file - for rank, idx in enumerate(sorted_indices): + for idx, rank in enumerate(sorted_indices): output_fpath = os.path.join( prediction_save_dir, f"{sample_name}_seed_{seed}_sample_{rank}.cif", @@ -173,15 +179,20 @@ class DataDumper: pdb_id=sample_name, ) - def _get_ranker_indices(self, data: dict, sorted_by_ranking_score: bool = True): + def _get_ranker_indices(self, data: dict): N_sample = len(data["summary_confidence"]) - sorted_indices = range(N_sample) - if sorted_by_ranking_score: - sorted_indices = sorted( - range(N_sample), - key=lambda i: data["summary_confidence"][i]["ranking_score"], - reverse=True, + if self.sorted_by_ranking_score: + value = torch.tensor( + [ + data["summary_confidence"][i]["ranking_score"] + for i in range(N_sample) + ] ) + sorted_indices = [ + i for i in torch.argsort(torch.argsort(value, descending=True)) + ] + else: + sorted_indices = [i for i in range(N_sample)] return sorted_indices def _save_confidence( @@ -200,7 +211,7 @@ class DataDumper: ) if sorted_indices is None: sorted_indices = range(N_sample) - for rank, idx in enumerate(sorted_indices): + for idx, rank in enumerate(sorted_indices): output_fpath = os.path.join( prediction_save_dir, f"{sample_name}_seed_{seed}_summary_confidence_sample_{rank}.json", diff --git a/protenix/runner/inference.py b/protenix/runner/inference.py index 3c77b7eb2f6268ff9ffcca5d74ccc9e67be1e203..5b3e1b172126c4a26d9581cbef31e9f24005f119 100644 --- a/protenix/runner/inference.py +++ b/protenix/runner/inference.py @@ -23,10 +23,11 @@ from typing import Any, Mapping import torch import torch.distributed as dist - from configs.configs_base import configs as configs_base from configs.configs_data import data_configs from configs.configs_inference import inference_configs +from runner.dumper import DataDumper + from protenix.config import parse_configs, parse_sys_args from protenix.data.infer_data_pipeline import get_inference_dataloader from protenix.model.protenix import Protenix @@ -34,8 +35,6 @@ from protenix.utils.distributed import DIST_WRAPPER from protenix.utils.seed import seed_everything from protenix.utils.torch_utils import to_device from protenix.web_service.dependency_url import URL -from runner.dumper import DataDumper -from runner.msa_search import contain_msa_res, msa_search_update logger = logging.getLogger(__name__) @@ -47,7 +46,10 @@ class InferenceRunner(object): self.init_basics() self.init_model() self.load_checkpoint() - self.init_dumper(need_atom_confidence=configs.need_atom_confidence) + self.init_dumper( + need_atom_confidence=configs.need_atom_confidence, + sorted_by_ranking_score=configs.sorted_by_ranking_score, + ) def init_env(self) -> None: self.print( @@ -112,14 +114,18 @@ class InferenceRunner(object): } self.model.load_state_dict( state_dict=checkpoint["model"], - strict=True, + strict=self.configs.load_strict, ) self.model.eval() self.print(f"Finish loading checkpoint.") - def init_dumper(self, need_atom_confidence: bool = False): + def init_dumper( + self, need_atom_confidence: bool = False, sorted_by_ranking_score: bool = True + ): self.dumper = DataDumper( - base_dir=self.dump_dir, need_atom_confidence=need_atom_confidence + base_dir=self.dump_dir, + need_atom_confidence=need_atom_confidence, + sorted_by_ranking_score=sorted_by_ranking_score, ) # Adapted from runner.train.Trainer.evaluate @@ -157,30 +163,29 @@ class InferenceRunner(object): def download_infercence_cache(configs: Any, model_version: str = "v0.2.0") -> None: - current_file_path = os.path.abspath(__file__) - current_directory = os.path.dirname(current_file_path) - code_directory = os.path.dirname(current_directory) - - data_cache_dir = os.path.join(code_directory, "release_data/ccd_cache") - os.makedirs(data_cache_dir, exist_ok=True) - for cache_name, fname in [ - ("ccd_components_file", "components.v20240608.cif"), - ("ccd_components_rdkit_mol_file", "components.v20240608.cif.rdkit_mol.pkl"), - ]: - if not opexists(cache_path := os.path.abspath(opjoin(data_cache_dir, fname))): + + for cache_name in ("ccd_components_file", "ccd_components_rdkit_mol_file"): + cur_cache_fpath = configs["data"][cache_name] + if not opexists(cur_cache_fpath): + os.makedirs(os.path.dirname(cur_cache_fpath), exist_ok=True) tos_url = URL[cache_name] - logger.info(f"Downloading data cache from\n {tos_url}...") - urllib.request.urlretrieve(tos_url, cache_path) + assert os.path.basename(tos_url) == os.path.basename(cur_cache_fpath), ( + f"{cache_name} file name is incorrect, `{tos_url}` and " + f"`{cur_cache_fpath}`. Please check and try again." + ) + logger.info( + f"Downloading data cache from\n {tos_url}... to {cur_cache_fpath}" + ) + urllib.request.urlretrieve(tos_url, cur_cache_fpath) checkpoint_path = configs.load_checkpoint_path if not opexists(checkpoint_path): - checkpoint_path = os.path.join( - code_directory, f"release_data/checkpoint/model_{model_version}.pt" - ) os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) tos_url = URL[f"model_{model_version}"] - logger.info(f"Downloading model checkpoint from\n {tos_url}...") + logger.info( + f"Downloading model checkpoint from\n {tos_url}... to {checkpoint_path}" + ) urllib.request.urlretrieve(tos_url, checkpoint_path) try: ckpt = torch.load(checkpoint_path) @@ -191,7 +196,6 @@ def download_infercence_cache(configs: Any, model_version: str = "v0.2.0") -> No "Download model checkpoint failed, please download by yourself with " f"wget {tos_url} -O {checkpoint_path}" ) - configs.load_checkpoint_path = checkpoint_path def update_inference_configs(configs: Any, N_token: int): @@ -211,38 +215,31 @@ def update_inference_configs(configs: Any, N_token: int): def infer_predict(runner: InferenceRunner, configs: Any) -> None: - # update msa result if not contains precomputed msa dir - if not contain_msa_res(configs.input_json_path): - logger.info( - f"{configs.input_json_path} dose not contain precomputed msa dir, now searching it." - ) - configs.input_json_path = msa_search_update( - configs.input_json_path, configs.dump_dir - ) - logger.info( - f"msa searching completed, new input json is {configs.input_json_path}" - ) # Data logger.info(f"Loading data from\n{configs.input_json_path}") - dataloader = get_inference_dataloader(configs=configs) + try: + dataloader = get_inference_dataloader(configs=configs) + except Exception as e: + error_message = f"{e}:\n{traceback.format_exc()}" + logger.info(error_message) + with open(opjoin(runner.error_dir, "error.txt"), "a") as f: + f.write(error_message) + return num_data = len(dataloader.dataset) for seed in configs.seeds: - seed_everything(seed=seed, deterministic=False) + seed_everything(seed=seed, deterministic=configs.deterministic) for batch in dataloader: try: data, atom_array, data_error_message = batch[0] + sample_name = data["sample_name"] if len(data_error_message) > 0: logger.info(data_error_message) - with open( - opjoin(runner.error_dir, f"{data['sample_name']}.txt"), - "w", - ) as f: + with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f: f.write(data_error_message) continue - sample_name = data["sample_name"] logger.info( ( f"[Rank {DIST_WRAPPER.rank} ({data['sample_index'] + 1}/{num_data})] {sample_name}: " @@ -271,15 +268,10 @@ def infer_predict(runner: InferenceRunner, configs: Any) -> None: error_message = f"[Rank {DIST_WRAPPER.rank}]{data['sample_name']} {e}:\n{traceback.format_exc()}" logger.info(error_message) # Save error info - if opexists( - error_path := opjoin(runner.error_dir, f"{sample_name}.txt") - ): - os.remove(error_path) - with open(error_path, "w") as f: + with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f: f.write(error_message) if hasattr(torch.cuda, "empty_cache"): torch.cuda.empty_cache() - raise RuntimeError(f"run infer failed: {str(e)}") def main(configs: Any) -> None: @@ -297,7 +289,7 @@ def run() -> None: filemode="w", ) configs_base["use_deepspeed_evo_attention"] = ( - os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "true" + os.environ.get("USE_DEEPSPEED_EVO_ATTENTION", False) == "true" ) configs = {**configs_base, **{"data": data_configs}, **inference_configs} configs = parse_configs( diff --git a/protenix/runner/msa_search.py b/protenix/runner/msa_search.py index 4c1008c95d2871470aa11b71783a313bd7c7442f..d352a052dc7c2f29cb4389d430b197f96873cb97 100644 --- a/protenix/runner/msa_search.py +++ b/protenix/runner/msa_search.py @@ -22,33 +22,18 @@ from protenix.web_service.colab_request_parser import RequestParser logger = get_logger(__name__) -def contain_msa_res(json_file: str) -> bool: - """ - check the json_path data has msa result or not. - """ - if not os.path.exists(json_file): - raise RuntimeError(f"`{json_file}` not exists.") - with open(json_file, "r") as f: - json_data = json.load(f) - for seq in json_data: - for sequence in seq["sequences"]: - if "proteinChain" in sequence.keys(): - proteinChain = sequence["proteinChain"] - if "msa" not in proteinChain.keys() or len(proteinChain["msa"]) == 0: - return False - return True - - -def update_msa_res(seq: dict, protein_msa_res: dict) -> dict: - for sequence in seq["sequences"]: +def need_msa_search(json_data: dict) -> bool: + need_msa = json_data.get("use_msa", True) + # TODO: add esm check + if not need_msa: + return need_msa + need_msa = False + for sequence in json_data["sequences"]: if "proteinChain" in sequence.keys(): - sequence["proteinChain"]["msa"] = { - "precomputed_msa_dir": protein_msa_res[ - sequence["proteinChain"]["sequence"] - ], - "pairing_db": "uniref100", - } - return seq + proteinChain = sequence["proteinChain"] + if "msa" not in proteinChain.keys() or len(proteinChain["msa"]) == 0: + need_msa = True + return need_msa def msa_search(seqs: Sequence[str], msa_res_dir: str) -> Sequence[str]: @@ -69,32 +54,68 @@ def msa_search(seqs: Sequence[str], msa_res_dir: str) -> Sequence[str]: return msa_res_subdirs -def msa_search_update(json_file: str, out_dir: str) -> str: +def update_seq_msa(infer_seq: dict, msa_res_dir: str) -> dict: + protein_seqs = [] + for sequence in infer_seq["sequences"]: + if "proteinChain" in sequence.keys(): + protein_seqs.append(sequence["proteinChain"]["sequence"]) + if len(protein_seqs) > 0: + protein_seqs = sorted(protein_seqs) + msa_res_subdirs = msa_search(protein_seqs, msa_res_dir) + assert len(msa_res_subdirs) == len(msa_res_subdirs), "msa search failed" + protein_msa_res = dict(zip(protein_seqs, msa_res_subdirs)) + for sequence in infer_seq["sequences"]: + if "proteinChain" in sequence.keys(): + sequence["proteinChain"]["msa"] = { + "precomputed_msa_dir": protein_msa_res[ + sequence["proteinChain"]["sequence"] + ], + "pairing_db": "uniref100", + } + return infer_seq + + +def update_infer_json( + json_file: str, out_dir: str, use_msa_server: bool = False +) -> str: """ - do msa search with mmseqs from json input and update it. + update json file for inference. + for every infer_data, if it needs to update msa result info, + it will run msa searching if use_msa_server is True, + else it will raise error. + if it does not need to update msa result info, then pass. """ - assert os.path.exists(json_file), f"input file {json_file} not exists." - if contain_msa_res(json_file): - logger.warning(f"{json_file} has already msa result, skip.") - return json_file + if not os.path.exists(json_file): + raise RuntimeError(f"`{json_file}` not exists.") with open(json_file, "r") as f: - input_json_data = json.load(f) - for seq_idx, seq in enumerate(input_json_data): - protein_seqs = [] - for sequence in seq["sequences"]: - if "proteinChain" in sequence.keys(): - protein_seqs.append(sequence["proteinChain"]["sequence"]) - if len(protein_seqs) > 0: - protein_seqs = sorted(protein_seqs) - msa_res_subdirs = msa_search( - protein_seqs, os.path.join(out_dir, f"msa_seq_{seq_idx}") - ) - assert len(msa_res_subdirs) == len(msa_res_subdirs), "msa search failed" - update_msa_res(seq, dict(zip(protein_seqs, msa_res_subdirs))) - msa_input_json = os.path.join( - os.path.dirname(json_file), - f"{os.path.splitext(os.path.basename(json_file))[0]}-add-msa.json", - ) - with open(msa_input_json, "w") as f: - json.dump(input_json_data, f, indent=4) - return msa_input_json + json_data = json.load(f) + + actual_updated = False + for seq_idx, infer_data in enumerate(json_data): + if need_msa_search(infer_data): + actual_updated = True + if use_msa_server: + seq_name = infer_data.get("name", f"seq_{seq_idx}") + logger.info( + f"starting to update msa result for seq {seq_idx} in {json_file}" + ) + update_seq_msa( + infer_data, + os.path.join(out_dir, seq_name, "msa_res" f"msa_seq_{seq_idx}"), + ) + else: + raise RuntimeError( + f"infer seq {seq_idx} in `{json_file}` has no msa result, please add first." + ) + if actual_updated: + updated_json = os.path.join( + os.path.dirname(os.path.abspath(json_file)), + f"{os.path.splitext(os.path.basename(json_file))[0]}-add-msa.json", + ) + with open(updated_json, "w") as f: + json.dump(json_data, f, indent=4) + logger.info(f"update msa result success and save to {updated_json}") + return updated_json + else: + logger.info(f"do not need to update msa result, so return itself {json_file}") + return json_file \ No newline at end of file diff --git a/protenix/runner/train.py b/protenix/runner/train.py index d96f6bf8ca500aea227169ea77c7dbe6fb06e0c9..5241e9fcc712423536d28982f0d15048ec2b09b3 100644 --- a/protenix/runner/train.py +++ b/protenix/runner/train.py @@ -552,7 +552,7 @@ def main(): filemode="w", ) configs_base["use_deepspeed_evo_attention"] = ( - os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "true" + os.environ.get("USE_DEEPSPEED_EVO_ATTENTION", False) == "true" ) configs = {**configs_base, **{"data": data_configs}} configs = parse_configs( diff --git a/run_all.py b/run_all.py index 356b3228add482370f52ea14c9e1f43ea52f16c0..376990f80e9dd6c9ac97ae4893120438a72593f4 100644 --- a/run_all.py +++ b/run_all.py @@ -37,30 +37,29 @@ class Runner: print(command) os.system(command) - def run_chai(self, config_names, input_dir, output_dir, seeds, num_diffusion_samples, msa_pqt_dirs): - for config_name, msa_pqt_dir in zip(config_names, msa_pqt_dirs): - fasta_path = f'{input_dir}/{config_name}.fasta' - for seed in seeds: - command = ( - f'CHAI_DOWNLOADS_DIR=/ai/share/workspace/zhangcb/Chai/ckpts python run_chai.py {fasta_path} {msa_pqt_dir} ' - f' {output_dir}/{config_name}/seed_{seed} {num_diffusion_samples} {seed} ' - ) - print(command) - os.system(command) + def run_chai(self, config_names, input_dir, output_dir, seeds, num_diffusion_samples, msa_dir): + str_seeds = ','.join([repr(seed) for seed in seeds]) + command = ( + 'CHAI_DOWNLOADS_DIR=/ai/share/workspace/zhangcb/Chai/ckpts ' + f' /ai/share/workspace/zhangcb/conda_env/chai/bin/python run_chai.py {input_dir} {msa_dir} ' + f' {output_dir} {num_diffusion_samples} {str_seeds} ' + ) + print(command) + os.system(command) def run_protenix(self, json_path, output_dir, seeds, num_diffusion_samples): - for seed in seeds: - command = ( - 'LAYERNORM_TYPE=fast_layernorm /ai/share/workspace/zhangcb/conda_env/protenix/bin/python protenix/inference.py ' - f' --input_json_path {json_path} ' - f' --dump_dir {output_dir} ' - f' --seeds {seed} ' - ' --model.N_cycle 10 ' - f' --sample_diffusion.N_sample {num_diffusion_samples} ' - ' --sample_diffusion.N_step 200 ' - ) - print(command) - os.system(command) + str_seeds = ','.join([repr(seed) for seed in seeds]) + command = ( + 'LAYERNORM_TYPE=fast_layernorm /ai/share/workspace/zhangcb/conda_env/protenix/bin/python protenix/inference.py ' + f' --input_json_path {json_path} ' + f' --dump_dir {output_dir} ' + f' --seeds {str_seeds} ' + ' --model.N_cycle 10 ' + f' --sample_diffusion.N_sample {num_diffusion_samples} ' + ' --sample_diffusion.N_step 200 ' + ) + print(command) + os.system(command) def run( self, *, @@ -79,7 +78,6 @@ class Runner: config_names = [inputs['name'] for inputs in inputs_list] config_dir = f'{output_dir}/configs' struct_dir = f'{output_dir}/preds' - msa_pqt_dirs = [] print('[Searching for MSA]') if not os.path.exists(f'{output_dir}/msa/raw/{inputs_list[0]["name"]}.a3m'): msa_batch_size = 200 @@ -97,16 +95,18 @@ class Runner: print('[Making configs]') for inputs in inputs_list: msa_proc_dir = f'{output_dir}/msa/processed/{inputs["name"]}' - get_configs( - name=inputs['name'], - config_dir=config_dir, - msa_raw_file=f'{output_dir}/msa/raw/{inputs["name"]}.a3m', - msa_proc_dir=msa_proc_dir, - seqs=inputs['seqs'], seq_cnts=inputs['seq_cnts'], - ccd_codes=inputs['ccd_codes'], smis=inputs['smis'], - seeds=seeds, tos_tag=tos_tag - ) - msa_pqt_dirs.append(f'{msa_proc_dir}/protenix/') + if not os.path.exists(f'{config_dir}/af3/{inputs["name"]}.json'): + get_configs( + name=inputs['name'], + config_dir=config_dir, + msa_raw_file=f'{output_dir}/msa/raw/{inputs["name"]}.a3m', + msa_proc_dir=msa_proc_dir, + seqs=inputs['seqs'], seq_cnts=inputs['seq_cnts'], + ccd_codes=inputs['ccd_codes'], smis=inputs['smis'], + seeds=seeds, tos_tag=tos_tag + ) + else: + print('Skip making config for', inputs['name']) # print('[Predicting]') if not disable_af3: @@ -117,7 +117,7 @@ class Runner: self.run_boltz(f'{config_dir}/boltz', f'{struct_dir}/boltz', seeds, num_diffusion_samples) if not disable_chai: print('\n\t[Chai-1]') - self.run_chai(config_names, f'{config_dir}/chai', f'{struct_dir}/chai', seeds, num_diffusion_samples, msa_pqt_dirs) + self.run_chai(config_names, f'{config_dir}/chai', f'{struct_dir}/chai', seeds, num_diffusion_samples, f'{output_dir}/msa/processed/') if not disable_protenix: print('\n\t[Protenix]') self.run_protenix(f'{config_dir}/protenix.json', f'{struct_dir}/protenix', seeds, num_diffusion_samples) @@ -174,4 +174,4 @@ def main(): if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/run_alphafold.py b/run_alphafold.py index 23720baa9dd088830a520ff1139aacdf7741b87c..7aeadb55d4511d2ed5ca8390022e46df39d78bd6 100755 --- a/run_alphafold.py +++ b/run_alphafold.py @@ -572,7 +572,10 @@ def process_fold_input( print('Skipping data pipeline...') else: print('Running data pipeline...') - fold_input = pipeline.DataPipeline(data_pipeline_config).process(fold_input) + try: + fold_input = pipeline.DataPipeline(data_pipeline_config).process(fold_input) + except: + return 1 print(f'Output directory: {output_dir}') print(f'Writing model input JSON to {output_dir}') diff --git a/run_chai.py b/run_chai.py index 443f0f0e690fbe11546afb07adaa296621c6c310..89078d195671879e072dff935b7bd7687d28027b 100644 --- a/run_chai.py +++ b/run_chai.py @@ -1,27 +1,34 @@ +import os import sys +from glob import glob from pathlib import Path import numpy as np from chai_lab.chai1 import run_inference def main(): - fasta_path, msa_dir, output_dir, num_diffn_samples, seed = sys.argv[1 :] + fasta_dir, msa_dir, output_dir, num_diffn_samples, str_seeds = sys.argv[1 :] + fasta_paths = glob(f'{fasta_dir}/*.fasta') # Generate structure - candidates = run_inference( - fasta_file=Path(fasta_path), - output_dir=Path(output_dir), - # 'default' setup - num_trunk_recycles=3, - num_diffn_timesteps=200, - num_diffn_samples=int(num_diffn_samples), - seed=(seed), - device="cuda:0", - use_esm_embeddings=True, - # See example .aligned.pqt files in this directory - msa_directory=Path(msa_dir), - # Exclusive with msa_directory; can be used for MMseqs2 server MSA generation - use_msa_server=False, - ) + for fasta_path in fasta_paths: + config_name = os.path.basename(fasta_path).split('.')[0] + for seed in str_seeds.split(','): + print(fasta_path, seed) + candidates = run_inference( + fasta_file=Path(fasta_path), + output_dir=Path(f'{output_dir}/{config_name}/seed_{seed}'), + # 'default' setup + num_trunk_recycles=3, + num_diffn_timesteps=200, + num_diffn_samples=int(num_diffn_samples), + seed=int(seed), + device="cuda:0", + use_esm_embeddings=True, + # See example .aligned.pqt files in this directory + msa_directory=Path(f'{msa_dir}/{config_name}/chai/'), + # Exclusive with msa_directory; can be used for MMseqs2 server MSA generation + use_msa_server=False, + ) # cif_paths = candidates.cif_paths # scores = [rd.aggregate_score for rd in candidates.ranking_data]