eval_demoBdb.py 2.47 KB
Newer Older
chenbin zhang's avatar
chenbin zhang committed
import os
import sys
import pandas as pd
from tqdm import tqdm
from multiprocessing import Pool
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
from evaluate_DTI import evaluate


def func(args):
    *_, gt_cif_path, pred_cif_path = args
    return evaluate(gt_cif_path, pred_cif_path)


def main():
    cif_dir = '/ai/share/workspace/zhangcb/Datasets/BindingDB/demo_set/complexes/structures_raw'
    case_ind2cif_names = pd.read_csv('/ai/share/workspace/zhangcb/Datasets/BindingDB/demo_set/case_ind2cif_names.tsv', sep='\t')
    #
    for exp_tag in ['train_native', 'train_finetuned_epoch_120'][: 1]:
        if 'train' in exp_tag:
            phase = 'train'
        else:
            phase = 'test'
        pred_dir = f'/ai/share/workspace/zhangcb/Boltz/exp_infer/output/BindingDB/demo_set/{exp_tag}/boltz_results_{phase}_set/predictions'
        save_csv_name = f'rmsd_scores/{exp_tag}.csv'
        #
        case_ind2cif_names = {row['case_ind'] : (eval(row['cif_names']), row['sim2train']) for _, row in case_ind2cif_names.iterrows()}
        case_inds = [int(filename.split('_')[1]) for filename in os.listdir(pred_dir)]
        info_table = {'case_ind': [], 'gt_cif_name': [], 'sim2train': [], 'global_CA_rmsd': [], 'pocket_rmsd': [], 'ligand_rmsd': []}
        args = []
        for case_ind in case_inds:
            pred_cif_path = f'{pred_dir}/case_{case_ind}/case_{case_ind}_model_0.cif'
            sim2train = case_ind2cif_names[case_ind][1]
            for cif_name in case_ind2cif_names[case_ind][0]:
                gt_cif_path = f'{cif_dir}/{cif_name}.cif'
                args.append((case_ind, cif_name, sim2train, gt_cif_path, pred_cif_path))
        print('#args =', len(args))
        with Pool(processes=112) as pool:
            results = list(tqdm(pool.imap(func, args), total=len(args), desc='Processing'))
        for ((case_ind, cif_name, sim2train, _, _), (global_CA_rmsd, pocket_rmsd, ligand_rmsd)) in zip(args, results):
            # global_CA_rmsd, pocket_rmsd, ligand_rmsd = evaluate(gt_cif_path, pred_cif_path)
            info_table['case_ind'].append(case_ind)
            info_table['gt_cif_name'].append(cif_name)
            info_table['sim2train'].append(sim2train)
            info_table['global_CA_rmsd'].append(global_CA_rmsd)
            info_table['pocket_rmsd'].append(pocket_rmsd)
            info_table['ligand_rmsd'].append(ligand_rmsd)
        pd.DataFrame(info_table).to_csv(save_csv_name, index=False)


if __name__ == '__main__':
    main()