diff --git a/README.md b/README.md
index 5e14e8a646b186988df1851508a8858108df06fb..f9fc06717dff90ee78abc20986860edd0d90c806 100644
--- a/README.md
+++ b/README.md
@@ -11,21 +11,4 @@
 ### 工程
 1. 实现多进程预测功能,例如在一个模型推理的时候,同步加载其他模型的checkpoint到Memory中
     - 思路是维护4个独立运行的uvicorn服务 or MOS服务
-2. 打包成docker image,每次启动的时候从AOSS上下载env和checkpoint到本地
-
-## 草稿
-### Configuration
-#### Chai-1
-```bash
-mkdir -p ~/.cache/
-cd ~/.cache
-ln -s /ai/share/workspace/zhangcb/.cache/huggingface huggingface
-
-source /ai/share/workspace/zhangcb/conda_env/chai/bin/activate
-chai fold --use-msa-server $input_path $output_dir
-```
-
-#### ColabFold MSA
-```bash
-export PATH=$PATH:/ai/share/workspace/zhangcb/LocalBins
-```
\ No newline at end of file
+2. 打包成docker image,每次启动的时候从AOSS上下载env和checkpoint到本地
\ No newline at end of file
diff --git a/protenix/configs/configs_base.py b/protenix/configs/configs_base.py
index 4a6e219ba19237316ef251731f9f83449f982cb7..327b2eccd81a9a5b8be3918cf3aeb5a418df042b 100644
--- a/protenix/configs/configs_base.py
+++ b/protenix/configs/configs_base.py
@@ -31,9 +31,9 @@ basic_configs = {
     "eval_first": False,  # run evaluate() before training steps
     "iters_to_accumulate": 1,
     "eval_only": False,
-    "load_checkpoint_path": "",
+    "load_checkpoint_path": "/ai/share/workspace/zhangcb/ABCPFold/src/protenix/release_data/checkpoint/model_v0.2.0.pt",
     "load_ema_checkpoint_path": "",
-    "load_strict": False,
+    "load_strict": True,
     "load_params_only": True,
     "skip_load_step": False,
     "skip_load_optimizer": False,
diff --git a/protenix/configs/configs_data.py b/protenix/configs/configs_data.py
index ecf50725ae27e9f0f9a39bc3a690700ed01e3a0d..99f3c7a09781cb06d99fdcfd94db22a0eae1b7d9 100644
--- a/protenix/configs/configs_data.py
+++ b/protenix/configs/configs_data.py
@@ -60,12 +60,24 @@ default_weighted_pdb_configs = {
     "shuffle_sym_ids": GlobalConfigValue("train_shuffle_sym_ids"),
 }
 
-DATA_ROOT_DIR = "/af3-dev/release_data/"
-CCD_COMPONENTS_FILE_PATH = os.path.join(DATA_ROOT_DIR, "components.v20240608.cif")
+DATA_ROOT_DIR = os.environ.get("PROTENIX_DATA_ROOT_DIR", "/af3-dev/release_data/")
+
+# Use CCD cache created by scripts/gen_ccd_cache.py priority. (without date in filename)
+# See: docs/prepare_data.md
+CCD_COMPONENTS_FILE_PATH = os.path.join(DATA_ROOT_DIR, "components.cif")
 CCD_COMPONENTS_RDKIT_MOL_FILE_PATH = os.path.join(
-    DATA_ROOT_DIR, "components.v20240608.cif.rdkit_mol.pkl"
+    DATA_ROOT_DIR, "components.cif.rdkit_mol.pkl"
 )
 
+if (not os.path.exists(CCD_COMPONENTS_FILE_PATH)) or (
+    not os.path.exists(CCD_COMPONENTS_RDKIT_MOL_FILE_PATH)
+):
+    CCD_COMPONENTS_FILE_PATH = os.path.join(DATA_ROOT_DIR, "components.v20240608.cif")
+    CCD_COMPONENTS_RDKIT_MOL_FILE_PATH = os.path.join(
+        DATA_ROOT_DIR, "components.v20240608.cif.rdkit_mol.pkl"
+    )
+
+
 # This is a patch in inference stage for users that do not have root permission.
 # If you run
 # ```
@@ -80,17 +92,25 @@ if (not os.path.exists(CCD_COMPONENTS_FILE_PATH)) or (
     not os.path.exists(CCD_COMPONENTS_RDKIT_MOL_FILE_PATH)
 ):
     print("Try to find the ccd cache data in the code directory for inference.")
-    # current_file_path = os.path.abspath(__file__)
-    # current_directory = os.path.dirname(current_file_path)
-    # code_directory = os.path.dirname(current_directory)
-    code_directory = '/ai/share/workspace/zhangcb/Protenix/'
+    current_file_path = os.path.abspath(__file__)
+    current_directory = os.path.dirname(current_file_path)
+    code_directory = os.path.dirname(current_directory)
 
     data_cache_dir = os.path.join(code_directory, "release_data/ccd_cache")
-    CCD_COMPONENTS_FILE_PATH = os.path.join(data_cache_dir, "components.v20240608.cif")
+    CCD_COMPONENTS_FILE_PATH = os.path.join(data_cache_dir, "components.cif")
     CCD_COMPONENTS_RDKIT_MOL_FILE_PATH = os.path.join(
-        data_cache_dir, "components.v20240608.cif.rdkit_mol.pkl"
+        data_cache_dir, "components.cif.rdkit_mol.pkl"
     )
+    if (not os.path.exists(CCD_COMPONENTS_FILE_PATH)) or (
+        not os.path.exists(CCD_COMPONENTS_RDKIT_MOL_FILE_PATH)
+    ):
 
+        CCD_COMPONENTS_FILE_PATH = os.path.join(
+            data_cache_dir, "components.v20240608.cif"
+        )
+        CCD_COMPONENTS_RDKIT_MOL_FILE_PATH = os.path.join(
+            data_cache_dir, "components.v20240608.cif.rdkit_mol.pkl"
+        )
 
 data_configs = {
     "num_dl_workers": 16,
diff --git a/protenix/configs/configs_inference.py b/protenix/configs/configs_inference.py
index b52f7a4333d5c38c40ae066d809d8d4d8f2e7026..2155148694d4561aaf860dc35b8018c05554cbcf 100644
--- a/protenix/configs/configs_inference.py
+++ b/protenix/configs/configs_inference.py
@@ -24,10 +24,9 @@ code_directory = os.path.dirname(current_directory)
 # "./release_data/checkpoint/model_v0.2.0.pt"
 inference_configs = {
     "seeds": ListValue([101]),
-    # "seeds":  ListValue(range(101, 111)),
-    # "seeds": ListValue(range(1, 101)),
     "dump_dir": "./output",
     "need_atom_confidence": False,
+    "sorted_by_ranking_score": True,
     "input_json_path": RequiredValue(str),
     "load_checkpoint_path": os.path.join(
         code_directory, "./release_data/checkpoint/model_v0.2.0.pt"
diff --git a/protenix/inference.py b/protenix/inference.py
index 0d12c365b561455cbbb215f544addd8e9ec3c2bf..5b3e1b172126c4a26d9581cbef31e9f24005f119 100644
--- a/protenix/inference.py
+++ b/protenix/inference.py
@@ -23,10 +23,11 @@ from typing import Any, Mapping
 
 import torch
 import torch.distributed as dist
-import sys
 from configs.configs_base import configs as configs_base
 from configs.configs_data import data_configs
 from configs.configs_inference import inference_configs
+from runner.dumper import DataDumper
+
 from protenix.config import parse_configs, parse_sys_args
 from protenix.data.infer_data_pipeline import get_inference_dataloader
 from protenix.model.protenix import Protenix
@@ -34,8 +35,6 @@ from protenix.utils.distributed import DIST_WRAPPER
 from protenix.utils.seed import seed_everything
 from protenix.utils.torch_utils import to_device
 from protenix.web_service.dependency_url import URL
-from runner.dumper import DataDumper
-from runner.msa_search import contain_msa_res, msa_search_update
 
 logger = logging.getLogger(__name__)
 
@@ -47,7 +46,10 @@ class InferenceRunner(object):
         self.init_basics()
         self.init_model()
         self.load_checkpoint()
-        self.init_dumper(need_atom_confidence=configs.need_atom_confidence)
+        self.init_dumper(
+            need_atom_confidence=configs.need_atom_confidence,
+            sorted_by_ranking_score=configs.sorted_by_ranking_score,
+        )
 
     def init_env(self) -> None:
         self.print(
@@ -112,14 +114,18 @@ class InferenceRunner(object):
             }
         self.model.load_state_dict(
             state_dict=checkpoint["model"],
-            strict=True,
+            strict=self.configs.load_strict,
         )
         self.model.eval()
         self.print(f"Finish loading checkpoint.")
 
-    def init_dumper(self, need_atom_confidence: bool = False):
+    def init_dumper(
+        self, need_atom_confidence: bool = False, sorted_by_ranking_score: bool = True
+    ):
         self.dumper = DataDumper(
-            base_dir=self.dump_dir, need_atom_confidence=need_atom_confidence
+            base_dir=self.dump_dir,
+            need_atom_confidence=need_atom_confidence,
+            sorted_by_ranking_score=sorted_by_ranking_score,
         )
 
     # Adapted from runner.train.Trainer.evaluate
@@ -157,32 +163,29 @@ class InferenceRunner(object):
 
 
 def download_infercence_cache(configs: Any, model_version: str = "v0.2.0") -> None:
-    # current_file_path = os.path.abspath(__file__)
-    # current_directory = os.path.dirname(current_file_path)
-    # code_directory = os.path.dirname(current_directory)
-    code_directory = os.path.dirname(os.path.abspath(__file__))
-    # code_directory = '/ai/share/workspace/zhangcb/Protenix/'
-
-    data_cache_dir = os.path.join(code_directory, "release_data/ccd_cache")
-    os.makedirs(data_cache_dir, exist_ok=True)
-    for cache_name, fname in [
-        ("ccd_components_file", "components.v20240608.cif"),
-        ("ccd_components_rdkit_mol_file", "components.v20240608.cif.rdkit_mol.pkl"),
-    ]:
-        if not opexists(cache_path := os.path.abspath(opjoin(data_cache_dir, fname))):
+
+    for cache_name in ("ccd_components_file", "ccd_components_rdkit_mol_file"):
+        cur_cache_fpath = configs["data"][cache_name]
+        if not opexists(cur_cache_fpath):
+            os.makedirs(os.path.dirname(cur_cache_fpath), exist_ok=True)
             tos_url = URL[cache_name]
-            logger.info(f"Downloading data cache from\n {tos_url} to {cache_path}")
-            urllib.request.urlretrieve(tos_url, cache_path)
+            assert os.path.basename(tos_url) == os.path.basename(cur_cache_fpath), (
+                f"{cache_name} file name is incorrect, `{tos_url}` and "
+                f"`{cur_cache_fpath}`. Please check and try again."
+            )
+            logger.info(
+                f"Downloading data cache from\n {tos_url}... to {cur_cache_fpath}"
+            )
+            urllib.request.urlretrieve(tos_url, cur_cache_fpath)
 
     checkpoint_path = configs.load_checkpoint_path
 
     if not opexists(checkpoint_path):
-        checkpoint_path = os.path.join(
-            code_directory, f"release_data/checkpoint/model_{model_version}.pt"
-        )
         os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
         tos_url = URL[f"model_{model_version}"]
-        logger.info(f"Downloading model checkpoint from\n {tos_url} to {checkpoint_path}")
+        logger.info(
+            f"Downloading model checkpoint from\n {tos_url}... to {checkpoint_path}"
+        )
         urllib.request.urlretrieve(tos_url, checkpoint_path)
         try:
             ckpt = torch.load(checkpoint_path)
@@ -193,7 +196,6 @@ def download_infercence_cache(configs: Any, model_version: str = "v0.2.0") -> No
                 "Download model checkpoint failed, please download by yourself with "
                 f"wget {tos_url} -O {checkpoint_path}"
             )
-        configs.load_checkpoint_path = checkpoint_path
 
 
 def update_inference_configs(configs: Any, N_token: int):
@@ -213,38 +215,31 @@ def update_inference_configs(configs: Any, N_token: int):
 
 
 def infer_predict(runner: InferenceRunner, configs: Any) -> None:
-    # update msa result if not contains precomputed msa dir
-    if not contain_msa_res(configs.input_json_path):
-        logger.info(
-            f"{configs.input_json_path} dose not contain precomputed msa dir, now searching it."
-        )
-        configs.input_json_path = msa_search_update(
-            configs.input_json_path, configs.dump_dir
-        )
-        logger.info(
-            f"msa searching completed, new input json is {configs.input_json_path}"
-        )
     # Data
     logger.info(f"Loading data from\n{configs.input_json_path}")
-    dataloader = get_inference_dataloader(configs=configs)
+    try:
+        dataloader = get_inference_dataloader(configs=configs)
+    except Exception as e:
+        error_message = f"{e}:\n{traceback.format_exc()}"
+        logger.info(error_message)
+        with open(opjoin(runner.error_dir, "error.txt"), "a") as f:
+            f.write(error_message)
+        return
 
     num_data = len(dataloader.dataset)
     for seed in configs.seeds:
-        seed_everything(seed=seed, deterministic=False)
+        seed_everything(seed=seed, deterministic=configs.deterministic)
         for batch in dataloader:
             try:
                 data, atom_array, data_error_message = batch[0]
+                sample_name = data["sample_name"]
 
                 if len(data_error_message) > 0:
                     logger.info(data_error_message)
-                    with open(
-                        opjoin(runner.error_dir, f"{data['sample_name']}.txt"),
-                        "w",
-                    ) as f:
+                    with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f:
                         f.write(data_error_message)
                     continue
 
-                sample_name = data["sample_name"]
                 logger.info(
                     (
                         f"[Rank {DIST_WRAPPER.rank} ({data['sample_index'] + 1}/{num_data})] {sample_name}: "
@@ -273,15 +268,10 @@ def infer_predict(runner: InferenceRunner, configs: Any) -> None:
                 error_message = f"[Rank {DIST_WRAPPER.rank}]{data['sample_name']} {e}:\n{traceback.format_exc()}"
                 logger.info(error_message)
                 # Save error info
-                if opexists(
-                    error_path := opjoin(runner.error_dir, f"{sample_name}.txt")
-                ):
-                    os.remove(error_path)
-                with open(error_path, "w") as f:
+                with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f:
                     f.write(error_message)
                 if hasattr(torch.cuda, "empty_cache"):
                     torch.cuda.empty_cache()
-                raise RuntimeError(f"run infer failed: {str(e)}")
 
 
 def main(configs: Any) -> None:
@@ -299,7 +289,7 @@ def run() -> None:
         filemode="w",
     )
     configs_base["use_deepspeed_evo_attention"] = (
-        os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "true"
+        os.environ.get("USE_DEEPSPEED_EVO_ATTENTION", False) == "true"
     )
     configs = {**configs_base, **{"data": data_configs}, **inference_configs}
     configs = parse_configs(
diff --git a/protenix/protenix/data/constants.py b/protenix/protenix/data/constants.py
index 4e786a83999d5ab463f24ad7f6922e044bda9398..308adf6c30be434c2a3eeb2ac85fd49ceee2b66a 100644
--- a/protenix/protenix/data/constants.py
+++ b/protenix/protenix/data/constants.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from rdkit.Chem import GetPeriodicTable
+
 EvaluationChainInterface = [
     "intra_ligand",
     "intra_dna",
@@ -36,10 +38,16 @@ EntityPolyTypeDict = {
     "ligand": ["cyclic-pseudo-peptide", "other"],
 }
 
+CRYSTALLIZATION_METHODS = {
+    "X-RAY DIFFRACTION",
+    "NEUTRON DIFFRACTION",
+    "ELECTRON CRYSTALLOGRAPHY",
+    "POWDER CRYSTALLOGRAPHY",
+    "FIBER DIFFRACTION",
+}
 
 ### Protein Constants ###
 # https://mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v40.dic/Items/_entity_poly.pdbx_seq_one_letter_code_can.html
-from rdkit.Chem import GetPeriodicTable
 
 mmcif_restype_1to3 = {
     "A": "ALA",
diff --git a/protenix/protenix/data/data_pipeline.py b/protenix/protenix/data/data_pipeline.py
index 8dad0b7ec8f45c35d7850d996e9e0054891fad52..e254655dfcde2f1aa7ba6d0996cb5877a21e4f96 100644
--- a/protenix/protenix/data/data_pipeline.py
+++ b/protenix/protenix/data/data_pipeline.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
 import os
 from collections import defaultdict
 from pathlib import Path
@@ -24,7 +25,8 @@ import torch
 from biotite.structure import AtomArray
 
 from protenix.data.msa_featurizer import MSAFeaturizer
-from protenix.data.tokenizer import TokenArray
+from protenix.data.parser import DistillationMMCIFParser, MMCIFParser
+from protenix.data.tokenizer import AtomArrayTokenizer, TokenArray
 from protenix.utils.cropping import CropData
 from protenix.utils.file_io import load_gzip_pickle
 
@@ -32,6 +34,64 @@ torch.multiprocessing.set_sharing_strategy("file_system")
 
 
 class DataPipeline(object):
+    """
+    DataPipeline class provides static methods to handle various data processing tasks related to bioassembly structures.
+    """
+
+    @staticmethod
+    def get_data_from_mmcif(
+        mmcif: Union[str, Path],
+        pdb_cluster_file: Union[str, Path, None] = None,
+        dataset: str = "WeightedPDB",
+    ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
+        """
+        Get raw data from mmcif with tokenizer and a list of chains and interfaces for sampling.
+
+        Args:
+            mmcif (Union[str, Path]): The raw mmcif file.
+            pdb_cluster_file (Union[str, Path, None], optional): Cluster info txt file. Defaults to None.
+            dataset (str, optional): The dataset type, either "WeightedPDB" or "Distillation". Defaults to "WeightedPDB".
+
+        Returns:
+            tuple[list[dict[str, Any]], dict[str, Any]]:
+                sample_indices_list (list[dict[str, Any]]): The sample indices list (each one is a chain or an interface).
+                bioassembly_dict (dict[str, Any]): The bioassembly dict with sequence, atom_array, and token_array.
+        """
+        try:
+            if dataset == "WeightedPDB":
+                parser = MMCIFParser(mmcif_file=mmcif)
+                bioassembly_dict = parser.get_bioassembly()
+            elif dataset == "Distillation":
+                parser = DistillationMMCIFParser(mmcif_file=mmcif)
+                bioassembly_dict = parser.get_structure_dict()
+            else:
+                raise NotImplementedError(
+                    'Unsupported "dataset", please input either "WeightedPDB" or "Distillation".'
+                )
+
+            sample_indices_list = parser.make_indices(
+                bioassembly_dict=bioassembly_dict, pdb_cluster_file=pdb_cluster_file
+            )
+            if len(sample_indices_list) == 0:
+                # empty indices and AtomArray
+                return [], bioassembly_dict
+
+            atom_array = bioassembly_dict["atom_array"]
+            atom_array.set_annotation(
+                "resolution", [parser.resolution] * len(atom_array)
+            )
+
+            tokenizer = AtomArrayTokenizer(atom_array)
+            token_array = tokenizer.get_token_array()
+            bioassembly_dict["msa_features"] = None
+            bioassembly_dict["template_features"] = None
+
+            bioassembly_dict["token_array"] = token_array
+            return sample_indices_list, bioassembly_dict
+
+        except Exception as e:
+            logging.warning("Gen data failed for %s due to %s", mmcif, e)
+            return [], {}
 
     @staticmethod
     def get_label_entity_id_to_asym_id_int(atom_array: AtomArray) -> dict[str, int]:
diff --git a/protenix/protenix/data/dataset.py b/protenix/protenix/data/dataset.py
index 62878cbf072d4cd71ee15f89b110342fa6c45136..fcaf3164de54281232cdaa012efdbdff1eeb6d97 100644
--- a/protenix/protenix/data/dataset.py
+++ b/protenix/protenix/data/dataset.py
@@ -894,25 +894,34 @@ def get_weighted_pdb_weight(
     cluster_size: int,
     chain_count: dict,
     eps: float = 1e-9,
-    beta_dict: dict = {
-        "chain": 0.5,
-        "interface": 1,
-    },
-    alpha_dict: dict = {
-        "prot": 3,
-        "nuc": 3,
-        "ligand": 1,
-    },
+    beta_dict: Optional[dict] = None,
+    alpha_dict: Optional[dict] = None,
 ) -> float:
     """
-    Get sample weight for each examples in weighted pdb dataset. AF3-SI (1)
-    Args:
-        data_type: chain or interface
-        cluster_size: cluster size of this chain/interface
-        chain_count: count of each kinds of chains, {"prot": int, "nuc": int, "ligand": int}
+    Get sample weight for each example in a weighted PDB dataset.
+
+        data_type (str): Type of data, either 'chain' or 'interface'.
+        cluster_size (int): Cluster size of this chain/interface.
+        chain_count (dict): Count of each kind of chains, e.g., {"prot": int, "nuc": int, "ligand": int}.
+        eps (float, optional): A small epsilon value to avoid division by zero. Default is 1e-9.
+        beta_dict (Optional[dict], optional): Dictionary containing beta values for 'chain' and 'interface'.
+        alpha_dict (Optional[dict], optional): Dictionary containing alpha values for different chain types.
+
     Returns:
-        weights: float
+         float: Calculated weight for the given chain/interface.
     """
+    if not beta_dict:
+        beta_dict = {
+            "chain": 0.5,
+            "interface": 1,
+        }
+    if not alpha_dict:
+        alpha_dict = {
+            "prot": 3,
+            "nuc": 3,
+            "ligand": 1,
+        }
+
     assert cluster_size > 0
     assert data_type in ["chain", "interface"]
     beta = beta_dict[data_type]
diff --git a/protenix/protenix/data/filter.py b/protenix/protenix/data/filter.py
index d9d05b43007dea83d425e039ed4064ae66426f81..1066b8f4a9cc11aeee7e5356336f73eb11de3b0d 100644
--- a/protenix/protenix/data/filter.py
+++ b/protenix/protenix/data/filter.py
@@ -14,7 +14,8 @@
 
 import biotite.structure as struc
 import numpy as np
-from biotite.structure import AtomArray
+from biotite.structure import AtomArray, get_molecule_indices
+from scipy.spatial.distance import cdist
 
 from protenix.data.constants import CRYSTALLIZATION_AIDS
 
@@ -80,3 +81,557 @@ class Filter(object):
         non_aids_mask = ~np.isin(atom_array.res_name, CRYSTALLIZATION_AIDS)
         poly_mask = np.isin(atom_array.label_entity_id, list(entity_poly_type.keys()))
         return atom_array[poly_mask | non_aids_mask]
+
+    @staticmethod
+    def _get_clashing_chains(
+        atom_array: AtomArray, chain_ids: list[str]
+    ) -> tuple[np.ndarray, list[int]]:
+        """
+        Calculate the number of atoms clashing with other chains for each chain
+        and return a matrix that records the count of clashing atoms.
+
+        Note: if two chains are covalent, they are not considered as clashing.
+
+        Args:
+            atom_array (AtomArray): All atoms, including those not resolved.
+            chain_ids (list[str]): Unique chain indices of resolved atoms.
+
+        Returns:
+            tuple:
+                clash_records (numpy.ndarray): Matrix of clashing atom num.
+                                               (i, j) means the ratio of i's atom clashed with j's atoms.
+                                               Note: (i, j) != (j, i).
+                chain_resolved_atom_nums (list[int]): The number of resolved atoms corresponding to each chain ID.
+        """
+        is_resolved_centre_atom = (
+            atom_array.centre_atom_mask == 1
+        ) & atom_array.is_resolved
+        cell_list = struc.CellList(
+            atom_array, cell_size=1.7, selection=is_resolved_centre_atom
+        )
+
+        # (i, j) means the ratio of i's atom clashed with j's atoms
+        clash_records = np.zeros((len(chain_ids), len(chain_ids)))
+
+        # record the number of resolved atoms for each chain
+        chain_resolved_atom_nums = []
+
+        # record covalent relationship between chains
+        chains_covalent_dict = {}
+        for idx, chain_id_i in enumerate(chain_ids):
+            for chain_id_j in chain_ids[idx + 1 :]:
+                mol_indices = get_molecule_indices(
+                    atom_array[np.isin(atom_array.chain_id, [chain_id_i, chain_id_j])]
+                )
+                if len(mol_indices) == 1:
+                    covalent = 1
+                else:
+                    covalent = 0
+                chains_covalent_dict[(chain_id_i, chain_id_j)] = covalent
+                chains_covalent_dict[(chain_id_j, chain_id_i)] = covalent
+
+        for i, chain_id in enumerate(chain_ids):
+            coords = atom_array.coord[
+                (atom_array.chain_id == chain_id) & is_resolved_centre_atom
+            ]
+            chain_resolved_atom_nums.append(len(coords))
+            chain_atom_ids = np.where(atom_array.chain_id == chain_id)[0]
+            chain_atom_ids_set = set(chain_atom_ids) | {-1}
+
+            # Get atom indices from the current cell and the eight surrounding cells.
+            neighbors_ids_2d = cell_list.get_atoms_in_cells(coords, cell_radius=1)
+            neighbors_ids = np.unique(neighbors_ids_2d)
+
+            # Remove the atom indices of the current chain.
+            other_chain_atom_ids = list(set(neighbors_ids) - chain_atom_ids_set)
+
+            if not other_chain_atom_ids:
+                continue
+            else:
+                # Calculate the distance matrix with neighboring atoms.
+                other_chain_atom_coords = atom_array.coord[other_chain_atom_ids]
+                dist_mat = cdist(coords, other_chain_atom_coords, metric="euclidean")
+                clash_mat = dist_mat < 1.6  # change 1.7 to 1.6 for more compatibility
+                if np.any(clash_mat):
+                    clashed_other_chain_ids = atom_array.chain_id[other_chain_atom_ids]
+
+                    for other_chain_id in set(clashed_other_chain_ids):
+
+                        # two chains covalent with each other
+                        if chains_covalent_dict[(chain_id, other_chain_id)]:
+                            continue
+
+                        cols = np.where(clashed_other_chain_ids == other_chain_id)[0]
+
+                        # how many i's atoms clashed with j
+                        any_atom_clashed = np.any(
+                            clash_mat[:, cols].astype(int), axis=1
+                        )
+                        clashed_atom_num = np.sum(any_atom_clashed.astype(int))
+
+                        if clashed_atom_num > 0:
+                            j = chain_ids.index(other_chain_id)
+                            clash_records[i][j] += clashed_atom_num
+        return clash_records, chain_resolved_atom_nums
+
+    @staticmethod
+    def _get_removed_clash_chain_ids(
+        clash_records: np.ndarray,
+        chain_ids: list[str],
+        chain_resolved_atom_nums: list[int],
+        core_chain_id: np.ndarray = [],
+    ) -> list[str]:
+        """
+        Perform pairwise comparisons on the chains, and select the chain IDs
+        to be deleted according to the clahsing chain rules.
+
+        Args:
+            clash_records (numpy.ndarray): Matrix of clashing atom num.
+                                           (i, j) means the ratio of i's atom clashed with j's atoms.
+                                           Note: (i, j) != (j, i).
+            chain_ids (list[str]): Unique chain indices of resolved atoms.
+            chain_resolved_atom_nums (list[int]): The number of resolved atoms corresponding to each chain ID.
+            core_chain_id (np.ndarray): The chain ID of the core chain.
+
+        Returns:
+            list[str]: A list of chain IDs that have been determined for deletion.
+        """
+        removed_chain_ids = []
+        for i in range(len(chain_ids)):
+            atom_num_i = chain_resolved_atom_nums[i]
+            chain_idx_i = chain_ids[i]
+
+            if chain_idx_i in removed_chain_ids:
+                continue
+
+            for j in range(i + 1, len(chain_ids)):
+                atom_num_j = chain_resolved_atom_nums[j]
+                chain_idx_j = chain_ids[j]
+
+                if chain_idx_j in removed_chain_ids:
+                    continue
+
+                clash_num_ij, clash_num_ji = (
+                    clash_records[i][j],
+                    clash_records[j][i],
+                )
+
+                clash_ratio_ij = clash_num_ij / atom_num_i
+                clash_ratio_ji = clash_num_ji / atom_num_j
+
+                if clash_ratio_ij <= 0.3 and clash_ratio_ji <= 0.3:
+                    # not reaches the threshold
+                    continue
+                else:
+                    # clashing chains
+                    if (
+                        chain_idx_i in core_chain_id
+                        and chain_idx_j not in core_chain_id
+                    ):
+                        removed_chain_idx = chain_idx_j
+                    elif (
+                        chain_idx_i not in core_chain_id
+                        and chain_idx_j in core_chain_id
+                    ):
+                        removed_chain_idx = chain_idx_i
+
+                    elif clash_ratio_ij > clash_ratio_ji:
+                        removed_chain_idx = chain_idx_i
+                    elif clash_ratio_ij < clash_ratio_ji:
+                        removed_chain_idx = chain_idx_j
+                    else:
+                        if atom_num_i < atom_num_j:
+                            removed_chain_idx = chain_idx_i
+                        elif atom_num_i > atom_num_j:
+                            removed_chain_idx = chain_idx_j
+                        else:
+                            removed_chain_idx = sorted([chain_idx_i, chain_idx_j])[1]
+
+                    removed_chain_ids.append(removed_chain_idx)
+
+                    if removed_chain_idx == chain_idx_i:
+                        # chain i already removed
+                        break
+        return removed_chain_ids
+
+    @staticmethod
+    def remove_polymer_chains_all_residues_unknown(
+        atom_array: AtomArray,
+        entity_poly_type: dict,
+    ) -> AtomArray:
+        """remove chains with all residues unknown"""
+        chain_starts = struc.get_chain_starts(atom_array, add_exclusive_stop=True)
+        invalid_chains = []  # list of [start, end)
+        for index in range(len(chain_starts) - 1):
+            start, end = chain_starts[index], chain_starts[index + 1]
+            entity_id = atom_array[start].label_entity_id
+            if (
+                entity_poly_type.get(entity_id, "non-poly") == "polypeptide(L)"
+                and np.all(atom_array.res_name[start:end] == "UNK")
+            ) or (
+                entity_poly_type.get(entity_id, "non-poly")
+                in (
+                    "polyribonucleotide",
+                    "polydeoxyribonucleotide",
+                )
+                and np.all(atom_array.res_name[start:end] == "N")
+            ):
+                invalid_chains.append((start, end))
+        mask = np.ones(len(atom_array), dtype=bool)
+        for start, end in invalid_chains:
+            mask[start:end] = False
+        atom_array = atom_array[mask]
+        return atom_array
+
+    @staticmethod
+    def remove_polymer_chains_too_short(
+        atom_array: AtomArray, entity_poly_type: dict
+    ) -> AtomArray:
+        chain_starts = struc.get_chain_starts(atom_array, add_exclusive_stop=True)
+        invalid_chains = []  # list of [start, end)
+        for index in range(len(chain_starts) - 1):
+            start, end = chain_starts[index], chain_starts[index + 1]
+            entity_id = atom_array[start].label_entity_id
+            num_residue_ids = len(set(atom_array.label_seq_id[start:end]))
+            if (
+                entity_poly_type.get(entity_id, "non-poly")
+                in (
+                    "polypeptide(L)",  # TODO: how to handle polypeptide(D)?
+                    "polyribonucleotide",
+                    "polydeoxyribonucleotide",
+                )
+                and num_residue_ids < 4
+            ):
+                invalid_chains.append((start, end))
+        mask = np.ones(len(atom_array), dtype=bool)
+        for start, end in invalid_chains:
+            mask[start:end] = False
+        atom_array = atom_array[mask]
+        return atom_array
+
+    @staticmethod
+    def remove_polymer_chains_with_consecutive_c_alpha_too_far_away(
+        atom_array: AtomArray, entity_poly_type: dict, max_distance: float = 10.0
+    ) -> AtomArray:
+        chain_starts = struc.get_chain_starts(atom_array, add_exclusive_stop=True)
+        invalid_chains = []  # list of [start, end)
+        for index in range(len(chain_starts) - 1):
+            start, end = chain_starts[index], chain_starts[index + 1]
+            entity_id = atom_array.label_entity_id[start]
+            if entity_poly_type.get(entity_id, "non-poly") == "polypeptide(L)":
+                peptide_atoms = atom_array[start:end]
+                ca_atoms = peptide_atoms[peptide_atoms.atom_name == "CA"]
+                seq_ids = ca_atoms.label_seq_id
+                seq_ids[seq_ids == "."] = "-100"
+                seq_ids = seq_ids.astype(np.int64)
+                dist_square = np.sum(
+                    (ca_atoms[:-1].coord - ca_atoms[1:].coord) ** 2, axis=-1
+                )
+                invalid_neighbor_mask = (dist_square > max_distance**2) & (
+                    seq_ids[:-1] + 1 == seq_ids[1:]
+                )
+                if np.any(invalid_neighbor_mask):
+                    invalid_chains.append((start, end))
+        mask = np.ones(len(atom_array), dtype=bool)
+        for start, end in invalid_chains:
+            mask[start:end] = False
+        atom_array = atom_array[mask]
+        return atom_array
+
+    @staticmethod
+    def too_many_chains_filter(
+        atom_array: AtomArray,
+        interface_radius: int = 15,
+        max_chains_num: int = 20,
+        core_indices: list[int] = None,
+        max_tokens_num: int = None,
+    ) -> tuple[AtomArray, int]:
+        """
+        Ref: AlphaFold3 SI Chapter 2.5.4
+
+        For bioassemblies with greater than 20 chains, we select a random interface token
+        (with a centre atom <15 Ã… to the centre atom of a token in another chain)
+        and select the closest 20 chains to this token based on
+        minimum distance between any tokens centre atom.
+
+        Note: due to the presence of covalent small molecules,
+        treat the covalent small molecule and the polymer it is attached to
+        as a single chain to avoid inadvertently removing the covalent small molecules.
+        Use the mol_id added to the AtomArray to differentiate between the various
+        parts of the structure composed of covalent bonds.
+
+        Args:
+            atom_array (AtomArray): Biotite AtomArray Object of a Bioassembly.
+            interface_radius (int, optional): Atoms within this distance of the central atom are considered interface atoms.
+                                            Defaults to 15.
+            max_chains_num (int, optional): The maximum number of chains permitted in a bioassembly.
+                                            Filtration will be applied if exceeds this value. Defaults to 20.
+            core_indices (list[int], optional): A list of indices to be used as chose the central atom.
+                                                     And corresponding chains in the list will be selected proriority.
+                                                     If None, a random index from whole AtomArray will be selected. Defaults to None.
+            max_tokens_num (int, optional): The maximum number of tokens permitted in a bioassembly.
+                                            If not None,  after more than max_chains_num, if the max_tokens_num is not reached,
+                                            it will continue to append the chains.
+
+        Returns:
+            tuple:
+                - atom_array (AtomArray): An AtomArray that has been processed through this filter.
+                - input_chains_num (int): The number of chain in the input AtomArray.
+                                          This is to log whether the filter has been utilized.
+        """
+        # each mol is a so called "chain" in the context of this filter.
+        input_chains_num = len(np.unique(atom_array.mol_id))
+        if input_chains_num <= max_chains_num:
+            # no change
+            return atom_array, input_chains_num
+
+        is_resolved_centre_atom = (
+            atom_array.centre_atom_mask == 1
+        ) & atom_array.is_resolved
+
+        cell_list = struc.CellList(
+            atom_array, cell_size=interface_radius, selection=is_resolved_centre_atom
+        )
+        resolved_centre_atom = atom_array[is_resolved_centre_atom]
+
+        assert resolved_centre_atom, "There is no resolved central atom."
+
+        # random pick centre atom
+        if core_indices is None:
+            index_shuf = np.random.default_rng(seed=42).permutation(
+                len(resolved_centre_atom)
+            )
+        else:
+            index_shuf = np.array(core_indices)
+            resolved_centre_atom_indices = np.nonzero(is_resolved_centre_atom)[0]
+
+            # get indices of resolved_centre_atom
+            index_shuf = np.array(
+                [
+                    np.where(resolved_centre_atom_indices == idx)[0][0]
+                    for idx in index_shuf
+                    if idx in resolved_centre_atom_indices
+                ]
+            )
+            np.random.default_rng(seed=42).shuffle(index_shuf)
+
+        chosen_centre_atom = None
+        for idx in index_shuf:
+            centre_atom = resolved_centre_atom[idx]
+            neighbors_indices = cell_list.get_atoms(
+                centre_atom.coord, radius=interface_radius
+            )
+            neighbors_indices = neighbors_indices[neighbors_indices != -1]
+
+            neighbors_chain_ids = np.unique(atom_array.mol_id[neighbors_indices])
+            # neighbors include centre atom itself
+            if len(neighbors_chain_ids) > 1:
+                chosen_centre_atom = centre_atom
+                break
+
+        # The distance between the central atoms in any two chains is greater than 15 angstroms.
+        if chosen_centre_atom is None:
+            return None, input_chains_num
+
+        dist_mat = cdist(centre_atom.coord.reshape((1, -1)), resolved_centre_atom.coord)
+        sorted_chain_id = np.array(
+            [
+                chain_id
+                for chain_id, _dist in sorted(
+                    zip(resolved_centre_atom.mol_id, dist_mat[0]),
+                    key=lambda pair: pair[1],
+                )
+            ]
+        )
+
+        if core_indices is not None:
+            # select core proriority
+            core_mol_id = np.unique(atom_array.mol_id[core_indices])
+            in_core_mask = np.isin(sorted_chain_id, core_mol_id)
+            sorted_chain_id = np.concatenate(
+                (sorted_chain_id[in_core_mask], sorted_chain_id[~in_core_mask])
+            )
+
+        closest_chain_id = set()
+        chain_ids_to_token_num = {}
+        if max_tokens_num is None:
+            max_tokens_num = 0
+
+        tokens = 0
+        for chain_id in sorted_chain_id:
+            # get token num
+            if chain_id not in chain_ids_to_token_num:
+                chain_ids_to_token_num[chain_id] = atom_array.centre_atom_mask[
+                    atom_array.mol_id == chain_id
+                ].sum()
+            chain_token_num = chain_ids_to_token_num[chain_id]
+
+            if len(closest_chain_id) >= max_chains_num:
+                if tokens + chain_token_num > max_tokens_num:
+                    break
+
+            closest_chain_id.add(chain_id)
+            tokens += chain_token_num
+
+        atom_array = atom_array[np.isin(atom_array.mol_id, list(closest_chain_id))]
+        output_chains_num = len(np.unique(atom_array.mol_id))
+        assert (
+            output_chains_num == max_chains_num
+            or atom_array.centre_atom_mask.sum() <= max_tokens_num
+        )
+        return atom_array, input_chains_num
+
+    @staticmethod
+    def remove_clashing_chains(
+        atom_array: AtomArray,
+        core_indices: list[int] = None,
+    ) -> AtomArray:
+        """
+        Ref: AlphaFold3 SI Chapter 2.5.4
+
+        Clashing chains are removed.
+        Clashing chains are defined as those with >30% of atoms within 1.7 Ã… of an atom in another chain.
+        If two chains are clashing with each other, the chain with the greater percentage of clashing atoms will be removed.
+        If the same fraction of atoms are clashing, the chain with fewer total atoms is removed.
+        If the chains have the same number of atoms, then the chain with the larger chain id is removed.
+
+        Note: if two chains are covalent, they are not considered as clashing.
+
+        Args:
+            atom_array (AtomArray): Biotite AtomArray Object of a Bioassembly.
+            core_indices (list[int]): A list of indices for core structures,
+                                      where these indices correspond to structures that will be preferentially
+                                      retained when pairwise clash chain assessments are performed.
+
+        Returns:
+            atom_array (AtomArray): An AtomArray that has been processed through this filter.
+            removed_chain_ids (list[str]): A list of chain IDs that have been determined for deletion.
+                                           This is to log whether the filter has been utilized.
+        """
+        chain_ids = np.unique(atom_array.chain_id[atom_array.is_resolved]).tolist()
+
+        if core_indices is not None:
+            core_chain_id = np.unique(atom_array.chain_id[core_indices])
+        else:
+            core_chain_id = np.array([])
+
+        clash_records, chain_resolved_atom_nums = Filter._get_clashing_chains(
+            atom_array, chain_ids
+        )
+        removed_chain_ids = Filter._get_removed_clash_chain_ids(
+            clash_records,
+            chain_ids,
+            chain_resolved_atom_nums,
+            core_chain_id=core_chain_id,
+        )
+
+        atom_array = atom_array[~np.isin(atom_array.chain_id, removed_chain_ids)]
+        return atom_array, removed_chain_ids
+
+    @staticmethod
+    def remove_unresolved_mols(atom_array: AtomArray) -> AtomArray:
+        """
+        Remove molecules from a bioassembly object which all atoms are not resolved.
+
+        Args:
+            atom_array (AtomArray): Biotite AtomArray Object of a bioassembly.
+
+        Returns:
+            AtomArray: An AtomArray object with unresolved molecules removed.
+        """
+        valid_mol_id = []
+        for mol_id in np.unique(atom_array.mol_id):
+            resolved = atom_array.is_resolved[atom_array.mol_id == mol_id]
+            if np.any(resolved):
+                valid_mol_id.append(mol_id)
+
+        atom_array = atom_array[np.isin(atom_array.mol_id, valid_mol_id)]
+        return atom_array
+
+    @staticmethod
+    def remove_asymmetric_polymer_ligand_bonds(
+        atom_array: AtomArray, entity_poly_type: dict
+    ) -> AtomArray:
+        """remove asymmetric polymer ligand bonds (including protein-protein bond, like disulfide bond).
+
+        AF3 SI 5.1 Structure filters
+        Bonds for structures with homomeric subcomplexes lacking the corresponding homomeric symmetry are also removed
+        - e.g. if a certain bonded ligand only exists for some of the symmetric copies, but not for all,
+        we remove the corresponding bond information from the input.
+        In consequence the model has to learn to infer these bonds by itself.
+
+        Args:
+            atom_array (AtomArray): input atom array
+
+        Returns:
+            AtomArray: output atom array with asymmetric polymer ligand bonds removed.
+        """
+        # get inter chain bonds
+        inter_chain_bonds = set()
+        for i, j, b in atom_array.bonds.as_array():
+            if atom_array.chain_id[i] != atom_array.chain_id[j]:
+                inter_chain_bonds.add((i, j))
+
+        # get asymmetric polymer ligand bonds
+        asymmetric_bonds = set()
+        chain_starts = struc.get_chain_starts(atom_array, add_exclusive_stop=False)
+        for bond in inter_chain_bonds:
+
+            if bond in asymmetric_bonds:
+                continue
+
+            i, j = bond
+            atom_i = atom_array[i]
+            atom_j = atom_array[j]
+            i_is_polymer = atom_i.label_entity_id in entity_poly_type
+            j_is_polymer = atom_j.label_entity_id in entity_poly_type
+            if i_is_polymer:
+                pass
+            elif j_is_polymer:
+                i, j = j, i
+                atom_i, atom_j = atom_j, atom_i
+                i_is_polymer, j_is_polymer = j_is_polymer, i_is_polymer
+            else:
+                # both entity is not polymer
+                continue
+
+            # get atom i mask from all entity i copies
+            entity_mask_i = atom_array.label_entity_id == atom_i.label_entity_id
+            num_copies = np.isin(chain_starts, np.flatnonzero(entity_mask_i)).sum()
+            mask_i = (
+                entity_mask_i
+                & (atom_array.res_id == atom_i.res_id)
+                & (atom_array.atom_name == atom_i.atom_name)
+            )
+            indices_i = np.flatnonzero(mask_i)
+
+            if len(indices_i) != num_copies:
+                # not every copy of entity i has atom i.
+                asymmetric_bonds.add(bond)
+                continue
+
+            # check all atom i in entity i bond to an atom j in entity j.
+            target_bonds = []
+            for ii in indices_i:
+                ii_bonds = [b for b in inter_chain_bonds if ii in b]
+                for bond in ii_bonds:
+                    jj = bond[1] if ii == bond[0] else bond[0]
+                    atom_jj = atom_array[jj]
+                    if atom_jj.label_entity_id != atom_j.label_entity_id:
+                        continue
+                    if atom_jj.res_name != atom_j.res_name:
+                        continue
+                    if atom_jj.atom_name != atom_j.atom_name:
+                        continue
+                    if j_is_polymer and atom_jj.res_id != atom_j.res_id:
+                        # only for polymer, check res_id
+                        continue
+                    # found bond (ii, jj) with same enity_id, res_name, atom_name to bond (i,j)
+                    target_bonds.append((min(ii, jj), max(ii, jj)))
+                    break
+            if len(target_bonds) != num_copies:
+                asymmetric_bonds |= set(target_bonds)
+
+        for bond in asymmetric_bonds:
+            atom_array.bonds.remove_bond(bond[0], bond[1])
+        return atom_array
diff --git a/protenix/protenix/data/json_maker.py b/protenix/protenix/data/json_maker.py
index 99855fe39663abbfcf22a4347413285d095fd369..39b1213320e61ad7ecdce1889fe2c7bd21e8bc11 100644
--- a/protenix/protenix/data/json_maker.py
+++ b/protenix/protenix/data/json_maker.py
@@ -46,12 +46,12 @@ def merge_covalent_bonds(
     for bond_dict in covalent_bonds:
         bond_unique_string = []
         entity_counts = (
-            all_entity_counts[str(bond_dict["left_entity"])],
-            all_entity_counts[str(bond_dict["right_entity"])],
+            all_entity_counts[str(bond_dict["entity1"])],
+            all_entity_counts[str(bond_dict["entity2"])],
         )
-        for i in ["left", "right"]:
+        for i in range(2):
             for j in ["entity", "position", "atom"]:
-                k = f"{i}_{j}"
+                k = f"{j}{i+1}"
                 bond_unique_string.append(str(bond_dict[k]))
         bond_unique_string = "_".join(bond_unique_string)
         bonds_recorder[bond_unique_string].append(bond_dict)
@@ -59,12 +59,12 @@ def merge_covalent_bonds(
 
     merged_covalent_bonds = []
     for k, v in bonds_recorder.items():
-        left_counts = bonds_entity_counts[k][0]
-        right_counts = bonds_entity_counts[k][1]
-        if left_counts == right_counts == len(v):
+        counts1 = bonds_entity_counts[k][0]
+        counts2 = bonds_entity_counts[k][1]
+        if counts1 == counts2 == len(v):
             bond_dict_copy = copy.deepcopy(v[0])
-            del bond_dict_copy["left_copy"]
-            del bond_dict_copy["right_copy"]
+            del bond_dict_copy["copy1"]
+            del bond_dict_copy["copy2"]
             merged_covalent_bonds.append(bond_dict_copy)
         else:
             merged_covalent_bonds.extend(v)
@@ -217,15 +217,15 @@ def atom_array_to_input_json(
         covalent_bonds = []
         for atoms in inter_entity_bonds[:, :2]:
             bond_dict = {}
-            for idx, i in enumerate(["left", "right"]):
-                atom = atom_array[atoms[idx]]
+            for i in range(2):
+                atom = atom_array[atoms[i]]
                 positon = atom.res_id
-                bond_dict[f"{i}_entity"] = int(
+                bond_dict[f"entity{i+1}"] = int(
                     label_entity_id_to_entity_id_in_json[atom.label_entity_id]
                 )
-                bond_dict[f"{i}_position"] = int(positon)
-                bond_dict[f"{i}_atom"] = atom.atom_name
-                bond_dict[f"{i}_copy"] = int(atom.copy_id)
+                bond_dict[f"position{i+1}"] = int(positon)
+                bond_dict[f"atom{i+1}"] = atom.atom_name
+                bond_dict[f"copy{i+1}"] = int(atom.copy_id)
 
             covalent_bonds.append(bond_dict)
 
@@ -299,8 +299,15 @@ def cif_to_input_json(
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument("--cif_file", type=str, required=True, help="The cif file to parse")
-    parser.add_argument("--json_file", type=str, required=False, default=None, help="The json file path to generate")
+    parser.add_argument(
+        "--cif_file", type=str, required=True, help="The cif file to parse"
+    )
+    parser.add_argument(
+        "--json_file",
+        type=str,
+        required=False,
+        default=None,
+        help="The json file path to generate",
+    )
     args = parser.parse_args()
     print(cif_to_input_json(args.cif_file, output_json=args.json_file))
-
diff --git a/protenix/protenix/data/json_parser.py b/protenix/protenix/data/json_parser.py
index a40bf42f5764844d4fa0b3a620ccd2cc09dee638..20fd3163943f31c47472d479c69ab23875aaec1b 100644
--- a/protenix/protenix/data/json_parser.py
+++ b/protenix/protenix/data/json_parser.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import concurrent.futures
 import copy
 import logging
 import random
@@ -500,7 +501,22 @@ def smiles_to_atom_info(smiles: str) -> dict:
     atom_info = {}
     mol = Chem.MolFromSmiles(smiles)
     mol = Chem.AddHs(mol)
-    ret_code = AllChem.EmbedMolecule(mol)
+
+    with concurrent.futures.ThreadPoolExecutor() as executor:
+        future = executor.submit(AllChem.EmbedMolecule, mol)
+
+        try:
+            ret_code = future.result(timeout=90)
+        except concurrent.futures.TimeoutError as exc:
+            raise TimeoutError(
+                'Conformer generation timed out.  \
+                    Please change the "ligand" input format to "CCD_" or "FILE_".'
+            ) from exc
+
+    if ret_code != 0:
+        # retry with random coords
+        ret_code = AllChem.EmbedMolecule(mol, useRandomCoords=True)
+
     assert ret_code == 0, f"Conformer generation failed for input SMILES: {smiles}"
     atom_info = rdkit_mol_to_atom_info(mol)
     return atom_info
diff --git a/protenix/protenix/data/json_to_feature.py b/protenix/protenix/data/json_to_feature.py
index 92cb2814860f6406fdb21768266d944a45127a06..e1f46a1fe8dc16d438e12d2df3d956dc1d19b2bd 100644
--- a/protenix/protenix/data/json_to_feature.py
+++ b/protenix/protenix/data/json_to_feature.py
@@ -167,11 +167,26 @@ class SampleDictToFeatures:
         bond_count = {}
         for bond_info_dict in self.input_dict["covalent_bonds"]:
             bond_atoms = []
-            for i in ["left", "right"]:
-                entity_id = int(bond_info_dict[f"{i}_entity"])
-                copy_id = int(bond_info_dict.get(f"{i}_copy"))
-                position = int(bond_info_dict[f"{i}_position"])
-                atom_name = bond_info_dict[f"{i}_atom"]
+            for idx, i in enumerate(["left", "right"]):
+                entity_id = int(
+                    bond_info_dict.get(
+                        f"{i}_entity", bond_info_dict.get(f"entity{idx+1}")
+                    )
+                )
+                copy_id = bond_info_dict.get(
+                    f"{i}_copy", bond_info_dict.get(f"copy{idx+1}")
+                )
+                position = int(
+                    bond_info_dict.get(
+                        f"{i}_position", bond_info_dict.get(f"position{idx+1}")
+                    )
+                )
+                atom_name = bond_info_dict.get(
+                    f"{i}_atom", bond_info_dict.get(f"atom{idx+1}")
+                )
+
+                if copy_id is not None:
+                    copy_id = int(copy_id)
 
                 if isinstance(atom_name, str):
                     if atom_name.isdigit():
@@ -194,11 +209,11 @@ class SampleDictToFeatures:
                     atom_indices.size > 0
                 ), f"No atom found for {atom_name} in entity {entity_id} at position {position}."
                 bond_atoms.append(atom_indices)
-
-            assert len(bond_atoms[0]) == len(
-                bond_atoms[1]
-            ), f'Can not create bonds because the "count" of entity {bond_info_dict["left_entity"]} \
-                and {bond_info_dict["right_entity"]} are not equal. '
+            assert len(bond_atoms[0]) == len(bond_atoms[1]), (
+                'Can not create bonds because the "count" of entity1 '
+                f'({bond_info_dict.get("left_entity", bond_info_dict.get("entity1"))}) '
+                f'and entity2 ({bond_info_dict.get("right_entity", bond_info_dict.get("entity2"))}) are not equal. '
+            )
 
             # Create bond between each asym chain pair
             for atom_idx1, atom_idx2 in zip(bond_atoms[0], bond_atoms[1]):
diff --git a/protenix/protenix/data/msa_featurizer.py b/protenix/protenix/data/msa_featurizer.py
index c2df9d3b4b4627ea5104c129c6fcd3a59eacd9b4..284d8758b43b9abc7c32cff80bd61a6544fab118 100644
--- a/protenix/protenix/data/msa_featurizer.py
+++ b/protenix/protenix/data/msa_featurizer.py
@@ -915,7 +915,6 @@ def merge_all_chain_features(
     )
     if msa_entity_type == "rna":
         np_example = rna_merge(
-            is_homomer_or_monomer=is_homomer_or_monomer,
             all_chain_features=all_chain_features,
             merge_method=merge_method,
             msa_crop_size=max_size,
diff --git a/protenix/protenix/data/parser.py b/protenix/protenix/data/parser.py
index a64047d450c32e08c164b8b70d9a5cf4fd028a23..682b89222efa027c6890c83a926252c954dd8f9c 100644
--- a/protenix/protenix/data/parser.py
+++ b/protenix/protenix/data/parser.py
@@ -16,9 +16,11 @@ import copy
 import functools
 import gzip
 import logging
-from collections import Counter
+import random
+from collections import Counter, defaultdict
+from datetime import datetime
 from pathlib import Path
-from typing import Optional, Union
+from typing import Any, Optional, Union
 
 import biotite.structure as struc
 import biotite.structure.io.pdbx as pdbx
@@ -31,13 +33,24 @@ from biotite.structure.molecules import get_molecule_indices
 from protenix.data import ccd
 from protenix.data.ccd import get_ccd_ref_info
 from protenix.data.constants import (
+    CRYSTALLIZATION_METHODS,
     DNA_STD_RESIDUES,
+    GLYCANS,
+    LIGAND_EXCLUSION,
+    PRO_STD_RESIDUES,
     PROT_STD_RESIDUES_ONE_TO_THREE,
     RES_ATOMS_DICT,
     RNA_STD_RESIDUES,
     STD_RESIDUES,
 )
-from protenix.data.utils import get_starts_by
+from protenix.data.filter import Filter
+from protenix.data.utils import (
+    atom_select,
+    get_inter_residue_bonds,
+    get_ligand_polymer_bond_mask,
+    get_starts_by,
+    parse_pdb_cluster_file_to_dict,
+)
 
 logger = logging.getLogger(__name__)
 
@@ -47,7 +60,11 @@ if "metalc" in pdbx_convert.PDBX_COVALENT_TYPES:  # for reload
 
 
 class MMCIFParser:
-    def __init__(self, mmcif_file: Union[str, Path]) -> None:
+    """
+    Parsing and extracting information from mmCIF files.
+    """
+
+    def __init__(self, mmcif_file: Union[str, Path]):
         self.cif = self._parse(mmcif_file=mmcif_file)
 
     def _parse(self, mmcif_file: Union[str, Path]) -> pdbx.CIFFile:
@@ -61,12 +78,120 @@ class MMCIFParser:
         return cif_file
 
     def get_category_table(self, name: str) -> Union[pd.DataFrame, None]:
+        """
+        Retrieve a category table from the CIF block and return it as a pandas DataFrame.
+
+        Args:
+            name (str): The name of the category to retrieve from the CIF block.
+
+        Returns:
+            Union[pd.DataFrame, None]: A pandas DataFrame containing the category data if the category exists,
+                                       otherwise None.
+        """
         if name not in self.cif.block:
             return None
         category = self.cif.block[name]
         category_dict = {k: column.as_array() for k, column in category.items()}
         return pd.DataFrame(category_dict, dtype=str)
 
+    @functools.cached_property
+    def pdb_id(self) -> str:
+        """
+        Extracts and returns the PDB ID from the CIF block.
+
+        Returns:
+            str: The PDB ID in lowercase if present, otherwise an empty string.
+        """
+
+        if "entry" not in self.cif.block:
+            return ""
+        else:
+            return self.cif.block["entry"]["id"].as_item().lower()
+
+    def num_assembly_polymer_chains(self, assembly_id: str = "1") -> int:
+        """
+        Calculate the number of polymer chains in a specified assembly.
+
+        Args:
+            assembly_id (str): The ID of the assembly to count polymer chains for.
+                               Defaults to "1". If "all", counts chains for all assemblies.
+
+        Returns:
+            int: The total number of polymer chains in the specified assembly.
+                 If the oligomeric count is invalid (e.g., '?'), the function returns None.
+        """
+        chain_count = 0
+        for _assembly_id, _chain_count in zip(
+            self.cif.block["pdbx_struct_assembly"]["id"].as_array(),
+            self.cif.block["pdbx_struct_assembly"]["oligomeric_count"].as_array(),
+        ):
+            if assembly_id == "all" or _assembly_id == assembly_id:
+                try:
+                    chain_count += int(_chain_count)
+                except ValueError:
+                    # oligomeric_count == '?'.  e.g. 1hya.cif
+                    return
+        return chain_count
+
+    @functools.cached_property
+    def resolution(self) -> float:
+        """
+        Get resolution for X-ray and cryoEM.
+        Some methods don't have resolution, set as -1.0
+
+        Returns:
+            float: resolution (set to -1.0 if not found)
+        """
+        block = self.cif.block
+        resolution_names = [
+            "refine.ls_d_res_high",
+            "em_3d_reconstruction.resolution",
+            "reflns.d_resolution_high",
+        ]
+        for category_item in resolution_names:
+            category, item = category_item.split(".")
+            if category in block and item in block[category]:
+                try:
+                    resolution = block[category][item].as_array(float)[0]
+                    # "." will be converted to 0.0, but it is not a valid resolution.
+                    if resolution == 0.0:
+                        continue
+                    return resolution
+                except ValueError:
+                    # in some cases, resolution_str is "?"
+                    continue
+        return -1.0
+
+    @functools.cached_property
+    def release_date(self) -> str:
+        """
+        Get first release date.
+
+        Returns:
+            str: yyyy-mm-dd
+        """
+
+        def _is_valid_date_format(date_string):
+            try:
+                datetime.strptime(date_string, "%Y-%m-%d")
+                return True
+            except ValueError:
+                return False
+
+        if "pdbx_audit_revision_history" in self.cif.block:
+            history = self.cif.block["pdbx_audit_revision_history"]
+            # np.str_ is inherit from str, so return is str
+            date = history["revision_date"].as_array()[0]
+        else:
+            # no release date
+            date = "9999-12-31"
+
+        valid_date = _is_valid_date_format(date)
+        assert (
+            valid_date
+        ), f"Invalid date format: {date}, it should be yyyy-mm-dd format"
+        return date
+
     @staticmethod
     def mse_to_met(atom_array: AtomArray) -> AtomArray:
         """
@@ -87,6 +212,42 @@ class MMCIFParser:
         atom_array.hetero[mse] = False
         return atom_array
 
+    @staticmethod
+    def fix_arginine(atom_array: AtomArray) -> AtomArray:
+        """
+        Ref: AlphaFold3 SI chapter 2.1
+        Arginine naming ambiguities are fixed (ensuring NH1 is always closer to CD than NH2).
+
+        Args:
+            atom_array (AtomArray): Biotite AtomArray object.
+
+        Returns:
+            AtomArray: Biotite AtomArray object after fix arginine .
+        """
+
+        starts = struc.get_residue_starts(atom_array, add_exclusive_stop=True)
+        for start_i, stop_i in zip(starts[:-1], starts[1:]):
+            if atom_array[start_i].res_name != "ARG":
+                continue
+            cd_idx, nh1_idx, nh2_idx = None, None, None
+            for idx in range(start_i, stop_i):
+                if atom_array.atom_name[idx] == "CD":
+                    cd_idx = idx
+                if atom_array.atom_name[idx] == "NH1":
+                    nh1_idx = idx
+                if atom_array.atom_name[idx] == "NH2":
+                    nh2_idx = idx
+            if cd_idx and nh1_idx and nh2_idx:  # all not None
+                cd_nh1 = atom_array.coord[nh1_idx] - atom_array.coord[cd_idx]
+                d2_cd_nh1 = np.sum(cd_nh1**2)
+                cd_nh2 = atom_array.coord[nh2_idx] - atom_array.coord[cd_idx]
+                d2_cd_nh2 = np.sum(cd_nh2**2)
+                if d2_cd_nh2 < d2_cd_nh1:
+                    atom_array.coord[[nh1_idx, nh2_idx]] = atom_array.coord[
+                        [nh2_idx, nh1_idx]
+                    ]
+        return atom_array
+
     @functools.cached_property
     def methods(self) -> list[str]:
         """the methods to get the structure
@@ -146,7 +307,9 @@ class MMCIFParser:
 
             seq_nums = entity_poly_seq.num[chain_mask].to_numpy(dtype=int)
 
-            if np.unique(seq_nums).size == seq_nums.size:
+            uniq_seq_num = np.unique(seq_nums).size
+
+            if uniq_seq_num == seq_nums.size:
                 # no altloc residues
                 poly_res_names[entity_id] = seq_mon_ids
                 continue
@@ -171,10 +334,12 @@ class MMCIFParser:
                 if select_mask[i]:
                     matching_res_id += 1
 
-            seq_mon_ids = seq_mon_ids[select_mask]
-            seq_nums = seq_nums[select_mask]
-            assert len(seq_nums) == max(seq_nums)
-            poly_res_names[entity_id] = seq_mon_ids
+            new_seq_mon_ids = seq_mon_ids[select_mask]
+            new_seq_nums = seq_nums[select_mask]
+            assert (
+                len(new_seq_nums) == uniq_seq_num
+            ), f"seq_nums not match:\n{seq_nums=}\n{new_seq_nums=}\n{seq_mon_ids=}\n{new_seq_mon_ids=}"
+            poly_res_names[entity_id] = new_seq_mon_ids
         return poly_res_names
 
     def get_sequences(self, atom_array=None) -> dict:
@@ -217,34 +382,99 @@ class MMCIFParser:
 
         return {i: t for i, t in zip(entity_poly.entity_id, entity_poly.type)}
 
-    def filter_altloc(self, atom_array: AtomArray, altloc: str = "first"):
+    def filter_altloc(self, atom_array: AtomArray, altloc: str = "first") -> AtomArray:
         """
-        altloc: "first", "A", "B", "global_largest", etc
+        Filter alternate conformations (altloc) of a given AtomArray based on the specified criteria.
+        For example, in 2PXS, there are two res_name (XYG|DYG) at res_id 63.
 
-        Filter first alternate coformation (altloc) of a given AtomArray.
-        - normally first altloc_id is 'A'
-        - but in one case, first altloc_id is '1' in 6uwi.cif
+        Args:
+            atom_array : AtomArray
+                The array of atoms to filter.
+            altloc : str, optional
+                The criteria for filtering alternate conformations. Possible values are:
+                - "first": Keep the first alternate conformation.
+                - "all": Keep all alternate conformations.
+                - "A", "B", etc.: Keep the specified alternate conformation.
+                - "global_largest": Keep the alternate conformation with the largest average occupancy.
 
-        biotite v0.41 can not handle diff res_name at same res_id.
-        For example, in 2pxs.cif, there are two res_name (XYG|DYG) at res_id 63,
-        need to keep the first XYG.
+        Returns:
+            AtomArray
+                The filtered AtomArray based on the specified altloc criteria.
         """
         if altloc == "all":
             return atom_array
 
-        altloc_id = altloc
-        if altloc == "first":
+        elif altloc == "first":
             letter_altloc_ids = np.unique(atom_array.label_alt_id)
             if len(letter_altloc_ids) == 1 and letter_altloc_ids[0] == ".":
                 return atom_array
             letter_altloc_ids = letter_altloc_ids[letter_altloc_ids != "."]
             altloc_id = np.sort(letter_altloc_ids)[0]
+            return atom_array[np.isin(atom_array.label_alt_id, [altloc_id, "."])]
+
+        elif altloc == "global_largest":
+            occ_dict = defaultdict(list)
+            res_altloc = defaultdict(list)
+
+            res_starts = get_residue_starts(atom_array, add_exclusive_stop=True)
+            for res_start, _res_end in zip(res_starts[:-1], res_starts[1:]):
+                altloc_char = atom_array.label_alt_id[res_start]
+                if altloc_char == ".":
+                    continue
+
+                occupency = atom_array.occupancy[res_start]
+                occ_dict[altloc_char].append(occupency)
+
+                chain_id = atom_array.chain_id[res_start]
+                res_id = atom_array.res_id[res_start]
+                res_altloc[(chain_id, res_id)].append(altloc_char)
+
+            alt_and_avg_occ = [
+                (altloc_char, np.mean(occ_list))
+                for altloc_char, occ_list in occ_dict.items()
+            ]
+            sorted_altloc_chars = [
+                i[0] for i in sorted(alt_and_avg_occ, key=lambda x: x[1], reverse=True)
+            ]
 
-        return atom_array[np.isin(atom_array.label_alt_id, [altloc_id, "."])]
+            selected_mask = np.zeros(len(atom_array), dtype=bool)
+            for res_start, res_end in zip(res_starts[:-1], res_starts[1:]):
+                chain_id = atom_array.chain_id[res_start]
+                res_id = atom_array.res_id[res_start]
+                altloc_char = atom_array.label_alt_id[res_start]
+
+                if altloc_char == ".":
+                    selected_mask[res_start:res_end] = True
+                else:
+                    res_sorted_altloc = [
+                        i
+                        for i in sorted_altloc_chars
+                        if i in res_altloc[(chain_id, res_id)]
+                    ]
+                    selected_altloc = res_sorted_altloc[0]
+                    if altloc_char == selected_altloc:
+                        selected_mask[res_start:res_end] = True
+            return atom_array[selected_mask]
+
+        else:
+            return atom_array[np.isin(atom_array.label_alt_id, [altloc, "."])]
 
     @staticmethod
     def replace_auth_with_label(atom_array: AtomArray) -> AtomArray:
-        # fix issue https://github.com/biotite-dev/biotite/issues/553
+        """
+        Replace the author-provided chain ID with the label asym ID in the given AtomArray.
+
+        This function addresses the issue described in https://github.com/biotite-dev/biotite/issues/553.
+        It updates the `chain_id` of the `atom_array` to match the `label_asym_id` and resets the ligand
+        residue IDs (`res_id`) for chains where the `label_seq_id` is ".". The residue IDs are reset
+        sequentially starting from 1 within each chain.
+
+        Args:
+            atom_array (AtomArray): The input AtomArray object to be modified.
+
+        Returns:
+            AtomArray: The modified AtomArray with updated chain IDs and residue IDs.
+        """
         atom_array.chain_id = atom_array.label_asym_id
 
         # reset ligand res_id
@@ -444,6 +674,1093 @@ class MMCIFParser:
             assembly.set_annotation("assembly_1", np.array(assembly_1_mask))
         return assembly
 
+    def _get_core_indices(self, atom_array):
+        if "assembly_1" in atom_array._annot:
+            core_indices = np.where(atom_array.assembly_1)[0]
+        else:
+            core_indices = None
+        return core_indices
+
+    def get_bioassembly(
+        self,
+        assembly_id: str = "1",
+        max_assembly_chains: int = 1000,
+    ) -> dict[str, Any]:
+        """
+        Build the given biological assembly.
+
+        Args:
+            assembly_id (str, optional): Assembly ID. Defaults to "1".
+            max_assembly_chains (int, optional): Max allowed chains in the assembly. Defaults to 1000.
+
+        Returns:
+            dict[str, Any]: A dictionary containing basic Bioassembly information, including:
+                - "pdb_id": The PDB ID.
+                - "sequences": The sequences associated with the assembly.
+                - "release_date": The release date of the structure.
+                - "assembly_id": The assembly ID.
+                - "num_assembly_polymer_chains": The number of polymer chains in the assembly.
+                - "num_prot_chains": The number of protein chains in the assembly.
+                - "entity_poly_type": The type of polymer entities.
+                - "resolution": The resolution of the structure. Set to -1.0 if resolution not found.
+                - "atom_array": The AtomArray object representing the structure.
+                - "num_tokens": The number of tokens in the AtomArray.
+        """
+        num_assembly_polymer_chains = self.num_assembly_polymer_chains(assembly_id)
+        bioassembly_dict = {
+            "pdb_id": self.pdb_id,
+            "sequences": self.get_sequences(),  # label_entity_id --> canonical_sequence
+            "release_date": self.release_date,
+            "assembly_id": assembly_id,
+            "num_assembly_polymer_chains": num_assembly_polymer_chains,
+            "num_prot_chains": -1,
+            "entity_poly_type": self.entity_poly_type,
+            "resolution": self.resolution,
+            "atom_array": None,
+        }
+        if (not num_assembly_polymer_chains) or (
+            num_assembly_polymer_chains > max_assembly_chains
+        ):
+            return bioassembly_dict
+
+        # created AtomArray of first model from mmcif atom_site (Asymmetric Unit)
+        atom_array = self.get_structure()
+
+        # convert MSE to MET to consistent with MMCIFParser.get_poly_res_names()
+        atom_array = self.mse_to_met(atom_array)
+
+        # update sequences: keep same altloc residue with atom_array
+        bioassembly_dict["sequences"] = self.get_sequences(atom_array)
+
+        pipeline_functions = [
+            Filter.remove_water,
+            Filter.remove_hydrogens,
+            lambda aa: Filter.remove_polymer_chains_all_residues_unknown(
+                aa, self.entity_poly_type
+            ),
+            # Note: Filter.remove_polymer_chains_too_short not being used
+            lambda aa: Filter.remove_polymer_chains_with_consecutive_c_alpha_too_far_away(
+                aa, self.entity_poly_type
+            ),
+            self.fix_arginine,
+            self.add_missing_atoms_and_residues,  # and add annotation is_resolved (False for missing atoms)
+            Filter.remove_element_X,  # remove X element (including ASX->ASP, GLX->GLU) after add_missing_atoms_and_residues()
+        ]
+
+        if set(self.methods) & CRYSTALLIZATION_METHODS:
+            # AF3 SI 2.5.4 Crystallization aids are removed if the mmCIF method information indicates that crystallography was used.
+            pipeline_functions.append(
+                lambda aa: Filter.remove_crystallization_aids(aa, self.entity_poly_type)
+            )
+
+        for func in pipeline_functions:
+            atom_array = func(atom_array)
+            if len(atom_array) == 0:
+                # no atoms left
+                return bioassembly_dict
+
+        atom_array = AddAtomArrayAnnot.add_token_mol_type(
+            atom_array, self.entity_poly_type
+        )
+        atom_array = AddAtomArrayAnnot.add_centre_atom_mask(atom_array)
+        atom_array = AddAtomArrayAnnot.add_atom_mol_type_mask(atom_array)
+        atom_array = AddAtomArrayAnnot.add_distogram_rep_atom_mask(atom_array)
+        atom_array = AddAtomArrayAnnot.add_plddt_m_rep_atom_mask(atom_array)
+        atom_array = AddAtomArrayAnnot.add_cano_seq_resname(atom_array)
+        atom_array = AddAtomArrayAnnot.add_tokatom_idx(atom_array)
+        atom_array = AddAtomArrayAnnot.add_modified_res_mask(atom_array)
+        assert (
+            atom_array.centre_atom_mask.sum()
+            == atom_array.distogram_rep_atom_mask.sum()
+        )
+
+        # expand created AtomArray by expand bioassembly
+        atom_array = self.expand_assembly(atom_array, assembly_id)
+
+        if len(atom_array) == 0:
+            # If no chains corresponding to the assembly_id remain in the AtomArray
+            # expand_assembly will return an empty AtomArray.
+            return bioassembly_dict
+
+        # reset the coords after expand assembly
+        atom_array.coord[~atom_array.is_resolved, :] = 0.0
+
+        # rename chain_ids from A A B to A0 A1 B0 and add asym_id_int, entity_id_int, sym_id_int
+        atom_array = AddAtomArrayAnnot.unique_chain_and_add_ids(atom_array)
+
+        # get chain id before remove chains
+        core_indices = self._get_core_indices(atom_array)
+        if core_indices is not None:
+            ori_chain_ids = np.unique(atom_array.chain_id[core_indices])
+        else:
+            ori_chain_ids = np.unique(atom_array.chain_id)
+
+        atom_array = AddAtomArrayAnnot.add_mol_id(atom_array)
+        atom_array = Filter.remove_unresolved_mols(atom_array)
+
+        # update core indices after remove unresolved mols
+        core_indices = np.where(np.isin(atom_array.chain_id, ori_chain_ids))[0]
+
+        # If the number of chains has already reached `max_chains_num`, but the token count hasn't reached `max_tokens_num`,
+        # chains will continue to be added until `max_tokens_num` is exceeded.
+        atom_array, _input_chains_num = Filter.too_many_chains_filter(
+            atom_array,
+            core_indices=core_indices,
+            max_chains_num=20,
+            max_tokens_num=5120,
+        )
+
+        if atom_array is None:
+            # The distance between the central atoms in any two chains is greater than 15 angstroms.
+            return bioassembly_dict
+
+        # update core indices after too_many_chains_filter
+        core_indices = np.where(np.isin(atom_array.chain_id, ori_chain_ids))[0]
+
+        atom_array, _removed_chain_ids = Filter.remove_clashing_chains(
+            atom_array, core_indices=core_indices
+        )
+
+        # remove asymmetric polymer ligand bonds (including protein-protein bond, like disulfide bond)
+        # apply to assembly atom array
+        atom_array = Filter.remove_asymmetric_polymer_ligand_bonds(
+            atom_array, self.entity_poly_type
+        )
+
+        # add_mol_id before applying the two filters below to ensure that covalent components are not removed as individual chains.
+        atom_array = AddAtomArrayAnnot.find_equiv_mol_and_assign_ids(
+            atom_array, self.entity_poly_type
+        )
+
+        # numerical encoding of (chain id, residue index)
+        atom_array = AddAtomArrayAnnot.add_ref_space_uid(atom_array)
+        atom_array = AddAtomArrayAnnot.add_ref_info_and_res_perm(atom_array)
+
+        # the number of protein chains in the assembly
+        prot_label_entity_ids = [
+            k for k, v in self.entity_poly_type.items() if "polypeptide" in v
+        ]
+        num_prot_chains = len(
+            np.unique(
+                atom_array.chain_id[
+                    np.isin(atom_array.label_entity_id, prot_label_entity_ids)
+                ]
+            )
+        )
+        bioassembly_dict["num_prot_chains"] = num_prot_chains
+
+        bioassembly_dict["atom_array"] = atom_array
+        bioassembly_dict["num_tokens"] = atom_array.centre_atom_mask.sum()
+        return bioassembly_dict
+
+    @staticmethod
+    def create_empty_annotation_like(
+        source_array: AtomArray, target_array: AtomArray
+    ) -> AtomArray:
+        """create empty annotation like source_array"""
+        # create empty annotation, atom array addition only keep common annotation
+        for k, v in source_array._annot.items():
+            if k not in target_array._annot:
+                target_array._annot[k] = np.zeros(len(target_array), dtype=v.dtype)
+        return target_array
+
+    @staticmethod
+    def find_non_ccd_leaving_atoms(
+        atom_array: AtomArray,
+        select_dict: dict[str, Any],
+        component: AtomArray,
+    ) -> list[str]:
+        """ "
+        handle mismatch bettween CCD and mmcif
+        some residue has bond in non-central atom (without leaving atoms in CCD)
+        and its neighbors should be removed like atom_array from mmcif.
+
+        Args:
+            atom_array (AtomArray): Biotite AtomArray object from mmcif.
+            select_dict dict[str, Any]: entity_id, res_id, atom_name,... of central atom in atom_array.
+            component (AtomArray): CCD component AtomArray object.
+
+        Returns:
+            list[str]: list of atom_name to be removed.
+        """
+        # find non-CCD central atoms in atom_array
+        indices_in_atom_array = atom_select(atom_array, select_dict)
+
+        if len(indices_in_atom_array) == 0:
+            return []
+
+        if component.bonds is None:
+            return []
+
+        # atom_name not in CCD component, return []
+        atom_name = select_dict["atom_name"]
+        idx_in_comp = np.where(component.atom_name == atom_name)[0]
+        if len(idx_in_comp) == 0:
+            return []
+        idx_in_comp = idx_in_comp[0]
+
+        # find non-CCD leaving atoms in atom_array
+        remove_atom_names = []
+        for idx in indices_in_atom_array:
+            neighbor_idx, types = atom_array.bonds.get_bonds(idx)
+            ref_neighbor_idx, types = component.bonds.get_bonds(idx_in_comp)
+            # neighbor_atom only bond to central atom in CCD component
+            ref_neighbor_idx = [
+                i for i in ref_neighbor_idx if len(component.bonds.get_bonds(i)[0]) == 1
+            ]
+            removed_mask = ~np.isin(
+                component.atom_name[ref_neighbor_idx],
+                atom_array.atom_name[neighbor_idx],
+            )
+            remove_atom_names.append(
+                component.atom_name[ref_neighbor_idx][removed_mask].tolist()
+            )
+        max_id = np.argmax(map(len, remove_atom_names))
+        return remove_atom_names[max_id]
+
+    def build_ref_chain_with_atom_array(self, atom_array: AtomArray) -> AtomArray:
+        """
+        build ref chain with atom_array and poly_res_names
+        """
+        # count inter residue bonds of each potential central atom for removing leaving atoms later
+        central_bond_count = Counter()  # (entity_id,res_id,atom_name) -> bond_count
+
+        # build reference entity atom array, including missing residues
+        poly_res_names = self.get_poly_res_names(atom_array)
+        entity_atom_array = {}
+        for entity_id, poly_type in self.entity_poly_type.items():
+            chain = struc.AtomArray(0)
+            for res_id, res_name in enumerate(poly_res_names[entity_id]):
+                # keep all leaving atoms, will remove leaving atoms later in this function
+                residue = ccd.get_component_atom_array(
+                    res_name, keep_leaving_atoms=True, keep_hydrogens=False
+                )
+                residue.res_id[:] = res_id + 1
+                chain += residue
+            res_starts = struc.get_residue_starts(chain, add_exclusive_stop=True)
+            inter_bonds = ccd._connect_inter_residue(chain, res_starts)
+
+            # filter out non-std polymer bonds
+            bond_mask = np.ones(len(inter_bonds._bonds), dtype=bool)
+            for b_idx, (atom_i, atom_j, b_type) in enumerate(inter_bonds._bonds):
+                idx_i = atom_select(
+                    atom_array,
+                    {
+                        "label_entity_id": entity_id,
+                        "res_id": chain.res_id[atom_i],
+                        "atom_name": chain.atom_name[atom_i],
+                    },
+                )
+                idx_j = atom_select(
+                    atom_array,
+                    {
+                        "label_entity_id": entity_id,
+                        "res_id": chain.res_id[atom_j],
+                        "atom_name": chain.atom_name[atom_j],
+                    },
+                )
+                for i in idx_i:
+                    for j in idx_j:
+                        # both i, j exist in same chain but not bond in atom_array, non-std polymer bonds, remove from chain
+                        if atom_array.chain_id[i] == atom_array.chain_id[j]:
+                            bonds, types = atom_array.bonds.get_bonds(i)
+                            if j not in bonds:
+                                bond_mask[b_idx] = False
+                                break
+
+                if bond_mask[b_idx]:
+                    # keep this bond, add to central_bond_count
+                    central_atom_idx = (
+                        atom_i if chain.atom_name[atom_i] in ("C", "P") else atom_j
+                    )
+                    atom_key = (
+                        entity_id,
+                        chain.res_id[central_atom_idx],
+                        chain.atom_name[central_atom_idx],
+                    )
+                    # use ref chain bond count if no inter bond in atom_array.
+                    central_bond_count[atom_key] = 1
+
+            inter_bonds._bonds = inter_bonds._bonds[bond_mask]
+            chain.bonds = chain.bonds.merge(inter_bonds)
+
+            chain.hetero[:] = False
+            entity_atom_array[entity_id] = chain
+
+        # remove leaving atoms of residues based on atom_array
+
+        # count inter residue bonds from atom_array for removing leaving atoms later
+        inter_residue_bonds = get_inter_residue_bonds(atom_array)
+        for i in inter_residue_bonds.flat:
+            bonds, types = atom_array.bonds.get_bonds(i)
+            bond_count = (
+                (atom_array.res_id[bonds] != atom_array.res_id[i])
+                | (atom_array.chain_id[bonds] != atom_array.chain_id[i])
+            ).sum()
+            atom_key = (
+                atom_array.label_entity_id[i],
+                atom_array.res_id[i],
+                atom_array.atom_name[i],
+            )
+            # remove leaving atoms if central atom has inter residue bond in any copy of a entity
+            central_bond_count[atom_key] = max(central_bond_count[atom_key], bond_count)
+
+        # remove leaving atoms for each central atom based in atom_array info
+        # so the residue in reference chain can be used directly.
+        for entity_id, chain in entity_atom_array.items():
+            keep_atom_mask = np.ones(len(chain), dtype=bool)
+            starts = struc.get_residue_starts(chain, add_exclusive_stop=True)
+            for start, stop in zip(starts[:-1], starts[1:]):
+                res_name = chain.res_name[start]
+                remove_atom_names = []
+                for i in range(start, stop):
+                    central_atom_name = chain.atom_name[i]
+                    atom_key = (entity_id, chain.res_id[i], central_atom_name)
+                    inter_bond_count = central_bond_count[atom_key]
+
+                    if inter_bond_count == 0:
+                        continue
+
+                    # num of remove leaving groups equals to num of inter residue bonds (inter_bond_count)
+                    component = ccd.get_component_atom_array(
+                        res_name, keep_leaving_atoms=True
+                    )
+
+                    if component.central_to_leaving_groups is None:
+                        # The leaving atoms might be labeled wrongly. The residue remains as it is.
+                        break
+
+                    # central_to_leaving_groups:dict[str, list[list[str]]], central atom name to leaving atom groups (atom names).
+                    if central_atom_name in component.central_to_leaving_groups:
+                        leaving_groups = component.central_to_leaving_groups[
+                            central_atom_name
+                        ]
+                        # removed only when there are leaving atoms.
+                        if inter_bond_count >= len(leaving_groups):
+                            remove_groups = leaving_groups
+                        else:
+                            # subsample leaving atoms, keep resolved leaving atoms first
+                            exist_group = []
+                            not_exist_group = []
+                            for group in leaving_groups:
+                                for leaving_atom_name in group:
+                                    atom_idx = atom_select(
+                                        atom_array,
+                                        select_dict={
+                                            "label_entity_id": entity_id,
+                                            "res_id": chain.res_id[i],
+                                            "atom_name": leaving_atom_name,
+                                        },
+                                    )
+                                    if len(atom_idx) > 0:  # resolved
+                                        exist_group.append(group)
+                                        break
+                                else:
+                                    not_exist_group.append(group)
+                            if inter_bond_count <= len(not_exist_group):
+                                remove_groups = random.sample(
+                                    not_exist_group, inter_bond_count
+                                )
+                            else:
+                                remove_groups = not_exist_group + random.sample(
+                                    exist_group, inter_bond_count - len(not_exist_group)
+                                )
+                        names = [name for group in remove_groups for name in group]
+                        remove_atom_names.extend(names)
+
+                    else:
+                        # may has non-std leaving atom
+                        non_std_leaving_atoms = self.find_non_ccd_leaving_atoms(
+                            atom_array=atom_array,
+                            select_dict={
+                                "label_entity_id": entity_id,
+                                "res_id": chain.res_id[i],
+                                "atom_name": chain.atom_name[i],
+                            },
+                            component=component,
+                        )
+                        if len(non_std_leaving_atoms) > 0:
+                            remove_atom_names.extend(non_std_leaving_atoms)
+
+                # remove leaving atoms of this residue
+                remove_mask = np.isin(chain.atom_name[start:stop], remove_atom_names)
+                keep_atom_mask[np.arange(start, stop)[remove_mask]] = False
+
+            entity_atom_array[entity_id] = chain[keep_atom_mask]
+        return entity_atom_array
+
+    @staticmethod
+    def make_new_residue(
+        atom_array, res_start, res_stop, ref_chain=None
+    ) -> tuple[AtomArray, dict[int, int]]:
+        """
+        make new residue from atom_array[res_start:res_stop], ref_chain is the reference chain.
+        1. only remove leavning atom when central atom covalent to other residue.
+        2. if ref_chain is provided, remove all atoms not match the residue in ref_chain.
+        """
+        res_id = atom_array.res_id[res_start]
+        res_name = atom_array.res_name[res_start]
+        ref_residue = ccd.get_component_atom_array(
+            res_name,
+            keep_leaving_atoms=True,
+            keep_hydrogens=False,
+        )
+        if ref_residue is None:  # only https://www.rcsb.org/ligand/UNL
+            return atom_array[res_start:res_stop]
+
+        if ref_residue.central_to_leaving_groups is None:
+            # ambiguous: one leaving group bond to more than one central atom, keep same atoms with PDB entry.
+            return atom_array[res_start:res_stop]
+
+        if ref_chain is not None:
+            return ref_chain[ref_chain.res_id == res_id]
+
+        keep_atom_mask = np.ones(len(ref_residue), dtype=bool)
+
+        # remove leavning atoms when covalent to other residue
+        for i in range(res_start, res_stop):
+            central_name = atom_array.atom_name[i]
+            old_atom_names = atom_array.atom_name[res_start:res_stop]
+            idx = np.where(old_atom_names == central_name)[0]
+            if len(idx) == 0:
+                # central atom is not resolved in atom_array, not remove leaving atoms
+                continue
+            idx = idx[0] + res_start
+            bonds, types = atom_array.bonds.get_bonds(idx)
+            bond_count = (res_id != atom_array.res_id[bonds]).sum()
+            if bond_count == 0:
+                # central atom is not covalent to other residue, not remove leaving atoms
+                continue
+
+            if central_name in ref_residue.central_to_leaving_groups:
+                leaving_groups = ref_residue.central_to_leaving_groups[central_name]
+                # removed only when there are leaving atoms.
+                if bond_count >= len(leaving_groups):
+                    remove_groups = leaving_groups
+                else:
+                    # subsample leaving atoms, remove unresolved leaving atoms first
+                    exist_group = []
+                    not_exist_group = []
+                    for group in leaving_groups:
+                        for leaving_atom_name in group:
+                            atom_idx = atom_select(
+                                atom_array,
+                                select_dict={
+                                    "chain_id": atom_array.chain_id[i],
+                                    "res_id": atom_array.res_id[i],
+                                    "atom_name": leaving_atom_name,
+                                },
+                            )
+                            if len(atom_idx) > 0:  # resolved
+                                exist_group.append(group)
+                                break
+                        else:
+                            not_exist_group.append(group)
+
+                    # not remove leaving atoms of B and BE, if all leaving atoms is exist in atom_array
+                    if central_name in ["B", "BE"]:
+                        if not not_exist_group:
+                            continue
+
+                    if bond_count <= len(not_exist_group):
+                        remove_groups = random.sample(not_exist_group, bond_count)
+                    else:
+                        remove_groups = not_exist_group + random.sample(
+                            exist_group, bond_count - len(not_exist_group)
+                        )
+            else:
+                leaving_atoms = MMCIFParser.find_non_ccd_leaving_atoms(
+                    atom_array=atom_array,
+                    select_dict={
+                        "chain_id": atom_array.chain_id[i],
+                        "res_id": atom_array.res_id[i],
+                        "atom_name": atom_array.atom_name[i],
+                    },
+                    component=ref_residue,
+                )
+                remove_groups = [leaving_atoms]
+
+            names = [name for group in remove_groups for name in group]
+            remove_mask = np.isin(ref_residue.atom_name, names)
+            keep_atom_mask &= ~remove_mask
+
+        return ref_residue[keep_atom_mask]
+
+    def add_missing_atoms_and_residues(self, atom_array: AtomArray) -> AtomArray:
+        """add missing atoms and residues based on CCD and mmcif info.
+
+        Args:
+            atom_array (AtomArray): structure with missing residues and atoms, from PDB entry.
+
+        Returns:
+            AtomArray: structure added missing residues and atoms (label atom_array.is_resolved as False).
+        """
+        # build reference entity atom array, including missing residues
+        entity_atom_array = self.build_ref_chain_with_atom_array(atom_array)
+
+        # build new atom array and copy info from input atom array to it (new_array).
+        new_array = None
+        new_global_start = 0
+        o2n_amap = {}  # old to new atom map
+        chain_starts = struc.get_chain_starts(atom_array, add_exclusive_stop=True)
+        res_starts = struc.get_residue_starts(atom_array, add_exclusive_stop=True)
+        for c_start, c_stop in zip(chain_starts[:-1], chain_starts[1:]):
+            # get reference chain atom array
+            entity_id = atom_array.label_entity_id[c_start]
+            has_ref_chain = False
+            if entity_id in entity_atom_array:
+                has_ref_chain = True
+                ref_chain_array = entity_atom_array[entity_id].copy()
+                ref_chain_array = self.create_empty_annotation_like(
+                    atom_array, ref_chain_array
+                )
+
+            chain_array = None
+            c_res_starts = res_starts[(c_start <= res_starts) & (res_starts <= c_stop)]
+
+            # add missing residues
+            prev_res_id = 0
+            for r_start, r_stop in zip(c_res_starts[:-1], c_res_starts[1:]):
+                curr_res_id = atom_array.res_id[r_start]
+                if has_ref_chain and curr_res_id - prev_res_id > 1:
+                    # missing residue in head or middle, res_id is 1-based int.
+                    segment = ref_chain_array[
+                        (prev_res_id < ref_chain_array.res_id)
+                        & (ref_chain_array.res_id < curr_res_id)
+                    ]
+                    if chain_array is None:
+                        chain_array = segment
+                    else:
+                        chain_array += segment
+
+                new_global_start = 0 if new_array is None else len(new_array)
+                new_global_start += 0 if chain_array is None else len(chain_array)
+
+                # add missing atoms of existing residue
+                ref_chain = ref_chain_array if has_ref_chain else None
+                new_residue = self.make_new_residue(
+                    atom_array, r_start, r_stop, ref_chain
+                )
+
+                new_residue = self.create_empty_annotation_like(atom_array, new_residue)
+
+                # copy residue level info
+                residue_fields = ["res_id", "hetero", "label_seq_id", "auth_seq_id"]
+                for k in residue_fields:
+                    v = atom_array._annot[k][r_start]
+                    new_residue._annot[k][:] = v
+
+                # make o2n_amap: old to new atom map
+                name_to_index_new = {
+                    name: idx for idx, name in enumerate(new_residue.atom_name)
+                }
+                res_o2n_amap = {}
+                res_mismatch_idx = []
+                for old_idx in range(r_start, r_stop):
+                    old_name = atom_array.atom_name[old_idx]
+                    if old_name not in name_to_index_new:
+                        # AF3 SI 2.5.4 Filtering
+                        # For residues or small molecules with CCD codes, atoms outside of the CCD code’s defined set of atom names are removed.
+                        res_mismatch_idx.append(old_idx)
+                    else:
+                        new_idx = name_to_index_new[old_name]
+                        res_o2n_amap[old_idx] = new_global_start + new_idx
+                if len(res_o2n_amap) > len(res_mismatch_idx):
+                    # Match residues only if more than half of their resolved atoms are matched.
+                    # e.g. 1gbt GBS shows 2/12 match, not add to o2n_amap, all atoms are marked as is_resolved=False.
+                    o2n_amap.update(res_o2n_amap)
+
+                if chain_array is None:
+                    chain_array = new_residue
+                else:
+                    chain_array += new_residue
+
+                prev_res_id = curr_res_id
+
+            # missing residue in tail
+            if has_ref_chain:
+                last_res_id = ref_chain_array.res_id[-1]
+                if last_res_id > curr_res_id:
+                    chain_array += ref_chain_array[ref_chain_array.res_id > curr_res_id]
+
+            # copy chain level info
+            chain_fields = [
+                "chain_id",
+                "label_asym_id",
+                "label_entity_id",
+                "auth_asym_id",
+                # "asym_id_int",
+                # "entity_id_int",
+                # "sym_id_int",
+            ]
+            for k in chain_fields:
+                chain_array._annot[k][:] = atom_array._annot[k][c_start]
+
+            if new_array is None:
+                new_array = chain_array
+            else:
+                new_array += chain_array
+
+        # copy atom level info
+        old_idx = list(o2n_amap.keys())
+        new_idx = list(o2n_amap.values())
+        atom_fields = ["b_factor", "occupancy", "charge"]
+        for k in atom_fields:
+            if k not in atom_array._annot:
+                continue
+            new_array._annot[k][new_idx] = atom_array._annot[k][old_idx]
+
+        # add is_resolved annotation
+        is_resolved = np.zeros(len(new_array), dtype=bool)
+        is_resolved[new_idx] = True
+        new_array.set_annotation("is_resolved", is_resolved)
+
+        # copy coord
+        new_array.coord[:] = 0.0
+        new_array.coord[new_idx] = atom_array.coord[old_idx]
+        # copy bonds
+        old_bonds = atom_array.bonds.as_array()  # *n x 3* np.ndarray (i,j,bond_type)
+
+        # some non-leaving atoms are not in the new_array for atom name mismatch, e.g. 4msw TYF
+        # only keep bonds of matching atoms
+        old_bonds = old_bonds[
+            np.isin(old_bonds[:, 0], old_idx) & np.isin(old_bonds[:, 1], old_idx)
+        ]
+
+        old_bonds[:, 0] = [o2n_amap[i] for i in old_bonds[:, 0]]
+        old_bonds[:, 1] = [o2n_amap[i] for i in old_bonds[:, 1]]
+        new_bonds = struc.BondList(len(new_array), old_bonds)
+        if new_array.bonds is None:
+            new_array.bonds = new_bonds
+        else:
+            new_array.bonds = new_array.bonds.merge(new_bonds)
+
+        # add peptide bonds and nucleic acid bonds based on CCD type
+        new_array = ccd.add_inter_residue_bonds(
+            new_array, exclude_struct_conn_pairs=True, remove_far_inter_chain_pairs=True
+        )
+        return new_array
+
+    def make_chain_indices(
+        self, atom_array: AtomArray, pdb_cluster_file: Union[str, Path] = None
+    ) -> list:
+        """
+        Make chain indices.
+
+        Args:
+            atom_array (AtomArray): Biotite AtomArray object.
+            pdb_cluster_file (Union[str, Path]): cluster info txt file.
+        """
+        if pdb_cluster_file is None:
+            pdb_cluster_dict = {}
+        else:
+            pdb_cluster_dict = parse_pdb_cluster_file_to_dict(pdb_cluster_file)
+        poly_res_names = self.get_poly_res_names(atom_array)
+        starts = struc.get_chain_starts(atom_array, add_exclusive_stop=True)
+        chain_indices_list = []
+
+        is_centre_atom_and_is_resolved = (
+            atom_array.is_resolved & atom_array.centre_atom_mask.astype(bool)
+        )
+        for start, stop in zip(starts[:-1], starts[1:]):
+            chain_id = atom_array.chain_id[start]
+            entity_id = atom_array.label_entity_id[start]
+
+            # skip if centre atoms within a chain are all unresolved, e.g. 1zc8
+            if ~np.any(is_centre_atom_and_is_resolved[start:stop]):
+                continue
+
+            # AF3 SI 2.5.1 Weighted PDB dataset
+            entity_type = self.entity_poly_type.get(entity_id, "non-poly")
+
+            res_names = poly_res_names.get(entity_id, None)
+            if res_names is None:
+                chain_atoms = atom_array[start:stop]
+                res_ids, res_names = struc.get_residues(chain_atoms)
+
+            if "polypeptide" in entity_type:
+                mol_type = "prot"
+                sequence = ccd.res_names_to_sequence(res_names)
+                if len(sequence) < 10:
+                    cluster_id = sequence
+                else:
+                    pdb_entity = f"{self.pdb_id}_{entity_id}"
+                    if pdb_entity in pdb_cluster_dict:
+                        cluster_id, _ = pdb_cluster_dict[pdb_entity]
+                    elif entity_type == "polypeptide(D)":
+                        cluster_id = sequence
+                    elif sequence == "X" * len(sequence):
+                        chain_atoms = atom_array[start:stop]
+                        res_ids, res_names = struc.get_residues(chain_atoms)
+                        if np.all(res_names == "UNK"):
+                            cluster_id = "poly_UNK"
+                        else:
+                            cluster_id = "_".join(res_names)
+                    else:
+                        cluster_id = "NotInClusterTxt"
+
+            elif "ribonucleotide" in entity_type:
+                mol_type = "nuc"
+                cluster_id = ccd.res_names_to_sequence(res_names)
+            else:
+                mol_type = "ligand"
+                cluster_id = "_".join(res_names)
+
+            chain_dict = {
+                "entity_id": entity_id,  # str
+                "chain_id": chain_id,
+                "mol_type": mol_type,
+                "cluster_id": cluster_id,
+            }
+            chain_indices_list.append(chain_dict)
+        return chain_indices_list
+
+    def make_interface_indices(
+        self, atom_array: AtomArray, chain_indices_list: list
+    ) -> list:
+        """make interface indices
+        As described in SI 2.5.1, interfaces defined as pairs of chains with minimum heavy atom
+        (i.e. non-hydrogen) separation less than 5 Ã…
+        Args:
+            atom_array (AtomArray): _description_
+            chain_indices_list (List): _description_
+        """
+
+        chain_indices_dict = {i["chain_id"]: i for i in chain_indices_list}
+        interface_indices_dict = {}
+
+        cell_list = struc.CellList(
+            atom_array, cell_size=5, selection=atom_array.is_resolved
+        )
+        for chain_i, chain_i_dict in chain_indices_dict.items():
+            chain_mask = atom_array.chain_id == chain_i
+            coord = atom_array.coord[chain_mask & atom_array.is_resolved]
+            neighbors_indices_2d = cell_list.get_atoms(
+                coord, radius=5
+            )  # shape:(n_coord, max_n_neighbors), padding with -1
+            neighbors_indices = np.unique(neighbors_indices_2d)
+            neighbors_indices = neighbors_indices[neighbors_indices != -1]
+
+            chain_j_list = np.unique(atom_array.chain_id[neighbors_indices])
+            for chain_j in chain_j_list:
+                if chain_i == chain_j:
+                    continue
+
+                # skip if centre atoms within a chain are all unresolved, e.g. 1zc8
+                if chain_j not in chain_indices_dict:
+                    continue
+
+                interface_id = "_".join(sorted([chain_i, chain_j]))
+                if interface_id in interface_indices_dict:
+                    continue
+                chain_j_dict = chain_indices_dict[chain_j]
+                interface_dict = {}
+                # chain_id --> chain_1_id
+                # mol_type --> mol_1_type
+                # entity_id --> entity_1_id
+                # cluster_id --> cluster_1_id
+                interface_dict.update(
+                    {k.replace("_", "_1_"): v for k, v in chain_i_dict.items()}
+                )
+                interface_dict.update(
+                    {k.replace("_", "_2_"): v for k, v in chain_j_dict.items()}
+                )
+                interface_indices_dict[interface_id] = interface_dict
+        return list(interface_indices_dict.values())
+
+    @staticmethod
+    def add_sub_mol_type(
+        atom_array: AtomArray,
+        indices_dict: dict[str, Any],
+    ) -> dict[str, Any]:
+        """
+        Add a "sub_mol_[i]_type" field to indices_dict.
+        It includes the following mol_types and sub_mol_types:
+
+        prot
+            - prot
+            - glycosylation_prot
+            - modified_prot
+
+        nuc
+            - dna
+            - rna
+            - modified_dna
+            - modified_rna
+            - dna_rna_hybrid
+
+        ligand
+            - bonded_ligand
+            - non_bonded_ligand
+
+        excluded_ligand
+            - excluded_ligand
+
+        glycans
+            - glycans
+
+        ions
+            - ions
+
+        Args:
+            atom_array (AtomArray): Biotite AtomArray object of bioassembly.
+            indices_dict (dict[str, Any]): A dict of chain or interface indices info.
+
+        Returns:
+            dict[str, Any]: A dict of chain or interface indices info with "sub_mol_[i]_type" field.
+        """
+        polymer_lig_bonds = get_ligand_polymer_bond_mask(atom_array)
+        if len(polymer_lig_bonds) == 0:
+            lig_polymer_bond_chain_id = []
+        else:
+            lig_polymer_bond_chain_id = atom_array.chain_id[
+                np.unique(polymer_lig_bonds[:, :2])
+            ]
+
+        for i in ["1", "2"]:
+            if indices_dict[f"entity_{i}_id"] == "":
+                indices_dict[f"sub_mol_{i}_type"] = ""
+                continue
+            entity_type = indices_dict[f"mol_{i}_type"]
+            mol_id = atom_array.mol_id[
+                atom_array.label_entity_id == indices_dict[f"entity_{i}_id"]
+            ][0]
+            mol_all_res_name = atom_array.res_name[atom_array.mol_id == mol_id]
+            chain_all_mol_type = atom_array.mol_type[
+                atom_array.chain_id == indices_dict[f"chain_{i}_id"]
+            ]
+            chain_all_res_name = atom_array.res_name[
+                atom_array.chain_id == indices_dict[f"chain_{i}_id"]
+            ]
+
+            if entity_type == "ligand":
+                ccd_code = indices_dict[f"cluster_{i}_id"]
+                if ccd_code in GLYCANS:
+                    indices_dict[f"sub_mol_{i}_type"] = "glycans"
+
+                elif ccd_code in LIGAND_EXCLUSION:
+                    indices_dict[f"sub_mol_{i}_type"] = "excluded_ligand"
+
+                elif indices_dict[f"chain_{i}_id"] in lig_polymer_bond_chain_id:
+                    indices_dict[f"sub_mol_{i}_type"] = "bonded_ligand"
+                else:
+                    indices_dict[f"sub_mol_{i}_type"] = "non_bonded_ligand"
+
+            elif entity_type == "prot":
+                # glycosylation
+                if np.any(np.isin([mol_all_res_name], list(GLYCANS))):
+                    indices_dict[f"sub_mol_{i}_type"] = "glycosylation_prot"
+
+                if ~np.all(np.isin(chain_all_res_name, list(PRO_STD_RESIDUES.keys()))):
+                    indices_dict[f"sub_mol_{i}_type"] = "modified_prot"
+
+            elif entity_type == "nuc":
+                if np.all(chain_all_mol_type == "dna"):
+                    if np.any(
+                        np.isin(chain_all_res_name, list(DNA_STD_RESIDUES.keys()))
+                    ):
+                        indices_dict[f"sub_mol_{i}_type"] = "dna"
+                    else:
+                        indices_dict[f"sub_mol_{i}_type"] = "modified_dna"
+
+                elif np.all(chain_all_mol_type == "rna"):
+                    if np.any(
+                        np.isin(chain_all_res_name, list(RNA_STD_RESIDUES.keys()))
+                    ):
+                        indices_dict[f"sub_mol_{i}_type"] = "rna"
+                    else:
+                        indices_dict[f"sub_mol_{i}_type"] = "modified_rna"
+                else:
+                    indices_dict[f"sub_mol_{i}_type"] = "dna_rna_hybrid"
+
+            else:
+                indices_dict[f"sub_mol_{i}_type"] = [f"mol_{i}_type"]
+
+            if indices_dict.get(f"sub_mol_{i}_type") is None:
+                indices_dict[f"sub_mol_{i}_type"] = indices_dict[f"mol_{i}_type"]
+        return indices_dict
+
+    @staticmethod
+    def add_eval_type(indices_dict: dict[str, Any]) -> dict[str, Any]:
+        """
+        Differentiate DNA and RNA from the nucleus.
+
+        Args:
+            indices_dict (dict[str, Any]): A dict of chain or interface indices info.
+
+        Returns:
+            dict[str, Any]: A dict of chain or interface indices info with "eval_type" field.
+        """
+        if indices_dict["mol_type_group"] not in ["intra_nuc", "nuc_prot"]:
+            eval_type = indices_dict["mol_type_group"]
+        elif "dna_rna_hybrid" in [
+            indices_dict["sub_mol_1_type"],
+            indices_dict["sub_mol_2_type"],
+        ]:
+            eval_type = indices_dict["mol_type_group"]
+        else:
+            if indices_dict["mol_type_group"] == "intra_nuc":
+                nuc_type = str(indices_dict["sub_mol_1_type"]).split("_")[-1]
+                eval_type = f"intra_{nuc_type}"
+            else:
+                nuc_type1 = str(indices_dict["sub_mol_1_type"]).split("_")[-1]
+                nuc_type2 = str(indices_dict["sub_mol_2_type"]).split("_")[-1]
+                if "dna" in [nuc_type1, nuc_type2]:
+                    eval_type = "dna_prot"
+                else:
+                    eval_type = "rna_prot"
+        indices_dict["eval_type"] = eval_type
+        return indices_dict
+
+    def make_indices(
+        self,
+        bioassembly_dict: dict[str, Any],
+        pdb_cluster_file: Union[str, Path] = None,
+    ) -> list:
+        """generate indices of chains and interfaces for sampling data
+
+        Args:
+            bioassembly_dict (dict): dict from MMCIFParser.get_bioassembly().
+            cluster_file (str): PDB cluster file. Defaults to None.
+        Return:
+            List(Dict(str, str)): sample_indices_list
+        """
+        atom_array = bioassembly_dict["atom_array"]
+        if atom_array is None:
+            print(
+                f"Warning: make_indices() input atom_array is None, return empty list (PDB Code:{bioassembly_dict['pdb_id']})"
+            )
+            return []
+        chain_indices_list = self.make_chain_indices(atom_array, pdb_cluster_file)
+        interface_indices_list = self.make_interface_indices(
+            atom_array, chain_indices_list
+        )
+        meta_dict = {
+            "pdb_id": bioassembly_dict["pdb_id"],
+            "assembly_id": bioassembly_dict["assembly_id"],
+            "release_date": self.release_date,
+            "num_tokens": bioassembly_dict["num_tokens"],
+            "num_prot_chains": bioassembly_dict["num_prot_chains"],
+            "resolution": self.resolution,
+        }
+        sample_indices_list = []
+        for chain_dict in chain_indices_list:
+            chain_dict_out = {k.replace("_", "_1_"): v for k, v in chain_dict.items()}
+            chain_dict_out.update(
+                {k.replace("_", "_2_"): "" for k, v in chain_dict.items()}
+            )
+            chain_dict_out["cluster_id"] = chain_dict["cluster_id"]
+            chain_dict_out.update(meta_dict)
+            chain_dict_out["type"] = "chain"
+            sample_indices_list.append(chain_dict_out)
+
+        for interface_dict in interface_indices_list:
+            cluster_ids = [
+                interface_dict["cluster_1_id"],
+                interface_dict["cluster_2_id"],
+            ]
+            interface_dict["cluster_id"] = ":".join(sorted(cluster_ids))
+            interface_dict.update(meta_dict)
+            interface_dict["type"] = "interface"
+            sample_indices_list.append(interface_dict)
+
+        for indices in sample_indices_list:
+            for i in ["1", "2"]:
+                chain_id = indices[f"chain_{i}_id"]
+                if chain_id == "":
+                    continue
+                chain_atom_num = np.sum([atom_array.chain_id == chain_id])
+                if chain_atom_num == 1:
+                    indices[f"mol_{i}_type"] = "ions"
+
+            if indices["type"] == "chain":
+                indices["mol_type_group"] = f'intra_{indices["mol_1_type"]}'
+            else:
+                indices["mol_type_group"] = "_".join(
+                    sorted([indices["mol_1_type"], indices["mol_2_type"]])
+                )
+            indices = self.add_sub_mol_type(atom_array, indices)
+            indices = self.add_eval_type(indices)
+        return sample_indices_list
+
+
+class DistillationMMCIFParser(MMCIFParser):
+
+    def get_structure_dict(self) -> dict[str, Any]:
+        """
+        Get an AtomArray from a CIF file of distillation data.
+
+        Returns:
+            Dict[str, Any]: a dict of asymmetric unit structure info.
+        """
+        # created AtomArray of first model from mmcif atom_site (Asymmetric Unit)
+        atom_array = self.get_structure()
+
+        # convert MSE to MET to consistent with MMCIFParser.get_poly_res_names()
+        atom_array = self.mse_to_met(atom_array)
+
+        structure_dict = {
+            "pdb_id": self.pdb_id,
+            "atom_array": None,
+            "assembly_id": None,
+            "sequences": self.get_sequences(atom_array),
+            "entity_poly_type": self.entity_poly_type,
+            "num_tokens": -1,
+            "num_prot_chains": -1,
+        }
+
+        pipeline_functions = [
+            self.fix_arginine,
+            self.add_missing_atoms_and_residues,  # add UNK
+        ]
+
+        for func in pipeline_functions:
+            atom_array = func(atom_array)
+            if len(atom_array) == 0:
+                # no atoms left
+                return structure_dict
+
+        atom_array = AddAtomArrayAnnot.add_token_mol_type(
+            atom_array, self.entity_poly_type
+        )
+        atom_array = AddAtomArrayAnnot.add_centre_atom_mask(atom_array)
+        atom_array = AddAtomArrayAnnot.add_atom_mol_type_mask(atom_array)
+        atom_array = AddAtomArrayAnnot.add_distogram_rep_atom_mask(atom_array)
+        atom_array = AddAtomArrayAnnot.add_plddt_m_rep_atom_mask(atom_array)
+        atom_array = AddAtomArrayAnnot.add_cano_seq_resname(atom_array)
+        atom_array = AddAtomArrayAnnot.add_tokatom_idx(atom_array)
+        atom_array = AddAtomArrayAnnot.add_modified_res_mask(atom_array)
+        assert (
+            atom_array.centre_atom_mask.sum()
+            == atom_array.distogram_rep_atom_mask.sum()
+        )
+
+        # rename chain_ids from A A B to A0 A1 B0 and add asym_id_int, entity_id_int, sym_id_int
+        atom_array = AddAtomArrayAnnot.unique_chain_and_add_ids(atom_array)
+        atom_array = AddAtomArrayAnnot.find_equiv_mol_and_assign_ids(
+            atom_array, self.entity_poly_type
+        )
+
+        # numerical encoding of (chain id, residue index)
+        atom_array = AddAtomArrayAnnot.add_ref_space_uid(atom_array)
+        atom_array = AddAtomArrayAnnot.add_ref_info_and_res_perm(atom_array)
+
+        # the number of protein chains in the structure
+        prot_label_entity_ids = [
+            k for k, v in self.entity_poly_type.items() if "polypeptide" in v
+        ]
+        num_prot_chains = len(
+            np.unique(
+                atom_array.chain_id[
+                    np.isin(atom_array.label_entity_id, prot_label_entity_ids)
+                ]
+            )
+        )
+        structure_dict["num_prot_chains"] = num_prot_chains
+        structure_dict["atom_array"] = atom_array
+        structure_dict["num_tokens"] = atom_array.centre_atom_mask.sum()
+        return structure_dict
+
 
 class AddAtomArrayAnnot(object):
     """
@@ -951,7 +2268,7 @@ class AddAtomArrayAnnot(object):
         """
         Unique chain ID and add asym_id, entity_id, sym_id.
         Adds a number to the chain ID to make chain IDs in the assembly unique.
-        Example: [A, B, A, B, C] ==> [A0, B0, A1, B1, C0]
+        Example: [A, B, A, B, C] -> [A, B, A.1, B.1, C]
 
         Args:
             atom_array (AtomArray): Biotite AtomArray object.
@@ -962,49 +2279,26 @@ class AddAtomArrayAnnot(object):
                 - entity_id_int: np.array(int)
                 - sym_id_int: np.array(int)
         """
-        entity_id_uniq = np.sort(np.unique(atom_array.label_entity_id))
-        entity_id_dict = {e: i for i, e in enumerate(entity_id_uniq)}
-        asym_ids = np.zeros(len(atom_array), dtype=int)
-        entity_ids = np.zeros(len(atom_array), dtype=int)
-        sym_ids = np.zeros(len(atom_array), dtype=int)
-        chain_ids = np.zeros(len(atom_array), dtype="U4")
-        counter = Counter()
-        start_indices = struc.get_chain_starts(atom_array, add_exclusive_stop=True)
-        for i in range(len(start_indices) - 1):
-            start_i = start_indices[i]
-            stop_i = start_indices[i + 1]
-            asym_ids[start_i:stop_i] = i
-
-            entity_id = atom_array.label_entity_id[start_i]
-            entity_ids[start_i:stop_i] = entity_id_dict[entity_id]
+        chain_ids = np.zeros(len(atom_array), dtype="<U8")
+        chain_starts = get_chain_starts(atom_array, add_exclusive_stop=True)
 
-            sym_ids[start_i:stop_i] = counter[entity_id]
-            counter[entity_id] += 1
-            new_chain_id = f"{atom_array.chain_id[start_i]}{sym_ids[start_i]}"
-            chain_ids[start_i:stop_i] = new_chain_id
+        chain_counter = Counter()
+        for start, stop in zip(chain_starts[:-1], chain_starts[1:]):
+            ori_chain_id = atom_array.chain_id[start]
+            cnt = chain_counter[ori_chain_id]
+            if cnt == 0:
+                new_chain_id = ori_chain_id
+            else:
+                new_chain_id = f"{ori_chain_id}.{chain_counter[ori_chain_id]}"
 
-        atom_array.set_annotation("asym_id_int", asym_ids)
-        atom_array.set_annotation("entity_id_int", entity_ids)
-        atom_array.set_annotation("sym_id_int", sym_ids)
-        atom_array.chain_id = chain_ids
-        return atom_array
+            chain_ids[start:stop] = new_chain_id
+            chain_counter[ori_chain_id] += 1
 
-    @staticmethod
-    def add_int_id(atom_array):
-        """
-        Unique chain ID and add asym_id, entity_id, sym_id.
-        Adds a number to the chain ID to make chain IDs in the assembly unique.
-        Example: [A, B, A, B, C] ==> [A0, B0, A1, B1, C0]
+        assert "" not in chain_ids
+        # reset chain id
+        atom_array.del_annotation("chain_id")
+        atom_array.set_annotation("chain_id", chain_ids)
 
-        Args:
-            atom_array (AtomArray): Biotite AtomArray object.
-
-        Returns:
-            AtomArray: Biotite AtomArray object with new annotations:
-                - asym_id_int: np.array(int)
-                - entity_id_int: np.array(int)
-                - sym_id_int: np.array(int)
-        """
         entity_id_uniq = np.sort(np.unique(atom_array.label_entity_id))
         entity_id_dict = {e: i for i, e in enumerate(entity_id_uniq)}
         asym_ids = np.zeros(len(atom_array), dtype=int)
diff --git a/protenix/protenix/data/utils.py b/protenix/protenix/data/utils.py
index 18f98aa2d76605c1099b5d00f6e530d50ca55bb1..b484e704dcbe75bb291d11cf49464868f1b9f370 100644
--- a/protenix/protenix/data/utils.py
+++ b/protenix/protenix/data/utils.py
@@ -14,6 +14,7 @@
 
 import argparse
 import copy
+import functools
 import os
 import re
 from collections import defaultdict
@@ -59,6 +60,27 @@ def int_to_letters(n: int) -> str:
     return result
 
 
+def get_inter_residue_bonds(atom_array: AtomArray) -> np.ndarray:
+    """get inter residue bonds by checking chain_id and res_id
+
+    Args:
+        atom_array (AtomArray): Biotite AtomArray, must have chain_id and res_id
+
+    Returns:
+        np.ndarray: inter residue bonds, shape = (n,2)
+    """
+    if atom_array.bonds is None:
+        return []
+    idx_i = atom_array.bonds._bonds[:, 0]
+    idx_j = atom_array.bonds._bonds[:, 1]
+    chain_id_diff = atom_array.chain_id[idx_i] != atom_array.chain_id[idx_j]
+    res_id_diff = atom_array.res_id[idx_i] != atom_array.res_id[idx_j]
+    diff_mask = chain_id_diff | res_id_diff
+    inter_residue_bonds = atom_array.bonds._bonds[diff_mask]
+    inter_residue_bonds = inter_residue_bonds[:, :2]  # remove bond type
+    return inter_residue_bonds
+
+
 def get_starts_by(
     atom_array: AtomArray, by_annot: str, add_exclusive_stop=False
 ) -> np.ndarray:
@@ -88,6 +110,26 @@ def get_starts_by(
         return np.concatenate(([0], starts))
 
 
+def atom_select(atom_array: AtomArray, select_dict: dict, as_mask=False) -> np.ndarray:
+    """return index of atom_array that match select_dict
+
+    Args:
+        atom_array (AtomArray): Biotite AtomArray
+        select_dict (dict): select dict, eg: {'element': 'C'}
+        as_mask (bool, optional): return mask of atom_array. Defaults to False.
+
+    Returns:
+        np.ndarray: index of atom_array that match select_dict
+    """
+    mask = np.ones(len(atom_array), dtype=bool)
+    for k, v in select_dict.items():
+        mask = mask & (getattr(atom_array, k) == v)
+    if as_mask:
+        return mask
+    else:
+        return np.where(mask)[0]
+
+
 def get_ligand_polymer_bond_mask(
     atom_array: AtomArray, lig_include_ions=False
 ) -> np.ndarray:
@@ -140,6 +182,38 @@ def get_ligand_polymer_bond_mask(
     return lig_polymer_bonds
 
 
+@functools.lru_cache
+def parse_pdb_cluster_file_to_dict(
+    cluster_file: str, remove_uniprot: bool = True
+) -> dict[str, tuple]:
+    """parse PDB cluster file, and return a pandas dataframe
+    example cluster file:
+    https://cdn.rcsb.org/resources/sequence/clusters/clusters-by-entity-40.txt
+
+    Args:
+        cluster_file (str): cluster_file path
+    Returns:
+        dict(str, tuple(str, str)): {pdb_id}_{entity_id} --> [cluster_id, cluster_size]
+    """
+    pdb_cluster_dict = {}
+    with open(cluster_file) as f:
+        for line in f:
+            pdb_clusters = []
+            for ids in line.strip().split():
+                if remove_uniprot:
+                    if ids.startswith("AF_") or ids.startswith("MA_"):
+                        continue
+                pdb_clusters.append(ids)
+            cluster_size = len(pdb_clusters)
+            if cluster_size == 0:
+                continue
+            # use first member as cluster id.
+            cluster_id = f"pdb_cluster_{pdb_clusters[0]}"
+            for ids in pdb_clusters:
+                pdb_cluster_dict[ids.lower()] = (cluster_id, cluster_size)
+    return pdb_cluster_dict
+
+
 def get_clean_data(atom_array: AtomArray) -> AtomArray:
     """
     Removes unresolved atoms from the AtomArray.
@@ -230,6 +304,21 @@ class CIFWriter:
         self.atom_array = atom_array
         self.entity_poly_type = entity_poly_type
 
+    def _get_entity_block(self):
+        if self.entity_poly_type is None:
+            return {}
+        entity_ids_in_atom_array = np.sort(np.unique(self.atom_array.label_entity_id))
+        entity_block_dict = defaultdict(list)
+        for entity_id in entity_ids_in_atom_array:
+            if entity_id not in self.entity_poly_type:
+                entity_type = "non-polymer"
+            else:
+                entity_type = "polymer"
+            entity_block_dict["id"].append(entity_id)
+            entity_block_dict["pdbx_description"].append(".")
+            entity_block_dict["type"].append(entity_type)
+        return pdbx.CIFCategory(entity_block_dict)
+    
     def _get_entity_poly_and_entity_poly_seq_block(self):
         entity_poly = defaultdict(list)
         for entity_id, entity_type in self.entity_poly_type.items():
@@ -247,6 +336,9 @@ class CIFWriter:
             entity_poly["entity_id"].append(entity_id)
             entity_poly["pdbx_strand_id"].append(label_asym_ids_str)
             entity_poly["type"].append(entity_type)
+            
+        if not entity_poly:
+            return {}
 
         entity_poly_seq = defaultdict(list)
         for entity_id, label_asym_ids_str in zip(
@@ -305,6 +397,7 @@ class CIFWriter:
 
         block_dict = {"entry": pdbx.CIFCategory({"id": entry_id})}
         if self.entity_poly_type:
+            block_dict["entity"] = self._get_entity_block()
             block_dict.update(self._get_entity_poly_and_entity_poly_seq_block())
 
         block = pdbx.CIFBlock(block_dict)
@@ -317,6 +410,11 @@ class CIFWriter:
         pdbx.set_structure(cif, self.atom_array, include_bonds=include_bonds)
         block = cif.block
         atom_site = block.get("atom_site")
+
+        occ = atom_site.get("occupancy")
+        if occ is None:
+            atom_site["occupancy"] = np.ones(len(self.atom_array), dtype=float)
+
         atom_site["label_entity_id"] = self.atom_array.label_entity_id
         cif.write(output_path)
 
@@ -678,7 +776,11 @@ def pdb_to_cif(input_fname: str, output_fname: str, entry_id: str = None):
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument("--pdb_file", type=str, required=True, help="The pdb file to parse")
-    parser.add_argument("--cif_file", type=str, required=True, help="The cif file path to generate")
+    parser.add_argument(
+        "--pdb_file", type=str, required=True, help="The pdb file to parse"
+    )
+    parser.add_argument(
+        "--cif_file", type=str, required=True, help="The cif file path to generate"
+    )
     args = parser.parse_args()
-    pdb_to_cif(args.pdb_file, args.cif_file)
\ No newline at end of file
+    pdb_to_cif(args.pdb_file, args.cif_file)
diff --git a/protenix/protenix/model/modules/pairformer.py b/protenix/protenix/model/modules/pairformer.py
index ad40f4b683c4e5699eb4e81ad5cc69b4cc3294b6..4ee3054887081798b39d5bef134eea880fa97e74 100644
--- a/protenix/protenix/model/modules/pairformer.py
+++ b/protenix/protenix/model/modules/pairformer.py
@@ -141,7 +141,7 @@ class PairformerBlock(nn.Module):
             z = z.transpose(-2, -3).contiguous()
             z += self.tri_att_end(
                 z,
-                mask=pair_mask.tranpose(-1, -2) if pair_mask is not None else None,
+                mask=pair_mask.transpose(-1, -2) if pair_mask is not None else None,
                 use_memory_efficient_kernel=use_memory_efficient_kernel,
                 use_deepspeed_evo_attention=use_deepspeed_evo_attention,
                 use_lma=use_lma,
@@ -184,7 +184,7 @@ class PairformerBlock(nn.Module):
             z = z + self.dropout_row(
                 self.tri_att_end(
                     z,
-                    mask=pair_mask.tranpose(-1, -2) if pair_mask is not None else None,
+                    mask=pair_mask.transpose(-1, -2) if pair_mask is not None else None,
                     use_memory_efficient_kernel=use_memory_efficient_kernel,
                     use_deepspeed_evo_attention=use_deepspeed_evo_attention,
                     use_lma=use_lma,
diff --git a/protenix/protenix/openfold_local/model/primitives.py b/protenix/protenix/openfold_local/model/primitives.py
index 1d0fb68d34cd1acf5cd85e7197884ab2c9bdd1f6..f1ad20a0756a2c069600087b8168208c95bfb199 100644
--- a/protenix/protenix/openfold_local/model/primitives.py
+++ b/protenix/protenix/openfold_local/model/primitives.py
@@ -544,13 +544,6 @@ class Attention(nn.Module):
 
         if use_memory_efficient_kernel:
             raise Exception(f"use_memory_efficient_kernel=True not supported!!!")
-            if len(biases) > 2:
-                raise ValueError(
-                    "If use_memory_efficient_kernel is True, you may only "
-                    "provide up to two bias terms"
-                )
-            o = attention_core(q, k, v, *((biases + [None] * 2)[:2]))
-            o = o.transpose(-2, -3)
         elif use_deepspeed_evo_attention:
             if len(biases) > 2:
                 raise ValueError(
diff --git a/protenix/protenix/openfold_local/np/residue_constants.py b/protenix/protenix/openfold_local/np/residue_constants.py
index dfafab7381a4ebc5d80a9084dff19a4a7c4c9eeb..9aff5ffa69ef0de297ca9a3529d693d2b80ecf39 100644
--- a/protenix/protenix/openfold_local/np/residue_constants.py
+++ b/protenix/protenix/openfold_local/np/residue_constants.py
@@ -22,7 +22,7 @@ from importlib import resources
 from typing import List, Mapping, Tuple
 
 import numpy as np
-import tree
+import optree
 
 # Distance from one CA to next CA [trans configuration: omega = 180].
 ca_ca = 3.80209737096
@@ -1072,7 +1072,7 @@ chi_atom_2_one_hot = chi_angle_atom(2)
 
 # An array like chi_angles_atoms but using indices rather than names.
 chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
-chi_angles_atom_indices = tree.map_structure(
+chi_angles_atom_indices = optree.tree_map(
     lambda atom_name: atom_order[atom_name], chi_angles_atom_indices
 )
 chi_angles_atom_indices = np.array(
diff --git a/protenix/protenix/openfold_local/utils/checkpointing.py b/protenix/protenix/openfold_local/utils/checkpointing.py
index 149fdec69a5f57a5db1ea96e87f26e49198ddd41..7ba7dfb51e06925f8da3b376f58109ffee04016f 100644
--- a/protenix/protenix/openfold_local/utils/checkpointing.py
+++ b/protenix/protenix/openfold_local/utils/checkpointing.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from functools import partial
 import importlib
 from typing import Any, List, Callable, Optional
 
@@ -33,7 +34,7 @@ def get_checkpoint_fn():
     if deepspeed_is_configured:
         checkpoint = deepspeed.checkpointing.checkpoint
     else:
-        checkpoint = torch.utils.checkpoint.checkpoint
+        checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
 
     return checkpoint
 
diff --git a/protenix/protenix/openfold_local/utils/chunk_utils.py b/protenix/protenix/openfold_local/utils/chunk_utils.py
index 3a7dc885961c3347280ff078314a2fe6338a508f..777ab68a4c126a8b0eedf521e163f8ad98961b76 100644
--- a/protenix/protenix/openfold_local/utils/chunk_utils.py
+++ b/protenix/protenix/openfold_local/utils/chunk_utils.py
@@ -14,28 +14,15 @@
 import logging
 import math
 from functools import partial
-from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
+from typing import Any, Callable, Optional, Sequence, Tuple
 
+import optree
 import torch
-
 from protenix.openfold_local.utils.tensor_utils import tensor_tree_map, tree_map
 
 
 def _fetch_dims(tree):
-    shapes = []
-    tree_type = type(tree)
-    if tree_type is dict:
-        for v in tree.values():
-            shapes.extend(_fetch_dims(v))
-    elif tree_type is list or tree_type is tuple:
-        for t in tree:
-            shapes.extend(_fetch_dims(t))
-    elif tree_type is torch.Tensor:
-        shapes.append(tree.shape)
-    else:
-        raise ValueError("Not supported")
-
-    return shapes
+    return optree.tree_flatten(optree.tree_map(lambda x: x.shape, tree))[0]
 
 
 @torch.jit.ignore
diff --git a/protenix/protenix/web_service/colab_request_parser.py b/protenix/protenix/web_service/colab_request_parser.py
index a7cc4f7f359c5f92937d9945b96607bf59fb832d..937ba9df13e59e12c0666e2eaf04bba030982c3e 100644
--- a/protenix/protenix/web_service/colab_request_parser.py
+++ b/protenix/protenix/web_service/colab_request_parser.py
@@ -32,7 +32,9 @@ from protenix.data.json_to_feature import SampleDictToFeatures
 from protenix.web_service.colab_request_utils import run_mmseqs2_service
 from protenix.web_service.dependency_url import URL
 
-MMSEQS_SERVICE_HOST_URL = os.getenv("MMSEQS_SERVICE_HOST_URL", "http://101.126.11.40:80")
+MMSEQS_SERVICE_HOST_URL = os.getenv(
+    "MMSEQS_SERVICE_HOST_URL", "https://protenix-server.com/api/msa"
+)
 MAX_ATOM_NUM = 60000
 MAX_TOKEN_NUM = 5000
 DATA_CACHE_DIR = "/af3-dev/release_data/"
@@ -76,7 +78,9 @@ class TooLargeComplexError(Exception):
 
 
 class RequestParser(object):
-    def __init__(self, request_json_path: str, request_dir: str, email: str = "") -> None:
+    def __init__(
+        self, request_json_path: str, request_dir: str, email: str = ""
+    ) -> None:
         with open(request_json_path, "r") as f:
             self.request = json.load(f)
         self.request_dir = request_dir
@@ -207,7 +211,10 @@ class RequestParser(object):
 
     @staticmethod
     def msa_search(
-        seqs_pending_msa: Sequence[str], tmp_fasta_fpath: str, msa_res_dir: str, email: str = ""
+        seqs_pending_msa: Sequence[str],
+        tmp_fasta_fpath: str,
+        msa_res_dir: str,
+        email: str = "",
     ) -> None:
         lines = []
         for idx, seq in enumerate(seqs_pending_msa):
@@ -229,7 +236,7 @@ class RequestParser(object):
                 use_templates=False,
                 host_url=MMSEQS_SERVICE_HOST_URL,
                 user_agent="colabfold/1.5.5",
-                email=email
+                email=email,
             )
         except Exception as e:
             error_message = f"MMSEQS2 failed with the following error message:\n{traceback.format_exc()}"
@@ -250,13 +257,20 @@ class RequestParser(object):
         def read_a3m(a3m_file: str) -> Tuple[List[str], List[str]]:
             heads = []
             seqs = []
+            # Record the row index. The index before this index is the MSA of Uniref30 DB,
+            # and the index after this index is the MSA of ColabfoldDB.
+            uniref_index = 0
             with open(a3m_file, "r") as infile:
-                for line in infile:
+                for idx, line in enumerate(infile):
                     if line.startswith(">"):
                         heads.append(line)
+                        if idx == 0:
+                            query_name = line
+                        elif idx > 0 and line == query_name:
+                            uniref_index = idx
                     else:
                         seqs.append(line)
-            return heads, seqs
+            return heads, seqs, uniref_index
 
         def make_pairing_and_non_pairing_msa(
             query_seq: str,
@@ -265,18 +279,22 @@ class RequestParser(object):
             uniref_to_ncbi_taxid: Mapping[str, str],
         ) -> List[str]:
 
-            heads, msa_seqs = read_a3m(raw_a3m_path)
+            heads, msa_seqs, uniref_index = read_a3m(raw_a3m_path)
             uniref100_lines = [">query\n", f"{query_seq}\n"]
             other_lines = [">query\n", f"{query_seq}\n"]
 
-            for head, msa_seq in zip(heads, msa_seqs):
+            for idx, (head, msa_seq) in enumerate(zip(heads, msa_seqs)):
                 if msa_seq.rstrip("\n") == query_seq:
                     continue
 
-                if "UniRef" in head:
-                    uniref_id = head.split("\t")[0][1:]
-                    ncbi_taxid = uniref_to_ncbi_taxid.get(uniref_id, None)
-                    if ncbi_taxid is not None:
+                uniref_id = head.split("\t")[0][1:]
+                ncbi_taxid = uniref_to_ncbi_taxid.get(uniref_id, None)
+                if (ncbi_taxid is not None) and (idx < (uniref_index // 2)):
+                    if not uniref_id.startswith("UniRef100_"):
+                        head = head.replace(
+                            uniref_id, f"UniRef100_{uniref_id}_{ncbi_taxid}/"
+                        )
+                    else:
                         head = head.replace(uniref_id, f"{uniref_id}_{ncbi_taxid}/")
                     uniref100_lines.extend([head, msa_seq])
                 else:
@@ -294,7 +312,7 @@ class RequestParser(object):
             seq_dir: str,
             raw_a3m_path: str,
         ):
-            heads, msa_seqs = read_a3m(raw_a3m_path)
+            heads, msa_seqs, _ = read_a3m(raw_a3m_path)
             other_lines = [">query\n", f"{query_seq}\n"]
             for head, msa_seq in zip(heads, msa_seqs):
                 if msa_seq.rstrip("\n") == query_seq:
@@ -390,9 +408,18 @@ class RequestParser(object):
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument("--request_json_path", type=str, required=True, help="Path to the request JSON file.")
-    parser.add_argument("--request_dir", type=str, required=True, help="Path to the request directory.")
-    parser.add_argument("--email", type=str, required=False, default="", help="Your email address.")
+    parser.add_argument(
+        "--request_json_path",
+        type=str,
+        required=True,
+        help="Path to the request JSON file.",
+    )
+    parser.add_argument(
+        "--request_dir", type=str, required=True, help="Path to the request directory."
+    )
+    parser.add_argument(
+        "--email", type=str, required=False, default="", help="Your email address."
+    )
 
     args = parser.parse_args()
     parser = RequestParser(
diff --git a/protenix/protenix/web_service/colab_request_utils.py b/protenix/protenix/web_service/colab_request_utils.py
index 5090bbcafe34c62f0faacf00cb79425a382dc12e..fd798a4e5aed3f0becf581c6125313392602a59f 100644
--- a/protenix/protenix/web_service/colab_request_utils.py
+++ b/protenix/protenix/web_service/colab_request_utils.py
@@ -14,16 +14,14 @@
 
 import logging
 import os
-import random
 import tarfile
 import time
 from typing import List, Tuple
 
+import requests
 from requests.auth import HTTPBasicAuth
 from tqdm import tqdm
 
-import requests
-
 TQDM_BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} estimate remaining: {remaining}]"
 logger = logging.getLogger(__name__)
 
@@ -42,7 +40,7 @@ def run_mmseqs2_service(
     pairing_strategy="greedy",
     host_url="https://api.colabfold.com",
     user_agent: str = "",
-    email: str = ""
+    email: str = "",
 ) -> Tuple[List[str], List[str]]:
     submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa"
 
diff --git a/protenix/protenix/web_service/prediction_visualization.py b/protenix/protenix/web_service/prediction_visualization.py
index c24e0b6cc2557813e2cd8d34feb10fd53566e185..37a3ab517d0b9d3091cc60826248aa0913743829 100644
--- a/protenix/protenix/web_service/prediction_visualization.py
+++ b/protenix/protenix/web_service/prediction_visualization.py
@@ -282,63 +282,73 @@ def plot_best_confidence_measure(prediction_fpath: str, output_dir: str = None):
     ):
         print("no full confidence data found, skip ploting.")
         return
-    full_confidence = pred_loader.full_confidence_data[:1]
-    summary_confidence = pred_loader.summary_confidence_data[:1]
-    atom_plddts = [d["atom_plddt"] for d in full_confidence]
-    token_pdes = [d["token_pair_pde"] for d in full_confidence]
-    token_paes = [d["token_pair_pae"] for d in full_confidence]
-    summary_keys = ["plddt", "gpde", "ptm", "iptm"]
-
-    fig, axes = plt.subplots(
-        nrows=3,
-        ncols=1,
-        figsize=(10, 20),
-        gridspec_kw={"height_ratios": [1.5, 2, 2]},
-    )
-    for i in range(3):
-        summary_text = ", ".join(
-            [
-                f"{k}: {v:.4f}"
-                for k, v in summary_confidence[0].items()
-                if k in summary_keys
-            ]
+    data_len = len(pred_loader.full_confidence_data)
+    for data_idx in range(data_len):
+        full_confidence = pred_loader.full_confidence_data[data_idx : data_idx + 1]
+        summary_confidence = pred_loader.summary_confidence_data[
+            data_idx : data_idx + 1
+        ]
+        atom_plddts = [d["atom_plddt"] for d in full_confidence]
+        token_pdes = [d["token_pair_pde"] for d in full_confidence]
+        token_paes = [d["token_pair_pae"] for d in full_confidence]
+        summary_keys = ["plddt", "gpde", "ptm", "iptm"]
+
+        fig, axes = plt.subplots(
+            nrows=3,
+            ncols=1,
+            figsize=(10, 20),
+            gridspec_kw={"height_ratios": [1.5, 2, 2]},
         )
-        if i == 0:
-            axes[i].plot(atom_plddts[0], color="k")
-            axes[i].text(
-                0.5,
-                1.07,
-                summary_text,
-                ha="center",
-                va="center",
-                transform=axes[i].transAxes,
-                fontsize=10,
+        for i in range(3):
+            summary_text = ", ".join(
+                [
+                    f"{k}: {v:.4f}"
+                    for k, v in summary_confidence[0].items()
+                    if k in summary_keys
+                ]
             )
-            axes[i].set_xlabel("Atom ID", fontsize=12)
-            axes[i].set_ylabel("pLDDT", fontsize=12)
-            axes[i].set_ylim([0, 1])
-            axes[i].spines[["right", "top"]].set_visible(False)
+            if i == 0:
+                axes[i].plot(atom_plddts[0], color="k")
+                axes[i].text(
+                    0.5,
+                    1.07,
+                    summary_text,
+                    ha="center",
+                    va="center",
+                    transform=axes[i].transAxes,
+                    fontsize=10,
+                )
+                axes[i].set_xlabel("Atom ID", fontsize=12)
+                axes[i].set_ylabel("pLDDT", fontsize=12)
+                axes[i].set_ylim([0, 1])
+                axes[i].spines[["right", "top"]].set_visible(False)
 
-        else:
-            data_to_plot = token_pdes[0] if i == 1 else token_paes[0]
-            cax = axes[i].matshow(data_to_plot, origin="lower")
-            axes[i].xaxis.tick_bottom()
-            axes[i].set_aspect("equal")
-            axes[i].set_xlabel("Scored Residue", fontsize=12)
-            axes[i].set_ylabel("Aligned Residue", fontsize=12)
-            axes[i].xaxis.set_major_locator(MaxNLocator(3))  # Max 5 ticks on the x-axis
-            axes[i].yaxis.set_major_locator(MaxNLocator(3))  # Max 5 ticks on the y-axis
-            color_bar = fig.colorbar(
-                cax, ax=axes[i], orientation="vertical", pad=0.1, shrink=0.6
-            )
-            cbar_label = (
-                "Predicted Distance Error" if i == 1 else "Predicted Aligned Error"
-            )
-            color_bar.set_label(cbar_label)
+            else:
+                data_to_plot = token_pdes[0] if i == 1 else token_paes[0]
+                cax = axes[i].matshow(data_to_plot, origin="lower")
+                axes[i].xaxis.tick_bottom()
+                axes[i].set_aspect("equal")
+                axes[i].set_xlabel("Scored Residue", fontsize=12)
+                axes[i].set_ylabel("Aligned Residue", fontsize=12)
+                axes[i].xaxis.set_major_locator(
+                    MaxNLocator(3)
+                )  # Max 5 ticks on the x-axis
+                axes[i].yaxis.set_major_locator(
+                    MaxNLocator(3)
+                )  # Max 5 ticks on the y-axis
+                color_bar = fig.colorbar(
+                    cax, ax=axes[i], orientation="vertical", pad=0.1, shrink=0.6
+                )
+                cbar_label = (
+                    "Predicted Distance Error" if i == 1 else "Predicted Aligned Error"
+                )
+                color_bar.set_label(cbar_label)
 
-    out_png_fpath = os.path.join(output_dir, "best_sample_confidence.png")
-    plt.savefig(out_png_fpath)
-    print(f"save best sample confidence plot successfully to {out_png_fpath}")
+        out_png_fpath = os.path.join(
+            output_dir, f"best_sample_confidence_{data_idx}.png"
+        )
+        plt.savefig(out_png_fpath)
+    print(f"save {data_len} best sample confidence plot successfully to {output_dir}")
 
 
 if __name__ == "__main__":
diff --git a/protenix/runner/batch_inference.py b/protenix/runner/batch_inference.py
index acebea432b90423bf5d1c41afd1de3b0acd3c88f..8dc16d2b83c8ad578b9dd979a258a0b36fb1205f 100644
--- a/protenix/runner/batch_inference.py
+++ b/protenix/runner/batch_inference.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import json
 import logging
 import os
@@ -23,18 +24,18 @@ from typing import List, Optional, Union
 import click
 import tqdm
 from Bio import SeqIO
+from rdkit import Chem
+
 from configs.configs_base import configs as configs_base
 from configs.configs_data import data_configs
 from configs.configs_inference import inference_configs
-from rdkit import Chem
-from runner.inference import InferenceRunner, download_infercence_cache, infer_predict
-from runner.msa_search import contain_msa_res, msa_search, msa_search_update
-
 from protenix.config import parse_configs
 from protenix.data.json_maker import cif_to_input_json
 from protenix.data.json_parser import lig_file_to_atom_info
 from protenix.data.utils import pdb_to_cif
 from protenix.utils.logger import get_logger
+from runner.inference import InferenceRunner, download_infercence_cache, infer_predict
+from runner.msa_search import msa_search, update_infer_json
 
 logger = get_logger(__name__)
 
@@ -49,9 +50,7 @@ def init_logging():
     )
 
 
-def generate_infer_jsons(
-    protein_msa_res: dict, ligand_file: str, seeds: List[int] = [101]
-) -> List[str]:
+def generate_infer_jsons(protein_msa_res: dict, ligand_file: str) -> List[str]:
     protein_chains = []
     if len(protein_msa_res) <= 0:
         raise RuntimeError(f"invalid `protein_msa_res` data in {protein_msa_res}")
@@ -135,7 +134,6 @@ def generate_infer_jsons(
         infer_json_files.append(json_file_name)
 
     for smi_ligand_file in smi_ligand_files:
-        one_infer_seq = protein_chains[:]
         with open(smi_ligand_file, "r") as f:
             smile_list = f.readlines()
         one_infer_seq = protein_chains[:]
@@ -161,13 +159,18 @@ def generate_infer_jsons(
     return infer_json_files
 
 
-def get_default_runner(seeds: Optional[list] = None) -> InferenceRunner:
+def get_default_runner(
+    seeds: Optional[tuple] = None,
+    n_cycle: int = 10,
+    n_step: int = 200,
+    n_sample: int = 5,
+) -> InferenceRunner:
     configs_base["use_deepspeed_evo_attention"] = (
-        os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "true"
+        os.environ.get("USE_DEEPSPEED_EVO_ATTENTION", False) == "true"
     )
-    configs_base["model"]["N_cycle"] = 10
-    configs_base["sample_diffusion"]["N_sample"] = 5
-    configs_base["sample_diffusion"]["N_step"] = 200
+    configs_base["model"]["N_cycle"] = n_cycle
+    configs_base["sample_diffusion"]["N_sample"] = n_sample
+    configs_base["sample_diffusion"]["N_step"] = n_step
     configs = {**configs_base, **{"data": data_configs}, **inference_configs}
     configs = parse_configs(
         configs=configs,
@@ -183,7 +186,10 @@ def inference_jsons(
     json_file: str,
     out_dir: str = "./output",
     use_msa_server: bool = False,
-    seeds: list = [101],
+    seeds: tuple = (101,),
+    n_cycle: int = 10,
+    n_step: int = 200,
+    n_sample: int = 5,
 ) -> None:
     """
     infer_json: json file or directory, will run infer with these jsons
@@ -210,66 +216,13 @@ def inference_jsons(
     infer_errors = {}
     inference_configs["dump_dir"] = out_dir
     inference_configs["input_json_path"] = infer_jsons[0]
-    runner = get_default_runner(seeds)
+    runner = get_default_runner(seeds, n_cycle, n_step, n_sample)
     configs = runner.configs
     for idx, infer_json in enumerate(tqdm.tqdm(infer_jsons)):
         try:
-            if use_msa_server:
-                infer_json = msa_search_update(
-                    infer_json, os.path.join(out_dir, f"msa_res_{idx}")
-                )
-            elif not contain_msa_res(infer_json):
-                raise RuntimeError(f"can not find msa for {infer_json}")
-            configs["input_json_path"] = infer_json
-            if not contain_msa_res(infer_json):
-                raise RuntimeError(
-                    f"`{infer_json}` has no msa result for `proteinChain`, please add first."
-                )
-            infer_predict(runner, configs)
-        except Exception as exc:
-            infer_errors[infer_json] = str(exc)
-    if len(infer_errors) > 0:
-        logger.warning(f"run inference failed: {infer_errors}")
-
-
-def batch_inference(
-    protein_msa_res: dict,
-    ligand_file: str,
-    out_dir: str = "./output",
-    seeds: List[int] = [101],
-) -> None:
-    """
-    ligand_file: ligand file or directory, should be in sdf format or smi with smlies list;
-    protein_msa_res: the msa result for `protein`, like:
-        {  "MGHHHHHHHHHHSSGH": {
-                "precomputed_msa_dir": "/path/to/msa_pairing/result/msa/1",
-                "pairing_db": "uniref100"
-            },
-            "MAEVIRSSAFWRSFPIFEEFDSE": {
-                "precomputed_msa_dir": "/path/to/msa_pairing/result/msa/2",
-                "pairing_db": "uniref100"
-            }
-        }
-    out_dir: the infer outout dir, default is `./output`
-    """
-
-    infer_jsons = generate_infer_jsons(protein_msa_res, ligand_file, seeds)
-    logger.info(f"will infer with {len(infer_jsons)} jsons")
-    if len(infer_jsons) == 0:
-        return
-
-    infer_errors = {}
-    inference_configs["dump_dir"] = out_dir
-    inference_configs["input_json_path"] = infer_jsons[0]
-    runner = get_default_runner(seeds=seeds)
-    configs = runner.configs
-    for infer_json in tqdm.tqdm(infer_jsons):
-        try:
-            configs["input_json_path"] = infer_json
-            if not contain_msa_res(infer_json):
-                raise RuntimeError(
-                    f"`{infer_json}` has no msa result for `proteinChain`, please add first."
-                )
+            configs["input_json_path"] = update_infer_json(
+                infer_json, out_dir=out_dir, use_msa_server=use_msa_server
+            )
             infer_predict(runner, configs)
         except Exception as exc:
             infer_errors[infer_json] = str(exc)
@@ -288,8 +241,11 @@ def protenix_cli():
 @click.option(
     "--seeds", type=str, default="101", help="the inference seed, split by comma"
 )
+@click.option("--cycle", type=int, default=10, help="pairformer cycle number")
+@click.option("--step", type=int, default=200, help="diffusion step")
+@click.option("--sample", type=int, default=5, help="sample number")
 @click.option("--use_msa_server", is_flag=True, help="do msa search or not")
-def predict(input, out_dir, seeds, use_msa_server):
+def predict(input, out_dir, seeds, cycle, step, sample, use_msa_server):
     """
     predict: Run predictions with protenix.
     :param input, out_dir, use_msa_server
@@ -297,10 +253,18 @@ def predict(input, out_dir, seeds, use_msa_server):
     """
     init_logging()
     logger.info(
-        f"run infer with input={input}, out_dir={out_dir}, use_msa_server={use_msa_server}"
+        f"run infer with input={input}, out_dir={out_dir}, cycle={cycle}, step={step}, sample={sample}, use_msa_server={use_msa_server}"
     )
     seeds = list(map(int, seeds.split(",")))
-    inference_jsons(input, out_dir, use_msa_server, seeds=seeds)
+    inference_jsons(
+        input,
+        out_dir,
+        use_msa_server,
+        seeds=seeds,
+        n_cycle=cycle,
+        n_step=step,
+        n_sample=sample,
+    )
 
 
 @click.command()
@@ -395,11 +359,9 @@ def msa(input, out_dir) -> Union[str, dict]:
     :return:
     """
     init_logging()
-    out_dir = os.path.join(out_dir, uuid.uuid4().hex)
-    os.makedirs(out_dir, exist_ok=True)
     logger.info(f"run msa with input={input}, out_dir={out_dir}")
     if input.endswith(".json"):
-        msa_input_json = msa_search_update(input, out_dir)
+        msa_input_json = update_infer_json(input, out_dir, use_msa_server=True)
         logger.info(f"msa results have been update to {msa_input_json}")
         return msa_input_json
     elif input.endswith(".fasta"):
@@ -424,22 +386,5 @@ protenix_cli.add_command(tojson)
 protenix_cli.add_command(msa)
 
 
-def test_batch_inference():
-    ligands_dir = "../examples/ligands"
-    protein_msa_res = {
-        "MASWSHPQFEKGGTHVAETSAPTRSEPDTRVLTLPGTASAPEFRLIDIDGLLNNRATTDVRDLGSGRLNAWGNSFPAAELPAPGSLITVAGIPFTWANAHARGDNIRCEGQVVDIPPGQYDWIYLLAASERRSEDTIWAHYDDGHADPLRVGISDFLDGTPAFGELSAFRTSRMHYPHHVQEGLPTTMWLTRVGMPRHGVARSLRLPRSVAMHVFALTLRTAAAVRLAEGATT": {
-            "precomputed_msa_dir": "../examples/7wux/msa/1",
-            "pairing_db": "uniref100",
-        },
-        "MGSSHHHHHHSQDPNSTTTAPPVELWTRDLGSCLHGTLATALIRDGHDPVTVLGAPWEFRRRPGAWSSEEYFFFAEPDSLAGRLALYHPFESTWHRSDGDGVDDLREALAAGVLPIAAVDNFHLPFRPAFHDVHAAHLLVVYRITETEVYVSDAQPPAFQGAIPLADFLASWGSLNPPDDADVFFSASPSGRRWLRTRMTGPVPEPDRHWVGRVIRENVARYRQEPPADTQTGLPGLRRYLDELCALTPGTNAASEALSELYVISWNIQAQSGLHAEFLRAHSVKWRIPELAEAAAGVDAVAHGWTGVRMTGAHSRVWQRHRPAELRGHATALVRRLEAALDLLELAADAVS": {
-            "precomputed_msa_dir": "../examples/7wux/msa/2",
-            "pairing_db": "uniref100",
-        },
-    }
-    out_dir = "./infer_output"
-    batch_inference(protein_msa_res, ligands_dir, out_dir=out_dir)
-
-
 if __name__ == "__main__":
-    init_logging()
-    test_batch_inference()
+    predict()
diff --git a/protenix/runner/dumper.py b/protenix/runner/dumper.py
index 06ac2e57381c8857971f978d301f79f2fdb868cc..e6e42c2878319a2054755db7d1954f4529dc8fc3 100644
--- a/protenix/runner/dumper.py
+++ b/protenix/runner/dumper.py
@@ -44,9 +44,15 @@ def get_clean_full_confidence(full_confidence_dict: dict) -> dict:
 
 
 class DataDumper:
-    def __init__(self, base_dir, need_atom_confidence: bool = False):
+    def __init__(
+        self,
+        base_dir,
+        need_atom_confidence: bool = False,
+        sorted_by_ranking_score: bool = True,
+    ) -> None:
         self.base_dir = base_dir
         self.need_atom_confidence = need_atom_confidence
+        self.sorted_by_ranking_score = sorted_by_ranking_score
 
     def dump(
         self,
@@ -156,7 +162,7 @@ class DataDumper:
         N_sample = pred_coordinates.shape[0]
         if sorted_indices is None:
             sorted_indices = range(N_sample)  # do not rank the output file
-        for rank, idx in enumerate(sorted_indices):
+        for idx, rank in enumerate(sorted_indices):
             output_fpath = os.path.join(
                 prediction_save_dir,
                 f"{sample_name}_seed_{seed}_sample_{rank}.cif",
@@ -173,15 +179,20 @@ class DataDumper:
                 pdb_id=sample_name,
             )
 
-    def _get_ranker_indices(self, data: dict, sorted_by_ranking_score: bool = True):
+    def _get_ranker_indices(self, data: dict):
         N_sample = len(data["summary_confidence"])
-        sorted_indices = range(N_sample)
-        if sorted_by_ranking_score:
-            sorted_indices = sorted(
-                range(N_sample),
-                key=lambda i: data["summary_confidence"][i]["ranking_score"],
-                reverse=True,
+        if self.sorted_by_ranking_score:
+            value = torch.tensor(
+                [
+                    data["summary_confidence"][i]["ranking_score"]
+                    for i in range(N_sample)
+                ]
             )
+            sorted_indices = [
+                i for i in torch.argsort(torch.argsort(value, descending=True))
+            ]
+        else:
+            sorted_indices = [i for i in range(N_sample)]
         return sorted_indices
 
     def _save_confidence(
@@ -200,7 +211,7 @@ class DataDumper:
                 )
         if sorted_indices is None:
             sorted_indices = range(N_sample)
-        for rank, idx in enumerate(sorted_indices):
+        for idx, rank in enumerate(sorted_indices):
             output_fpath = os.path.join(
                 prediction_save_dir,
                 f"{sample_name}_seed_{seed}_summary_confidence_sample_{rank}.json",
diff --git a/protenix/runner/inference.py b/protenix/runner/inference.py
index 3c77b7eb2f6268ff9ffcca5d74ccc9e67be1e203..5b3e1b172126c4a26d9581cbef31e9f24005f119 100644
--- a/protenix/runner/inference.py
+++ b/protenix/runner/inference.py
@@ -23,10 +23,11 @@ from typing import Any, Mapping
 
 import torch
 import torch.distributed as dist
-
 from configs.configs_base import configs as configs_base
 from configs.configs_data import data_configs
 from configs.configs_inference import inference_configs
+from runner.dumper import DataDumper
+
 from protenix.config import parse_configs, parse_sys_args
 from protenix.data.infer_data_pipeline import get_inference_dataloader
 from protenix.model.protenix import Protenix
@@ -34,8 +35,6 @@ from protenix.utils.distributed import DIST_WRAPPER
 from protenix.utils.seed import seed_everything
 from protenix.utils.torch_utils import to_device
 from protenix.web_service.dependency_url import URL
-from runner.dumper import DataDumper
-from runner.msa_search import contain_msa_res, msa_search_update
 
 logger = logging.getLogger(__name__)
 
@@ -47,7 +46,10 @@ class InferenceRunner(object):
         self.init_basics()
         self.init_model()
         self.load_checkpoint()
-        self.init_dumper(need_atom_confidence=configs.need_atom_confidence)
+        self.init_dumper(
+            need_atom_confidence=configs.need_atom_confidence,
+            sorted_by_ranking_score=configs.sorted_by_ranking_score,
+        )
 
     def init_env(self) -> None:
         self.print(
@@ -112,14 +114,18 @@ class InferenceRunner(object):
             }
         self.model.load_state_dict(
             state_dict=checkpoint["model"],
-            strict=True,
+            strict=self.configs.load_strict,
         )
         self.model.eval()
         self.print(f"Finish loading checkpoint.")
 
-    def init_dumper(self, need_atom_confidence: bool = False):
+    def init_dumper(
+        self, need_atom_confidence: bool = False, sorted_by_ranking_score: bool = True
+    ):
         self.dumper = DataDumper(
-            base_dir=self.dump_dir, need_atom_confidence=need_atom_confidence
+            base_dir=self.dump_dir,
+            need_atom_confidence=need_atom_confidence,
+            sorted_by_ranking_score=sorted_by_ranking_score,
         )
 
     # Adapted from runner.train.Trainer.evaluate
@@ -157,30 +163,29 @@ class InferenceRunner(object):
 
 
 def download_infercence_cache(configs: Any, model_version: str = "v0.2.0") -> None:
-    current_file_path = os.path.abspath(__file__)
-    current_directory = os.path.dirname(current_file_path)
-    code_directory = os.path.dirname(current_directory)
-
-    data_cache_dir = os.path.join(code_directory, "release_data/ccd_cache")
-    os.makedirs(data_cache_dir, exist_ok=True)
-    for cache_name, fname in [
-        ("ccd_components_file", "components.v20240608.cif"),
-        ("ccd_components_rdkit_mol_file", "components.v20240608.cif.rdkit_mol.pkl"),
-    ]:
-        if not opexists(cache_path := os.path.abspath(opjoin(data_cache_dir, fname))):
+
+    for cache_name in ("ccd_components_file", "ccd_components_rdkit_mol_file"):
+        cur_cache_fpath = configs["data"][cache_name]
+        if not opexists(cur_cache_fpath):
+            os.makedirs(os.path.dirname(cur_cache_fpath), exist_ok=True)
             tos_url = URL[cache_name]
-            logger.info(f"Downloading data cache from\n {tos_url}...")
-            urllib.request.urlretrieve(tos_url, cache_path)
+            assert os.path.basename(tos_url) == os.path.basename(cur_cache_fpath), (
+                f"{cache_name} file name is incorrect, `{tos_url}` and "
+                f"`{cur_cache_fpath}`. Please check and try again."
+            )
+            logger.info(
+                f"Downloading data cache from\n {tos_url}... to {cur_cache_fpath}"
+            )
+            urllib.request.urlretrieve(tos_url, cur_cache_fpath)
 
     checkpoint_path = configs.load_checkpoint_path
 
     if not opexists(checkpoint_path):
-        checkpoint_path = os.path.join(
-            code_directory, f"release_data/checkpoint/model_{model_version}.pt"
-        )
         os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
         tos_url = URL[f"model_{model_version}"]
-        logger.info(f"Downloading model checkpoint from\n {tos_url}...")
+        logger.info(
+            f"Downloading model checkpoint from\n {tos_url}... to {checkpoint_path}"
+        )
         urllib.request.urlretrieve(tos_url, checkpoint_path)
         try:
             ckpt = torch.load(checkpoint_path)
@@ -191,7 +196,6 @@ def download_infercence_cache(configs: Any, model_version: str = "v0.2.0") -> No
                 "Download model checkpoint failed, please download by yourself with "
                 f"wget {tos_url} -O {checkpoint_path}"
             )
-        configs.load_checkpoint_path = checkpoint_path
 
 
 def update_inference_configs(configs: Any, N_token: int):
@@ -211,38 +215,31 @@ def update_inference_configs(configs: Any, N_token: int):
 
 
 def infer_predict(runner: InferenceRunner, configs: Any) -> None:
-    # update msa result if not contains precomputed msa dir
-    if not contain_msa_res(configs.input_json_path):
-        logger.info(
-            f"{configs.input_json_path} dose not contain precomputed msa dir, now searching it."
-        )
-        configs.input_json_path = msa_search_update(
-            configs.input_json_path, configs.dump_dir
-        )
-        logger.info(
-            f"msa searching completed, new input json is {configs.input_json_path}"
-        )
     # Data
     logger.info(f"Loading data from\n{configs.input_json_path}")
-    dataloader = get_inference_dataloader(configs=configs)
+    try:
+        dataloader = get_inference_dataloader(configs=configs)
+    except Exception as e:
+        error_message = f"{e}:\n{traceback.format_exc()}"
+        logger.info(error_message)
+        with open(opjoin(runner.error_dir, "error.txt"), "a") as f:
+            f.write(error_message)
+        return
 
     num_data = len(dataloader.dataset)
     for seed in configs.seeds:
-        seed_everything(seed=seed, deterministic=False)
+        seed_everything(seed=seed, deterministic=configs.deterministic)
         for batch in dataloader:
             try:
                 data, atom_array, data_error_message = batch[0]
+                sample_name = data["sample_name"]
 
                 if len(data_error_message) > 0:
                     logger.info(data_error_message)
-                    with open(
-                        opjoin(runner.error_dir, f"{data['sample_name']}.txt"),
-                        "w",
-                    ) as f:
+                    with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f:
                         f.write(data_error_message)
                     continue
 
-                sample_name = data["sample_name"]
                 logger.info(
                     (
                         f"[Rank {DIST_WRAPPER.rank} ({data['sample_index'] + 1}/{num_data})] {sample_name}: "
@@ -271,15 +268,10 @@ def infer_predict(runner: InferenceRunner, configs: Any) -> None:
                 error_message = f"[Rank {DIST_WRAPPER.rank}]{data['sample_name']} {e}:\n{traceback.format_exc()}"
                 logger.info(error_message)
                 # Save error info
-                if opexists(
-                    error_path := opjoin(runner.error_dir, f"{sample_name}.txt")
-                ):
-                    os.remove(error_path)
-                with open(error_path, "w") as f:
+                with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f:
                     f.write(error_message)
                 if hasattr(torch.cuda, "empty_cache"):
                     torch.cuda.empty_cache()
-                raise RuntimeError(f"run infer failed: {str(e)}")
 
 
 def main(configs: Any) -> None:
@@ -297,7 +289,7 @@ def run() -> None:
         filemode="w",
     )
     configs_base["use_deepspeed_evo_attention"] = (
-        os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "true"
+        os.environ.get("USE_DEEPSPEED_EVO_ATTENTION", False) == "true"
     )
     configs = {**configs_base, **{"data": data_configs}, **inference_configs}
     configs = parse_configs(
diff --git a/protenix/runner/msa_search.py b/protenix/runner/msa_search.py
index 4c1008c95d2871470aa11b71783a313bd7c7442f..d352a052dc7c2f29cb4389d430b197f96873cb97 100644
--- a/protenix/runner/msa_search.py
+++ b/protenix/runner/msa_search.py
@@ -22,33 +22,18 @@ from protenix.web_service.colab_request_parser import RequestParser
 logger = get_logger(__name__)
 
 
-def contain_msa_res(json_file: str) -> bool:
-    """
-    check the json_path data has msa result or not.
-    """
-    if not os.path.exists(json_file):
-        raise RuntimeError(f"`{json_file}` not exists.")
-    with open(json_file, "r") as f:
-        json_data = json.load(f)
-    for seq in json_data:
-        for sequence in seq["sequences"]:
-            if "proteinChain" in sequence.keys():
-                proteinChain = sequence["proteinChain"]
-                if "msa" not in proteinChain.keys() or len(proteinChain["msa"]) == 0:
-                    return False
-    return True
-
-
-def update_msa_res(seq: dict, protein_msa_res: dict) -> dict:
-    for sequence in seq["sequences"]:
+def need_msa_search(json_data: dict) -> bool:
+    need_msa = json_data.get("use_msa", True)
+    # TODO: add esm check
+    if not need_msa:
+        return need_msa
+    need_msa = False
+    for sequence in json_data["sequences"]:
         if "proteinChain" in sequence.keys():
-            sequence["proteinChain"]["msa"] = {
-                "precomputed_msa_dir": protein_msa_res[
-                    sequence["proteinChain"]["sequence"]
-                ],
-                "pairing_db": "uniref100",
-            }
-    return seq
+            proteinChain = sequence["proteinChain"]
+            if "msa" not in proteinChain.keys() or len(proteinChain["msa"]) == 0:
+                need_msa = True
+    return need_msa
 
 
 def msa_search(seqs: Sequence[str], msa_res_dir: str) -> Sequence[str]:
@@ -69,32 +54,68 @@ def msa_search(seqs: Sequence[str], msa_res_dir: str) -> Sequence[str]:
     return msa_res_subdirs
 
 
-def msa_search_update(json_file: str, out_dir: str) -> str:
+def update_seq_msa(infer_seq: dict, msa_res_dir: str) -> dict:
+    protein_seqs = []
+    for sequence in infer_seq["sequences"]:
+        if "proteinChain" in sequence.keys():
+            protein_seqs.append(sequence["proteinChain"]["sequence"])
+    if len(protein_seqs) > 0:
+        protein_seqs = sorted(protein_seqs)
+        msa_res_subdirs = msa_search(protein_seqs, msa_res_dir)
+        assert len(msa_res_subdirs) == len(msa_res_subdirs), "msa search failed"
+        protein_msa_res = dict(zip(protein_seqs, msa_res_subdirs))
+        for sequence in infer_seq["sequences"]:
+            if "proteinChain" in sequence.keys():
+                sequence["proteinChain"]["msa"] = {
+                    "precomputed_msa_dir": protein_msa_res[
+                        sequence["proteinChain"]["sequence"]
+                    ],
+                    "pairing_db": "uniref100",
+                }
+    return infer_seq
+
+
+def update_infer_json(
+    json_file: str, out_dir: str, use_msa_server: bool = False
+) -> str:
     """
-    do msa search with mmseqs from json input and update it.
+    update json file for inference.
+    for every infer_data, if it needs to update msa result info,
+    it will run msa searching if use_msa_server is True,
+    else it will raise error.
+    if it does not need to update msa result info, then pass.
     """
-    assert os.path.exists(json_file), f"input file {json_file} not exists."
-    if contain_msa_res(json_file):
-        logger.warning(f"{json_file} has already msa result, skip.")
-        return json_file
+    if not os.path.exists(json_file):
+        raise RuntimeError(f"`{json_file}` not exists.")
     with open(json_file, "r") as f:
-        input_json_data = json.load(f)
-    for seq_idx, seq in enumerate(input_json_data):
-        protein_seqs = []
-        for sequence in seq["sequences"]:
-            if "proteinChain" in sequence.keys():
-                protein_seqs.append(sequence["proteinChain"]["sequence"])
-        if len(protein_seqs) > 0:
-            protein_seqs = sorted(protein_seqs)
-            msa_res_subdirs = msa_search(
-                protein_seqs, os.path.join(out_dir, f"msa_seq_{seq_idx}")
-            )
-            assert len(msa_res_subdirs) == len(msa_res_subdirs), "msa search failed"
-            update_msa_res(seq, dict(zip(protein_seqs, msa_res_subdirs)))
-    msa_input_json = os.path.join(
-        os.path.dirname(json_file),
-        f"{os.path.splitext(os.path.basename(json_file))[0]}-add-msa.json",
-    )
-    with open(msa_input_json, "w") as f:
-        json.dump(input_json_data, f, indent=4)
-    return msa_input_json
+        json_data = json.load(f)
+
+    actual_updated = False
+    for seq_idx, infer_data in enumerate(json_data):
+        if need_msa_search(infer_data):
+            actual_updated = True
+            if use_msa_server:
+                seq_name = infer_data.get("name", f"seq_{seq_idx}")
+                logger.info(
+                    f"starting to update msa result for seq {seq_idx} in {json_file}"
+                )
+                update_seq_msa(
+                    infer_data,
+                    os.path.join(out_dir, seq_name, "msa_res" f"msa_seq_{seq_idx}"),
+                )
+            else:
+                raise RuntimeError(
+                    f"infer seq {seq_idx} in `{json_file}` has no msa result, please add first."
+                )
+    if actual_updated:
+        updated_json = os.path.join(
+            os.path.dirname(os.path.abspath(json_file)),
+            f"{os.path.splitext(os.path.basename(json_file))[0]}-add-msa.json",
+        )
+        with open(updated_json, "w") as f:
+            json.dump(json_data, f, indent=4)
+        logger.info(f"update msa result success and save to {updated_json}")
+        return updated_json
+    else:
+        logger.info(f"do not need to update msa result, so return itself {json_file}")
+        return json_file
\ No newline at end of file
diff --git a/protenix/runner/train.py b/protenix/runner/train.py
index d96f6bf8ca500aea227169ea77c7dbe6fb06e0c9..5241e9fcc712423536d28982f0d15048ec2b09b3 100644
--- a/protenix/runner/train.py
+++ b/protenix/runner/train.py
@@ -552,7 +552,7 @@ def main():
         filemode="w",
     )
     configs_base["use_deepspeed_evo_attention"] = (
-        os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "true"
+        os.environ.get("USE_DEEPSPEED_EVO_ATTENTION", False) == "true"
     )
     configs = {**configs_base, **{"data": data_configs}}
     configs = parse_configs(
diff --git a/run_all.py b/run_all.py
index 356b3228add482370f52ea14c9e1f43ea52f16c0..376990f80e9dd6c9ac97ae4893120438a72593f4 100644
--- a/run_all.py
+++ b/run_all.py
@@ -37,30 +37,29 @@ class Runner:
             print(command)
             os.system(command)
 
-    def run_chai(self, config_names, input_dir, output_dir, seeds, num_diffusion_samples, msa_pqt_dirs):
-        for config_name, msa_pqt_dir in zip(config_names, msa_pqt_dirs):
-            fasta_path = f'{input_dir}/{config_name}.fasta'
-            for seed in seeds:
-                command = (
-                    f'CHAI_DOWNLOADS_DIR=/ai/share/workspace/zhangcb/Chai/ckpts python run_chai.py {fasta_path} {msa_pqt_dir} '
-                    f' {output_dir}/{config_name}/seed_{seed} {num_diffusion_samples} {seed} '
-                )
-                print(command)
-                os.system(command)
+    def run_chai(self, config_names, input_dir, output_dir, seeds, num_diffusion_samples, msa_dir):
+        str_seeds = ','.join([repr(seed) for seed in seeds])
+        command = (
+            'CHAI_DOWNLOADS_DIR=/ai/share/workspace/zhangcb/Chai/ckpts '
+            f' /ai/share/workspace/zhangcb/conda_env/chai/bin/python run_chai.py {input_dir} {msa_dir} '
+            f' {output_dir} {num_diffusion_samples} {str_seeds} '
+        )
+        print(command)
+        os.system(command)
 
     def run_protenix(self, json_path, output_dir, seeds, num_diffusion_samples):
-        for seed in seeds:
-            command = (
-                'LAYERNORM_TYPE=fast_layernorm /ai/share/workspace/zhangcb/conda_env/protenix/bin/python protenix/inference.py '
-                f' --input_json_path {json_path} '
-                f' --dump_dir {output_dir} '
-                f' --seeds {seed} '
-                ' --model.N_cycle 10 '
-                f' --sample_diffusion.N_sample {num_diffusion_samples} '
-                ' --sample_diffusion.N_step 200 '
-            )
-            print(command)
-            os.system(command)
+        str_seeds = ','.join([repr(seed) for seed in seeds])
+        command = (
+            'LAYERNORM_TYPE=fast_layernorm /ai/share/workspace/zhangcb/conda_env/protenix/bin/python protenix/inference.py '
+            f' --input_json_path {json_path} '
+            f' --dump_dir {output_dir} '
+            f' --seeds {str_seeds} '
+            ' --model.N_cycle 10 '
+            f' --sample_diffusion.N_sample {num_diffusion_samples} '
+            ' --sample_diffusion.N_step 200 '
+        )
+        print(command)
+        os.system(command)
     
     def run(
             self, *, 
@@ -79,7 +78,6 @@ class Runner:
         config_names = [inputs['name'] for inputs in inputs_list]
         config_dir = f'{output_dir}/configs'
         struct_dir = f'{output_dir}/preds'
-        msa_pqt_dirs = []
         print('[Searching for MSA]')
         if not os.path.exists(f'{output_dir}/msa/raw/{inputs_list[0]["name"]}.a3m'):
             msa_batch_size = 200
@@ -97,16 +95,18 @@ class Runner:
         print('[Making configs]')
         for inputs in inputs_list:
             msa_proc_dir = f'{output_dir}/msa/processed/{inputs["name"]}'
-            get_configs(
-                name=inputs['name'],
-                config_dir=config_dir,
-                msa_raw_file=f'{output_dir}/msa/raw/{inputs["name"]}.a3m',
-                msa_proc_dir=msa_proc_dir,
-                seqs=inputs['seqs'], seq_cnts=inputs['seq_cnts'],
-                ccd_codes=inputs['ccd_codes'], smis=inputs['smis'],
-                seeds=seeds, tos_tag=tos_tag
-            )
-            msa_pqt_dirs.append(f'{msa_proc_dir}/protenix/')
+            if not os.path.exists(f'{config_dir}/af3/{inputs["name"]}.json'):
+                get_configs(
+                    name=inputs['name'],
+                    config_dir=config_dir,
+                    msa_raw_file=f'{output_dir}/msa/raw/{inputs["name"]}.a3m',
+                    msa_proc_dir=msa_proc_dir,
+                    seqs=inputs['seqs'], seq_cnts=inputs['seq_cnts'],
+                    ccd_codes=inputs['ccd_codes'], smis=inputs['smis'],
+                    seeds=seeds, tos_tag=tos_tag
+                )
+            else:
+                print('Skip making config for', inputs['name'])
         #
         print('[Predicting]')
         if not disable_af3:
@@ -117,7 +117,7 @@ class Runner:
             self.run_boltz(f'{config_dir}/boltz', f'{struct_dir}/boltz', seeds, num_diffusion_samples)
         if not disable_chai:
             print('\n\t[Chai-1]')
-            self.run_chai(config_names, f'{config_dir}/chai', f'{struct_dir}/chai', seeds, num_diffusion_samples, msa_pqt_dirs)
+            self.run_chai(config_names, f'{config_dir}/chai', f'{struct_dir}/chai', seeds, num_diffusion_samples, f'{output_dir}/msa/processed/')
         if not disable_protenix:
             print('\n\t[Protenix]')
             self.run_protenix(f'{config_dir}/protenix.json', f'{struct_dir}/protenix', seeds, num_diffusion_samples)
@@ -174,4 +174,4 @@ def main():
 
 
 if __name__ == '__main__':
-    main()
\ No newline at end of file
+    main()
diff --git a/run_alphafold.py b/run_alphafold.py
index 23720baa9dd088830a520ff1139aacdf7741b87c..7aeadb55d4511d2ed5ca8390022e46df39d78bd6 100755
--- a/run_alphafold.py
+++ b/run_alphafold.py
@@ -572,7 +572,10 @@ def process_fold_input(
     print('Skipping data pipeline...')
   else:
     print('Running data pipeline...')
-    fold_input = pipeline.DataPipeline(data_pipeline_config).process(fold_input)
+    try:
+      fold_input = pipeline.DataPipeline(data_pipeline_config).process(fold_input)
+    except:
+      return 1
 
   print(f'Output directory: {output_dir}')
   print(f'Writing model input JSON to {output_dir}')
diff --git a/run_chai.py b/run_chai.py
index 443f0f0e690fbe11546afb07adaa296621c6c310..89078d195671879e072dff935b7bd7687d28027b 100644
--- a/run_chai.py
+++ b/run_chai.py
@@ -1,27 +1,34 @@
+import os
 import sys
+from glob import glob
 from pathlib import Path
 import numpy as np
 from chai_lab.chai1 import run_inference
 
 
 def main():
-    fasta_path, msa_dir, output_dir, num_diffn_samples, seed = sys.argv[1 :]
+    fasta_dir, msa_dir, output_dir, num_diffn_samples, str_seeds = sys.argv[1 :]
+    fasta_paths = glob(f'{fasta_dir}/*.fasta')
     # Generate structure
-    candidates = run_inference(
-        fasta_file=Path(fasta_path),
-        output_dir=Path(output_dir),
-        # 'default' setup
-        num_trunk_recycles=3,
-        num_diffn_timesteps=200,
-        num_diffn_samples=int(num_diffn_samples),
-        seed=(seed),
-        device="cuda:0",
-        use_esm_embeddings=True,
-        # See example .aligned.pqt files in this directory
-        msa_directory=Path(msa_dir),
-        # Exclusive with msa_directory; can be used for MMseqs2 server MSA generation
-        use_msa_server=False,
-    )
+    for fasta_path in fasta_paths:
+        config_name = os.path.basename(fasta_path).split('.')[0]
+        for seed in str_seeds.split(','):
+            print(fasta_path, seed)
+            candidates = run_inference(
+                fasta_file=Path(fasta_path),
+                output_dir=Path(f'{output_dir}/{config_name}/seed_{seed}'),
+                # 'default' setup
+                num_trunk_recycles=3,
+                num_diffn_timesteps=200,
+                num_diffn_samples=int(num_diffn_samples),
+                seed=int(seed),
+                device="cuda:0",
+                use_esm_embeddings=True,
+                # See example .aligned.pqt files in this directory
+                msa_directory=Path(f'{msa_dir}/{config_name}/chai/'),
+                # Exclusive with msa_directory; can be used for MMseqs2 server MSA generation
+                use_msa_server=False,
+            )
     # cif_paths = candidates.cif_paths
     # scores = [rd.aggregate_score for rd in candidates.ranking_data]