Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import sys
import pandas as pd
from tqdm import tqdm
from multiprocessing import Pool
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
from evaluate_DTI import evaluate
def func(args):
*_, gt_cif_path, pred_cif_path = args
return evaluate(gt_cif_path, pred_cif_path)
def main():
cif_dir = '/ai/share/workspace/zhangcb/Datasets/BindingDB/demo_set/complexes/structures_raw'
case_ind2cif_names = pd.read_csv('/ai/share/workspace/zhangcb/Datasets/BindingDB/demo_set/case_ind2cif_names.tsv', sep='\t')
#
for exp_tag in ['train_native', 'train_finetuned_epoch_120'][: 1]:
if 'train' in exp_tag:
phase = 'train'
else:
phase = 'test'
pred_dir = f'/ai/share/workspace/zhangcb/Boltz/exp_infer/output/BindingDB/demo_set/{exp_tag}/boltz_results_{phase}_set/predictions'
save_csv_name = f'rmsd_scores/{exp_tag}.csv'
#
case_ind2cif_names = {row['case_ind'] : (eval(row['cif_names']), row['sim2train']) for _, row in case_ind2cif_names.iterrows()}
case_inds = [int(filename.split('_')[1]) for filename in os.listdir(pred_dir)]
info_table = {'case_ind': [], 'gt_cif_name': [], 'sim2train': [], 'global_CA_rmsd': [], 'pocket_rmsd': [], 'ligand_rmsd': []}
args = []
for case_ind in case_inds:
pred_cif_path = f'{pred_dir}/case_{case_ind}/case_{case_ind}_model_0.cif'
sim2train = case_ind2cif_names[case_ind][1]
for cif_name in case_ind2cif_names[case_ind][0]:
gt_cif_path = f'{cif_dir}/{cif_name}.cif'
args.append((case_ind, cif_name, sim2train, gt_cif_path, pred_cif_path))
print('#args =', len(args))
with Pool(processes=112) as pool:
results = list(tqdm(pool.imap(func, args), total=len(args), desc='Processing'))
for ((case_ind, cif_name, sim2train, _, _), (global_CA_rmsd, pocket_rmsd, ligand_rmsd)) in zip(args, results):
# global_CA_rmsd, pocket_rmsd, ligand_rmsd = evaluate(gt_cif_path, pred_cif_path)
info_table['case_ind'].append(case_ind)
info_table['gt_cif_name'].append(cif_name)
info_table['sim2train'].append(sim2train)
info_table['global_CA_rmsd'].append(global_CA_rmsd)
info_table['pocket_rmsd'].append(pocket_rmsd)
info_table['ligand_rmsd'].append(ligand_rmsd)
pd.DataFrame(info_table).to_csv(save_csv_name, index=False)
if __name__ == '__main__':
main()