How about RNA inverse folding?

Learning to design RNA

如何优化rna逆合成模型

这是一个很challenging的问题,别问,问就是啥都不会。

考虑到gRNAde: Geometric Deep Learning for 3D RNA Inverse Design论文的细节、思路非常清晰,且作者release了非常具象的source code,因而以gRNAde等state-of-the-art模型入手分析问题,作者代码开源地址:https://github.com/chaitjo/geometric-rna-design

服务器需要的一些设置

在这个过程中,注意一些tricks。由于租用的国内服务器与英国的环境不太一样,因此需要多使用数据盘而非系统盘。本项目最大的问题是作者的环境变量设置。

Python 3.10.12 and CUDA 11.8, numpy <2.0

  • 镜像网站加速
1
git clone https://ghfast.top/github.com/chaitjo/geometric-rna-design
1
cd /root/autodl-tmp/geometric-rna-design
  • 此外,在数据盘配环境
1
mamba create -p /root/autodl-tmp/rna python=3.10

如果需要激活环境,请注意相对路径:

1
mamba activate /root/autodl-tmp/rna
  • 配置租用服务器版本环境变量,下方命令只对当前shell有效。建议写入bashrc或者利用作者给出的.env文件,记得source ~/.bashrc
1
2
3
4
5
6
export PROJECT_PATH='/root/autodl-tmp/geometric-rna-design/'
export ETERNAFOLD='/root/autodl-tmp/geometric-rna-design/tools/EternaFold'
export X3DNA='/root/autodl-tmp/geometric-rna-design/tools/x3dna-v2.4'
export PATH="/root/autodl-tmp/geometric-rna-design/tools/x3dna-v2.4/bin:$PATH"
export PATH="/root/autodl-tmp/cdhit:$PATH"
PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128

而相应的,英国项目版本则按照如下,放到.env文件里。

1
2
3
4
5
6
7
8
9
export PROJECT_PATH='/home/remote1/geometric-rna-design/'
export DATA_PATH='/home/remote1/geometric-rna-design/data/'
export WANDB_PROJECT='rna'
export WANDB_ENTITY='wenxy59-sun-yat-sen-university'
export WANDB_DIR='/home/remote1/geometric-rna-design/'
export ETERNAFOLD='/home/remote1/geometric-rna-design/tools/EternaFold'
export X3DNA='/home/remote1/geometric-rna-design/tools/x3dna-v2.4'
export PATH="/home/remote1/geometric-rna-design/tools/x3dna-v2.4/bin:$PATH"
PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128

训练模型的注意事项

数据预处理

注意运行main脚本前,需要处理数据生成process.pt文件,并生成相应的das_split.pt文件。在这个过程中,需要安装USalign与qTMclust工具。一般而言,安装后者的同时会自动安装前者。

1
2
3
git clone https://github.com/pylelab/USalign.git
cd USalign
g++ -O3 -o qTMclust qTMclust.cpp -lm

之后就可以使用USalign -hqTMclust -h命令来验证安装,记得查一下路径。在训练脚本里的相应位置路径用的相对路径可能报错,例如src/data/clustering_units.py记得修改。

处理数据成功

数据处理大概是把14000+条raw数据处理为3910条可用数据,因为raw数据有残缺的处理后筛查出去了,因而数据不多,正常。

有问题的数据

然后用notebook里的代码生成split文件。

模型使用与实验

作者给出.py脚本进行启动,或者用命令行如下:

1
python gRNAde.py     --pdb_filepath data/raw/6J6G_1_L-E.pdb     --output_filepath tutorial/lnc/po/114.fasta     --split das     --max_num_conformers 1  --n_samples 16     --temperature 0.5

使用效果如图:

利用gRNAde预测结构与序列

模型在长rna序列(100+nts)的时候性能会下降,虽然recovery保持良好但是二级结构自洽性得分SC Score呈现明显线性下降。

四个参数直观对比
SC Score(左)与Recovery(右)随Sequence的变化

另一个生成范式:RiboDiffusion

至于RiboDiffusion,其他部署方式一致但是环境设置如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
conda create -n rna2 python=3.10 -y
conda activate rna2
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
pip install absl-py==0.15.0
pip install biopython==1.80
pip install dm_tree==0.1.7
pip install fair-esm==2.0.0
pip install ml_collections==0.1.1
pip install numpy==1.24.3
pip install scipy>=1.10.0
pip install tqdm==4.64.1
pip install torch-cluster==1.6.1+pt113cu116 -f https://data.pyg.org/whl/torch-1.13.0+cu116.html
pip install torch-scatter==2.1.1+pt113cu116 -f https://data.pyg.org/whl/torch-1.13.0+cu116.html
pip install torch-geometric==2.3.1

根据脚本运行,输出会出现一些问题,例如:

RiboDiffusion输出
RiboDiffusion模型的Recovery在对数坐标(上)与常数坐标(下)表示下随Sequence的变化

原始的启动脚本为:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from tqdm import tqdm
import numpy as np
import random
from models import *
from utils import *
from diffusion import NoiseScheduleVP
from sampling import get_sampling_fn
from datasets import utils as du
import functools
import tree
from configs.inference_ribodiffusion import get_config
config = get_config()
# Choose heckpoint name
checkpoint_path = './ckpts/exp_inf.pth'
# checkpoint_path = './ckpts/exp_inf_large.pth'
config.eval.sampling_steps = 100
# config.eval.sampling_steps = 100
NUM_TO_LETTER = np.array(['A', 'G', 'C', 'U'])

def get_optimizer(config, params):
  """Return a flax optimizer object based on `config`."""
  if config.optim.optimizer == 'Adam':
      optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps, weight_decay=config.optim.weight_decay)
  elif config.optim.optimizer == 'AdamW':
      optimizer = torch.optim.AdamW(params, lr=config.optim.lr, amsgrad=True, weight_decay=1e-12)
  else:
      raise NotImplementedError(f'Optimizer {config.optim.optimizer} not supported yet!')
  return optimizer
# Initialize model
model = create_model(config)
ema = ExponentialMovingAverage(model.parameters(), decay=config.model.ema_decay)
params = model.parameters()
optimizer = get_optimizer(config, model.parameters())
state = dict(optimizer=optimizer, model=model, ema=ema, step=0)

model_size = sum(p.numel() for p in model.parameters()) * 4 / 2 ** 20
print('model size: {:.1f}MB'.format(model_size))

# Load checkpoint
state = restore_checkpoint(checkpoint_path, state, device=config.device)
ema.copy_to(model.parameters())

# Initialize noise scheduler
noise_scheduler = NoiseScheduleVP(config.sde.schedule, continuous_beta_0=config.sde.continuous_beta_0,
                                  continuous_beta_1=config.sde.continuous_beta_1)
# Obtain data scalar and inverse scalar
inverse_scaler = get_data_inverse_scaler(config)

# Setup sampling function
test_sampling_fn = get_sampling_fn(config, noise_scheduler, config.eval.sampling_steps, inverse_scaler)
pdb2data = functools.partial(du.PDBtoData, num_posenc=config.data.num_posenc, num_rbf=config.data.num_rbf, knn_num=config.data.knn_num)
# Run inference on a single p
pdb_file= '/home/remote1/geometric-rna-design/data/raw/1FIR_1_A.pdb'
pdb_id = pdb_file.replace('.pdb', '')
if '/' in pdb_id:
    pdb_id = pdb_id.split('/')[-1]

config.eval.dynamic_threshold=True
config.eval.cond_scale=0.4
config.eval.n_samples=16
test_sampling_fn = get_sampling_fn(config, noise_scheduler, config.eval.sampling_steps, inverse_scaler)
struct_data = pdb2data(pdb_file)
struct_data = tree.map_structure(lambda x:x.unsqueeze(0).repeat_interleave(config.eval.n_samples, dim=0).to(config.device), struct_data)
samples = test_sampling_fn(model, struct_data)
print(f'PDB ID: {pdb_id}')
native_seq = ''.join(list(NUM_TO_LETTER[struct_data['seq'][0].cpu().numpy()]))
print(f'Native sequence: {native_seq}')
for i in range(len(samples)):
    # native_seq = ''.join(list(NUM_TO_LETTER[struct_data['seq'].squeeze(0).cpu().numpy()]))
    # print(f'Native sequence: {native_seq}')
    designed_seq = ''.join(list(NUM_TO_LETTER[samples[i].cpu().numpy()]))
    print(f'Generated sequence {i+1}: {designed_seq}')
    recovery_ = samples[i].eq(struct_data['seq'][0]).float().mean().item()
    print(f'Recovery rate {i+1}: {recovery_:.4f}')

自动化设计为:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
import torch
from tqdm import tqdm
import numpy as np
import random
import os
import sys
from datetime import datetime

# Add the gRNAde path to import evaluation tools
sys.path.append('/home/remote1/geometric-rna-design/src')
sys.path.append('/home/remote1/geometric-rna-design/')

from models import *
from utils import *
from diffusion import NoiseScheduleVP
from sampling import get_sampling_fn
from datasets import utils as du
import functools
import tree
from configs.inference_ribodiffusion import get_config

# Import gRNAde evaluation functions
from src.evaluator import (
    self_consistency_score_eternafold,
    edit_distance
)
from src.data.data_utils import pdb_to_tensor, get_c4p_coords
from src.constants import NUM_TO_LETTER, PROJECT_PATH

# Import BioPython for sequence handling
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

config = get_config()

# Choose checkpoint name
checkpoint_path = './ckpts/exp_inf.pth'
config.eval.sampling_steps = 100

def get_optimizer(config, params):
    """Return a flax optimizer object based on `config`."""
    if config.optim.optimizer == 'Adam':
        optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), 
                              eps=config.optim.eps, weight_decay=config.optim.weight_decay)
    elif config.optim.optimizer == 'AdamW':
        optimizer = torch.optim.AdamW(params, lr=config.optim.lr, amsgrad=True, weight_decay=1e-12)
    else:
        raise NotImplementedError(f'Optimizer {config.optim.optimizer} not supported yet!')
    return optimizer

def extract_sec_struct_from_pdb(pdb_file):
    """
    Extract secondary structure from PDB file using gRNAde's data utilities.
    """
    try:
        sequence, coords, sec_struct, sasa = pdb_to_tensor(
            pdb_file, 
            return_sec_struct=True, 
            return_sasa=True,
            keep_insertions=False
        )
        return [sec_struct] if sec_struct else None
    except Exception as e:
        print(f"Warning: Could not extract secondary structure from {pdb_file}: {e}")
        return None

def prepare_raw_data_for_grnade_eval(pdb_file, native_seq):
    """
    Prepare raw data structure compatible with gRNAde evaluator.
    """
    try:
        sequence, coords, sec_struct, sasa = pdb_to_tensor(
            pdb_file, 
            return_sec_struct=True, 
            return_sasa=True,
            keep_insertions=False
        )
        
        raw_data = {
            'sequence': native_seq,
            'coords_list': [coords] if coords is not None else [],
            'sec_struct_list': [sec_struct] if sec_struct else ["."] * len(native_seq),
            'sasa_list': [sasa] if sasa is not None else [np.ones(len(native_seq))]
        }
        
        return raw_data
    except Exception as e:
        print(f"Warning: Could not prepare raw data from {pdb_file}: {e}")
        # Return minimal raw data structure
        return {
            'sequence': native_seq,
            'coords_list': [],
            'sec_struct_list': ["."] * len(native_seq),
            'sasa_list': [np.ones(len(native_seq))]
        }

def evaluate_ribodiffusion_with_grnade_sc(samples, native_seq, pdb_file, pdb_id, 
                                         output_dir=None, save_results=True):
    """
    Evaluate RiboDiffusion samples using gRNAde's self-consistency evaluation.
    """
    # Prepare raw data for gRNAde evaluator
    raw_data = prepare_raw_data_for_grnade_eval(pdb_file, native_seq)
    
    # Convert RiboDiffusion samples to numpy arrays
    sample_arrays = []
    for sample in samples:
        if isinstance(sample, torch.Tensor):
            sample_arrays.append(sample.cpu().numpy())
        else:
            sample_arrays.append(np.array(sample))
    
    # Create mask for coordinates (assume all positions are valid for now)
    mask_coords = np.ones(len(native_seq), dtype=bool)
    
    # Calculate basic metrics
    results = {
        'pdb_id': pdb_id,
        'native_seq': native_seq,
        'samples': [],
        'recovery_rates': [],
        'edit_distances': [],
        'sc_scores_eternafold': []
    }
    
    # Convert native sequence to numerical for recovery calculation
    letter_to_num = {letter: idx for idx, letter in enumerate(NUM_TO_LETTER)}
    native_array = np.array([letter_to_num[char] for char in native_seq])
    
    print(f"\nEvaluating {len(samples)} samples for {pdb_id}:")
    
    for i, sample_array in enumerate(sample_arrays):
        designed_seq = ''.join([NUM_TO_LETTER[num] for num in sample_array])
        results['samples'].append(designed_seq)
        
        # Calculate recovery rate
        recovery = (sample_array == native_array).mean()
        results['recovery_rates'].append(recovery)
        
        # Calculate edit distance using gRNAde's function
        edit_dist = edit_distance(designed_seq, native_seq)
        results['edit_distances'].append(edit_dist)
        
        print(f'Sample {i+1}: Recovery={recovery:.4f}, Edit_dist={edit_dist}')
    
    # Calculate self-consistency scores using gRNAde's EternaFold evaluator
    print("\nCalculating self-consistency scores with EternaFold...")
    try:
        sc_scores = self_consistency_score_eternafold(
            sample_arrays,
            raw_data['sec_struct_list'],
            mask_coords
        )
        results['sc_scores_eternafold'] = sc_scores.tolist()
        
        for i, sc_score in enumerate(sc_scores):
            print(f'Sample {i+1}: SC_score={sc_score:.4f}')
            
    except Exception as e:
        print(f"Warning: Could not calculate SC scores: {e}")
        print("This might be due to EternaFold not being properly installed or configured.")
        results['sc_scores_eternafold'] = [0.0] * len(samples)
    
    # Save results
    if save_results and output_dir:
        os.makedirs(output_dir, exist_ok=True)
        
        # Save as FASTA file compatible with gRNAde format
        sequences = []
        
        # First record: input sequence with metadata
        sequences.append(SeqRecord(
            Seq(native_seq), 
            id="input_sequence,",
            description=f"pdb_id={pdb_id}, ribodiffusion_evaluation"
        ))
        
        # Remaining records: designed sequences with metrics
        for i, (seq, recovery, edit_dist, sc_score) in enumerate(zip(
            results['samples'], 
            results['recovery_rates'], 
            results['edit_distances'],
            results['sc_scores_eternafold']
        )):
            sequences.append(SeqRecord(
                Seq(seq), 
                id=f"sample={i},",
                description=f"recovery={recovery:.4f}, edit_dist={edit_dist}, sc_score={sc_score:.4f}"
            ))
        
        # Save FASTA
        fasta_path = os.path.join(output_dir, f"{pdb_id}_ribodiffusion_designs.fasta")
        SeqIO.write(sequences, fasta_path, "fasta")
        print(f"\nResults saved to: {fasta_path}")
    
    # Print summary statistics
    print(f"\n{'='*50}")
    print(f"Summary for {pdb_id}:")
    print(f"Native sequence length: {len(native_seq)}")
    print(f"Number of samples: {len(samples)}")
    print(f"Mean Recovery: {np.mean(results['recovery_rates']):.4f} ± {np.std(results['recovery_rates']):.4f}")
    print(f"Mean Edit Distance: {np.mean(results['edit_distances']):.2f} ± {np.std(results['edit_distances']):.2f}")
    
    if results['sc_scores_eternafold'][0] != 0.0:
        print(f"Mean SC Score (EternaFold): {np.mean(results['sc_scores_eternafold']):.4f} ± {np.std(results['sc_scores_eternafold']):.4f}")
    else:
        print("SC Scores not calculated (EternaFold unavailable)")
    print(f"{'='*50}")
    
    return results

# Initialize model
model = create_model(config)
ema = ExponentialMovingAverage(model.parameters(), decay=config.model.ema_decay)
params = model.parameters()
optimizer = get_optimizer(config, model.parameters())
state = dict(optimizer=optimizer, model=model, ema=ema, step=0)

model_size = sum(p.numel() for p in model.parameters()) * 4 / 2 ** 20
print('Model size: {:.1f}MB'.format(model_size))

# Load checkpoint
state = restore_checkpoint(checkpoint_path, state, device=config.device)
ema.copy_to(model.parameters())

# Initialize noise scheduler
noise_scheduler = NoiseScheduleVP(config.sde.schedule, 
                                 continuous_beta_0=config.sde.continuous_beta_0,
                                 continuous_beta_1=config.sde.continuous_beta_1)

# Obtain data scalar and inverse scalar
inverse_scaler = get_data_inverse_scaler(config)

# Setup sampling function
test_sampling_fn = get_sampling_fn(config, noise_scheduler, config.eval.sampling_steps, inverse_scaler)
pdb2data = functools.partial(du.PDBtoData, num_posenc=config.data.num_posenc, 
                            num_rbf=config.data.num_rbf, knn_num=config.data.knn_num)

# Run inference
pdb_file = '/home/remote1/geometric-rna-design/data/raw/7PIC_1_5.pdb'
pdb_id = os.path.basename(pdb_file).replace('.pdb', '')

# Configure sampling
config.eval.dynamic_threshold = True
config.eval.cond_scale = 0.4
config.eval.n_samples = 16

# Generate samples
print(f'Processing PDB: {pdb_file}')
test_sampling_fn = get_sampling_fn(config, noise_scheduler, config.eval.sampling_steps, inverse_scaler)
struct_data = pdb2data(pdb_file)
struct_data = tree.map_structure(
    lambda x: x.unsqueeze(0).repeat_interleave(config.eval.n_samples, dim=0).to(config.device), 
    struct_data
)
samples = test_sampling_fn(model, struct_data)

# Get native sequence
native_seq = ''.join(list(NUM_TO_LETTER[struct_data['seq'][0].cpu().numpy()]))
print(f'Native sequence: {native_seq}')

# Create output directory
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = f"ribodiffusion_grnade_eval_{current_time}"

# Evaluate with gRNAde's self-consistency framework
results = evaluate_ribodiffusion_with_grnade_sc(
    samples=samples,
    native_seq=native_seq,
    pdb_file=pdb_file,
    pdb_id=pdb_id,
    output_dir=output_dir,
    save_results=True
)

在13000条数据进行测试,

自动化测试
测试结果与测试结果(log后)

初步改进——attention

考虑到多层注意力机制会有利于长序列的捕捉,我先使用MultiheadAttention,加了一层简单的代码,进行训练。其中原本release的代码有bug,尤其是这个函数里的mask_coords掩码逻辑有问题,经常与sample的维度不匹配报错。我把修改过的evaluator.py贴在这里,同时防止显存爆炸,需要设置一些batch。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
import os
import copy
import shutil
from datetime import datetime

import numpy as np
import pandas as pd
from tqdm import tqdm
import wandb

import torch
import torch.nn.functional as F
from torchmetrics.functional.classification import binary_matthews_corrcoef

from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

from MDAnalysis.analysis.align import rotation_matrix
from MDAnalysis.analysis.rms import rmsd as get_rmsd

from src.data.data_utils import pdb_to_tensor, get_c4p_coords
from src.data.sec_struct_utils import (
    predict_sec_struct,
    dotbracket_to_paired,
    dotbracket_to_adjacency
)
from src.constants import (
    NUM_TO_LETTER, 
    PROJECT_PATH,
    RMSD_THRESHOLD,
    TM_THRESHOLD,
    GDT_THRESHOLD
)


def evaluate(
        model, 
        dataset, 
        n_samples, 
        temperature, 
        device, 
        model_name="eval",
        metrics=[
            'recovery', 'perplexity', 'sc_score_eternafold', 
            'sc_score_ribonanzanet', 'sc_score_rhofold'
        ],
        save_designs=False
    ):
    """
    Run evaluation suite for trained RNA inverse folding model on a dataset.

    The following metrics can be computed along with metadata per sample per residue:
    1. (recovery) Sequence recovery per residue (taking mean gives per sample recovery)
    2. (perplexity) Perplexity per sample
    3. (sc_score_eternafold) Secondary structure self-consistency score per sample, 
        using EternaFold for secondary structure prediction and computing MCC between
        the predicted and groundtruth 2D structures as adjacency matrices.
    4. (sc_score_ribonanzanet) Chemical modification self-consistency score per sample,
        using RibonanzaNet for chemical modification prediction of the groundtruth and
        designed sequences, and measuring MAE between them.
    5. (sc_score_rhofold) Tertiary structure self-consistency scores per sample,
        using RhoFold for tertiary structure prediction and measuring RMSD, TM-score,
        and GDT_TS between the predicted and groundtruth C4' 3D coordinates.
    6. (rmsd_within_thresh) Percentage of samples with RMSD within threshold (<=2.0A)
    7. (tm_within_thresh) Percentage of samples with TM-score within threshold (>=0.45)
    8. (gddt_within_thresh) Percentage of samples with GDT_TS within threshold (>=0.50)

    Args:
        model: trained RNA inverse folding model
        dataset: dataset to evaluate on
        n_samples: number of predicted samples/sequences per data point 
        temperature: sampling temperature
        device: device to run evaluation on
        model_name: name of model/dataset for plotting (default: 'eval')
        metrics: list of metrics to compute
        save_designs: whether to save designs as fasta with metrics
    
    Returns: Dictionary with the following keys:
        df: DataFrame with metrics and metadata per residue per sample for analysis and plotting
        samples_list: list of tensors of shape (n_samples, seq_len) per data point 
        recovery_list: list of mean recovery per data point
        perplexity_list: list of mean perplexity per data point
        sc_score_eternafold_list: list of 2D self-consistency scores per data point
        sc_score_ribonanzanet_list: list of 1D self-consistency scores per data point
        sc_score_rmsd_list: list of 3D self-consistency RMSDs per data point
        sc_score_tm_list: list of 3D self-consistency TM-scores per data point
        sc_score_gddt_list: list of 3D self-consistency GDTs per data point
        rmsd_within_thresh_list: list of % scRMSDs within threshold per data point
        tm_within_thresh_list: list of % scTMs within threshold per data point
        gddt_within_thresh_list: list of % scGDDTs within threshold per data point
    """
    assert 'recovery' in metrics, 'Sequence recovery must be computed for evaluation'

    #######################################################################
    # Optionally initialise other models used for self-consistency scoring
    #######################################################################

    if 'sc_score_ribonanzanet' in metrics:
        from tools.ribonanzanet.network import RibonanzaNet
        
        # Initialise RibonanzaNet for self-consistency score
        ribonanza_net = RibonanzaNet(
            os.path.join(PROJECT_PATH, 'tools/ribonanzanet/config.yaml'),
            os.path.join(PROJECT_PATH, 'tools/ribonanzanet/ribonanzanet.pt'),
            device
        )
        # Transfer model to device in eval mode
        ribonanza_net = ribonanza_net.to(device)
        ribonanza_net.eval()
    
    if 'sc_score_rhofold' in metrics:
        from tools.rhofold.rf import RhoFold
        from tools.rhofold.config import rhofold_config
        
        # Initialise RhoFold for 3D self-consistency score
        rhofold = RhoFold(rhofold_config, device)
        rhofold_path = os.path.join(PROJECT_PATH, "tools/rhofold/model_20221010_params.pt")
        print(f"Loading RhoFold checkpoint: {rhofold_path}")
        rhofold.load_state_dict(torch.load(rhofold_path, map_location=torch.device('cpu'))['model'])
        # Transfer model to device in eval mode
        rhofold = rhofold.to(device)
        rhofold.eval()
        current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")

    ####################################################
    # Evaluation loop over each data point sequentially
    ####################################################

    # per sample metric lists for storing evaluation results
    samples_list = []               # list of tensors of shape (n_samples, seq_len) per data point 
    recovery_list = []              # list of mean recovery per data point
    perplexity_list = []            # list of mean perplexity per data point
    sc_score_ribonanzanet_list = [] # list of 1D self-consistency scores per data point
    sc_score_eternafold_list = []   # list of 2D self-consistency scores per data point
    sc_score_rmsd_list = []         # list of 3D self-consistency RMSDs per data point
    rmsd_within_thresh_list = []    # list of % scRMSDs within threshold per data point
    sc_score_tm_list = []           # list of 3D self-consistency TM-scores per data point
    tm_within_thresh_list = []      # list of % scTMs within threshold per data point
    sc_score_gddt_list = []         # list of 3D self-consistency GDTs per data point
    gddt_within_thresh_list = []    # list of % scGDDTs within threshold per data point

    # DataFrame to store metrics and metadata per residue per sample for analysis and plotting
    df = pd.DataFrame(columns=['idx', 'recovery', 'sasa', 'paired', 'rmsds', 'model_name'])

    model.eval()
    if device.type == 'xpu':
        import intel_extension_for_pytorch as ipex
        model = ipex.optimize(model)
        if 'sc_score_ribonanzanet' in metrics:
            ribonanza_net = ipex.optimize(ribonanza_net)
        if 'sc_score_rhofold' in metrics:
            rhofold = ipex.optimize(rhofold)
    
    with torch.no_grad():
        for idx, raw_data in tqdm(
            enumerate(dataset.data_list),
            total=len(dataset.data_list)
        ):
            # featurise raw data
            data = dataset.featurizer(raw_data).to(device)

            # sample n_samples from model for single data point: n_samples x seq_len
            samples, logits = model.sample(data, n_samples, temperature, return_logits=True)
            samples_list.append(samples.cpu().numpy())
            
            # perplexity per sample: n_samples x 1
            n_nodes = logits.shape[1]
            perplexity = torch.exp(F.cross_entropy(
                logits.view(n_samples * n_nodes, model.out_dim), 
                samples.view(n_samples * n_nodes).long(), 
                reduction="none"
            ).view(n_samples, n_nodes).mean(dim=1)).cpu().numpy()
            perplexity_list.append(perplexity.mean())

            ###########
            # Metadata
            ###########

            # per residue average SASA: seq_len x 1
            mask_coords = data.mask_coords.cpu().numpy()
            sasa = np.mean(raw_data['sasa_list'], axis=0)[mask_coords]

            # per residue indicator for paired/unpaired: seq_len x 1
            paired = np.mean(
                [dotbracket_to_paired(sec_struct) for sec_struct in raw_data['sec_struct_list']], axis=0
            )[mask_coords]

            # per residue average RMSD: seq_len x 1
            if len(raw_data["coords_list"]) == 1:
                rmsds = np.zeros_like(sasa)
            else:
                rmsds = []
                for i in range(len(raw_data["coords_list"])):
                    for j in range(i+1, len(raw_data["coords_list"])):
                        coords_i = get_c4p_coords(raw_data["coords_list"][i])
                        coords_j = get_c4p_coords(raw_data["coords_list"][j])
                        rmsds.append(torch.sqrt(torch.sum((coords_i - coords_j)**2, dim=1)).cpu().numpy())
                rmsds = np.stack(rmsds).mean(axis=0)[mask_coords]

            ##########
            # Metrics
            ##########

            # sequence recovery per residue across all samples: n_samples x seq_len 
            recovery = samples.eq(data.seq).float().cpu().numpy()
            recovery_list.append(recovery.mean())

            # update per residue per sample dataframe
            df = pd.concat([
                df, 
                pd.DataFrame({
                    'idx': [idx] * len(recovery.mean(axis=0)),
                    'recovery': recovery.mean(axis=0),
                    'sasa': sasa,
                    'paired': paired,
                    'rmsds': rmsds,
                    'model_name': [model_name] * len(recovery.mean(axis=0))
                })
            ], ignore_index=True)

            # global 2D self consistency score per sample: n_samples x 1
            if 'sc_score_eternafold' in metrics:
                sc_score_eternafold, pred_sec_structs = self_consistency_score_eternafold(
                    samples.cpu().numpy(), 
                    raw_data['sec_struct_list'], 
                    mask_coords,
                    return_sec_structs = True
                )
                sc_score_eternafold_list.append(sc_score_eternafold.mean())

            # global 1D self consistency score per sample: n_samples x 1
            if 'sc_score_ribonanzanet' in metrics:
                sc_score_ribonanzanet, pred_chem_mods = self_consistency_score_ribonanzanet(
                    samples.cpu().numpy(), 
                    raw_data['sequence'],
                    mask_coords, 
                    ribonanza_net,
                    return_chem_mods = True
                )
                sc_score_ribonanzanet_list.append(sc_score_ribonanzanet.mean())
            
            # global 3D self consistency scores per sample: n_samples x 1, each
            if 'sc_score_rhofold' in metrics:
                try:
                    output_dir = os.path.join(
                        wandb.run.dir, f"designs_{model_name}/{current_datetime}/sample{idx}/")
                except AttributeError:
                    output_dir = os.path.join(
                        PROJECT_PATH, f"designs_{model_name}/{current_datetime}/sample{idx}/")

                sc_score_rmsd, sc_score_tm, sc_score_gdt = self_consistency_score_rhofold(
                    samples.cpu().numpy(), 
                    raw_data,
                    mask_coords,
                    rhofold,
                    output_dir,
                    save_designs = save_designs
                )
                sc_score_rmsd_list.append(sc_score_rmsd.mean())
                sc_score_tm_list.append(sc_score_tm.mean())
                sc_score_gddt_list.append(sc_score_gdt.mean())

                rmsd_within_thresh_list.append((sc_score_rmsd <= RMSD_THRESHOLD).sum() / n_samples)
                tm_within_thresh_list.append((sc_score_tm >= TM_THRESHOLD).sum() / n_samples)
                gddt_within_thresh_list.append((sc_score_gdt >= GDT_THRESHOLD).sum() / n_samples)

                if save_designs:
                    # collate designed sequences in fasta format
                    sequences = [SeqRecord(
                        Seq(raw_data["sequence"]), id=f"input_sequence,", 
                        description=f"pdb_id={raw_data['id_list'][0]} rfam={raw_data['rfam_list'][0]} eq_class={raw_data['eq_class_list'][0]} cluster={raw_data['cluster_structsim0.45']}"
                    )]
                    for idx, zipped in enumerate(zip(
                        samples.cpu().numpy(),
                        perplexity,
                        recovery.mean(axis=1),
                        sc_score_eternafold,
                        pred_sec_structs,
                        sc_score_ribonanzanet,
                        pred_chem_mods,
                        sc_score_rmsd,
                        sc_score_tm,
                        sc_score_gdt
                    )):
                        seq, perp, rec, sc, pred_ss, sc_ribo, pred_cm, sc_rmsd, sc_tm, sc_gdt = zipped
                        seq = "".join([NUM_TO_LETTER[num] for num in seq])
                        edit_dist = edit_distance(seq, raw_data['sequence'])
                        sequences.append(SeqRecord(
                            Seq(seq), id=f"sample={idx},",
                            description=f"temperature={temperature} perplexity={perp:.4f} recovery={rec:.4f} edit_dist={edit_dist} sc_score={sc:.4f} sc_score_ribonanzanet={sc_ribo:.4f} sc_score_rmsd={sc_rmsd:.4f} sc_score_tm={sc_tm:.4f} sc_score_gdt={sc_gdt:.4f}"
                        ))
                    # write all designed sequences to output filepath
                    SeqIO.write(sequences, os.path.join(output_dir, "all_designs.fasta"), "fasta")

    out = {
        'df': df,
        'samples_list': samples_list,
        'recovery_list': recovery_list,
        'perplexity_list': perplexity_list
    }
    if 'sc_score_eternafold' in metrics:
        out['sc_score_eternafold'] = sc_score_eternafold_list
    if 'sc_score_ribonanzanet' in metrics:
        out['sc_score_ribonanzanet'] = sc_score_ribonanzanet_list
    if 'sc_score_rhofold' in metrics:
        out['sc_score_rmsd'] = sc_score_rmsd_list
        out['sc_score_tm'] = sc_score_tm_list
        out['sc_score_gddt'] = sc_score_gddt_list
        out['rmsd_within_thresh'] = rmsd_within_thresh_list
        out['tm_within_thresh'] = tm_within_thresh_list
        out['gddt_within_thresh'] = gddt_within_thresh_list
        # =========================================================
    
    import gc 

    
    if 'ribonanza_net' in locals():
        del ribonanza_net
    if 'rhofold' in locals():
        del rhofold

    
    gc.collect()

 
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    # =========================================================
    return out


def self_consistency_score_eternafold(
        samples, 
        true_sec_struct_list, 
        mask_coords,
        n_samples_ss = 1,
        num_to_letter = NUM_TO_LETTER,
        return_sec_structs = False
    ):
    """
    Compute self consistency score for an RNA, given its true secondary structure(s)
    and a list of designed sequences. 
    EternaFold is used to 'forward fold' the designs.
    
    Args:
        samples: designed sequences of shape (n_samples, seq_len)
        true_sec_struct_list: list of true secondary structures (n_true_ss, seq_len)
        mask_coords: mask for missing sequence coordinates to be ignored during evaluation
        n_samples_ss: number of predicted secondary structures per designed sample
        num_to_letter: lookup table mapping integers to nucleotides
        return_sec_structs: whether to return the predicted secondary structures
    
    Workflow:
        
        Input: For a given RNA molecule, we are given:
        - Designed sequences of shape (n_samples, seq_len)
        - True secondary structure(s) of shape (n_true_ss, seq_len)
        
        For each designed sequence:
        - Predict n_sample_ss secondary structures using EternaFold
        - For each pair of true and predicted secondary structures:
            - Compute MCC score between their adjacency matrix representations
        - Take the average MCC score across all n_sample_ss predicted structures
        
        Take the average MCC score across all n_samples designed sequences
    """
    
    n_true_ss = len(true_sec_struct_list)
    sequence_length = mask_coords.sum()
    # map all entries from dotbracket to numerical representation
    true_sec_struct_list = np.array([dotbracket_to_adjacency(ss) for ss in true_sec_struct_list])
    # mask out missing sequence coordinates
    true_sec_struct_list = true_sec_struct_list[:, mask_coords][:, :, mask_coords]
    # reshape to (n_true_ss * n_samples_ss, seq_len, seq_len)
    true_sec_struct_list = torch.tensor(
        true_sec_struct_list
    ).unsqueeze(1).repeat(1, n_samples_ss, 1, 1).reshape(-1, sequence_length, sequence_length)

    mcc_scores = []
    pred_sec_structs = []
    for _sample in samples:
        # convert sample to string
        pred_seq = ''.join([num_to_letter[num] for num in _sample])
        # predict secondary structure(s) for each sample
        pred_sec_struct_list = predict_sec_struct(pred_seq, n_samples=n_samples_ss)
        if return_sec_structs:
            pred_sec_structs.append(copy.copy(pred_sec_struct_list))
        # map all entries from dotbracket to numerical representation
        pred_sec_struct_list = np.array([dotbracket_to_adjacency(ss) for ss in pred_sec_struct_list])
        # reshape to (n_samples_ss * n_true_ss, seq_len, seq_len)
        pred_sec_struct_list = torch.tensor(
            pred_sec_struct_list
        ).unsqueeze(0).repeat(n_true_ss, 1, 1, 1).reshape(-1, sequence_length, sequence_length)

        # compute mean MCC score between pairs of true and predicted secondary structures
        mcc_scores.append(
            binary_matthews_corrcoef(
                pred_sec_struct_list,
                true_sec_struct_list,
            ).float().mean()
        )

    if return_sec_structs:
        return np.array(mcc_scores), pred_sec_structs
    else:
        return np.array(mcc_scores)


def self_consistency_score_ribonanzanet(
    samples,
    true_sequence,
    mask_seq,
    ribonanza_net,
    num_to_letter=NUM_TO_LETTER,
    return_chem_mods=False,
):
    """Compute self consistency score for an RNA, given the (predicted) chemical modifications for
    the original RNA and a list of designed sequences. RibonanzaNet is used to 'forward fold' the
    designs.

    Args:
        samples: designed sequences of shape (n_samples, seq_len)
        true_sequence: true RNA sequence used to predict chemical modifications
        mask_seq: mask for missing sequence coordinates to be ignored during evaluation
        ribonanza_net: RibonanzaNet model
        num_to_letter: lookup table mapping integers to nucleotides
        return_chem_mods: whether to return the predicted chemical modifications

    Workflow:

        Input: For a given RNA molecule, we are given:
        - Designed sequences of shape (n_samples, seq_len)
        - Predicted chemical modifications for original sequence,
          of shape (n_samples, seq_len, 2), predicted via RibonanzaNet, of which we take
          the index 0 from the last channal --> 2A3/SHAPE.

        For each designed sequence:
        - Predict chemical modifications using RibonanzaNet
        - Compute mean absolute error between prediction and chemical modifications for
          the original sequence

        Take the average mean absolute error across all n_samples designed sequences
    """
    # Compute original sequence's chemical modifications using RibonanzaNet
    true_sequence = np.array([char for char in true_sequence])
    true_sequence = "".join(true_sequence[mask_seq])
    true_chem_mod = ribonanza_net.predict(true_sequence).unsqueeze(0).cpu().numpy()[:,:,0]

     # 2. 【核心修改】分批处理模型生成的样本
    _samples_char = np.array([[num_to_letter[num] for num in seq] for seq in samples])
    
    batch_size = 1  # 定义一个小的批次大小,可以根据显存调整(比如4, 8, 16)
    all_preds = [] # 用于收集所有批次的预测结果

    for i in range(0, len(_samples_char), batch_size):
        # 取出一个小批次
        batch_samples = _samples_char[i:i+batch_size]
        # 对这个小批次进行预测
        batch_pred = ribonanza_net.predict(batch_samples).cpu().numpy()[:,:,0]
        all_preds.append(batch_pred)
    
    # 将所有批次的结果拼接起来
    pred_chem_mod = np.concatenate(all_preds, axis=0)
   # _samples_char = np.array([[num_to_letter[num] for num in seq] for seq in samples])
   # _samples = np.array([[num_to_letter[num] for num in seq] for seq in samples])
   #pred_chem_mod = ribonanza_net.predict(_samples_char).cpu().numpy()[:,:,0]
   ##pred_chem_mod = ribonanza_net.predict(_samples[:, mask_seq]).cpu().numpy()[:,:,0]
    if return_chem_mods:
        return (np.abs(pred_chem_mod - true_chem_mod).mean(1)), pred_chem_mod
    else:
        return np.abs(pred_chem_mod - true_chem_mod).mean(1)


def self_consistency_score_ribonanzanet_sec_struct(
        samples, 
        true_sec_struct, 
        mask_coords, 
        ribonanza_net_ss,
        num_to_letter = NUM_TO_LETTER,
        return_sec_structs = False
    ):
    # map from dotbracket to numerical representation
    true_sec_struct = np.array(dotbracket_to_adjacency(true_sec_struct, keep_pseudoknots=True))
    # mask out missing sequence coordinates
    true_sec_struct = true_sec_struct[mask_coords][:, mask_coords]
    # (n_samples, seq_len, seq_len)
    true_sec_struct = torch.tensor(true_sec_struct)

    _samples = np.array([[num_to_letter[num] for num in seq] for seq in samples])
    _, pred_sec_structs = ribonanza_net_ss.predict(_samples)  # (n_samples, seq_len, seq_len)
    
    mcc_scores = []
    for pred_sec_struct in pred_sec_structs:
        # map from dotbracket to numerical representation
        pred_sec_struct = torch.tensor(dotbracket_to_adjacency(pred_sec_struct, keep_pseudoknots=True))
        # compute mean MCC score between pairs of true and predicted secondary structures
        mcc_scores.append(
            binary_matthews_corrcoef(
                pred_sec_struct,
                true_sec_struct,
            ).float().mean()
        )

    if return_sec_structs:
        return np.array(mcc_scores), pred_sec_structs
    else:
        return np.array(mcc_scores)


def self_consistency_score_rhofold(
        samples,
        true_raw_data,
        mask_coords,
        rhofold,
        output_dir,
        num_to_letter = NUM_TO_LETTER,
        save_designs = False,
        save_pdbs = False,
        use_relax = False,
    ):
    """
    Compute self consistency score for an RNA, given its true 3D structure(s)
    for the original RNA and a list of designed sequences.
    RhoFold is used to 'forward fold' the designs.

    Credit: adapted from Rishabh Anand

    Args:
        samples: designed sequences of shape (n_samples, seq_len)
        true_raw_data: Original RNA raw data with 3D structure(s) in `coords_list`
        mask_coords: mask for missing sequence coordinates to be ignored during evaluation
        rhofold: RhoFold model
        output_dir: directory to save designed sequences and structures
        num_to_letter: lookup table mapping integers to nucleotides
        save_designs: whether to save designs as fasta to output directory
        save_pdbs: whether to save PDBs of forward-folded designs to output directory
        use_relax: whether to perform Amber relaxation on designed structures

    Workflow:
            
        Input: For a given RNA molecule, we are given:
        - Designed sequences of shape (n_samples, seq_len)
        - True 3D structure(s) of shape (n_true_structs, seq_len, 3)
        
        For each designed sequence:
        - Predict the tertiary structure using RhoFold
        - For each pair of true and predicted 3D structures:
            - Compute RMSD, TM-score & GDT between their C4' coordinates
        
        Take the average self-consistency scores across all n_samples designed sequences

    Returns:
        sc_rmsds: array of RMSD scores per sample
        sc_tms: array of TM-score scores per sample
        sc_gddts: array of GDT scores per sample
    """
    os.makedirs(output_dir, exist_ok=True)

    # Collate designed sequences in fasta format
    # first record: input sequence and model metadata
    input_seq = SeqRecord(
        Seq(true_raw_data["sequence"]),
        id=f"input_sequence,",
        description=f"input_sequence"
    )
    # SeqIO.write(input_seq, os.path.join(output_dir, "input_seq.fasta"), "fasta")
    sequences = [input_seq]
    
    # remaining records: designed sequences and metrics
    sc_rmsds = []
    sc_tms = []
    sc_gddts = []
    for idx, seq in enumerate(samples):
        # Save designed sequence to fasta file (temporary)
        seq = SeqRecord(
            Seq("".join([num_to_letter[num] for num in seq])), 
            id=f"sample={idx},",
            description=f"sample={idx}"
        )
        sequences.append(seq)
        design_fasta_path = os.path.join(output_dir, f"design{idx}.fasta")
        SeqIO.write(seq, design_fasta_path, "fasta")
        
        # Forward fold designed sequence using RhoFold
        design_pdb_path = os.path.join(output_dir, f"design{idx}.pdb")
        rhofold.predict(design_fasta_path, design_pdb_path, use_relax)
        
        # Load C4' coordinates of designed structure
        _, coords, _, _ = pdb_to_tensor(
            design_pdb_path,
            return_sec_struct=False,
            return_sasa=False,
            keep_insertions=False,
        )
        coords = get_c4p_coords(coords)
        # zero-center coordinates
        coords = coords - coords.mean(dim=0)

        # Compute self-consistency between designed and groundtruth structures
        _sc_rmsds = []
        _sc_tms = []
        _sc_gddts = []
        for other_coords in true_raw_data["coords_list"]:
            _other = get_c4p_coords(other_coords)[mask_coords, :]
            # zero-center other coordinates
            _other = _other - _other.mean(dim=0)
            # globally align coordinates
            R_hat = rotation_matrix(
                _other,  # mobile set
                coords # reference set
            )[0]
            _other = _other @ R_hat.T
            # compute metrics
            _sc_rmsds.append(get_rmsd(
                coords, _other, superposition=True, center=True))
            _sc_tms.append(get_tmscore(coords, _other))
            _sc_gddts.append(get_gddt(coords, _other))

        sc_rmsds.append(np.mean(_sc_rmsds))
        sc_tms.append(np.mean(_sc_tms))
        sc_gddts.append(np.mean(_sc_gddts))

        # remove temporary files
        os.unlink(design_fasta_path)
        if save_pdbs is False:
            os.unlink(design_pdb_path)
    
    if save_designs is False:
        # remove output directory        
        shutil.rmtree(output_dir)
    else:
        # write all designed sequences to output filepath
        SeqIO.write(sequences, os.path.join(output_dir, "all_designs.fasta"), "fasta")

    return np.array(sc_rmsds), np.array(sc_tms), np.array(sc_gddts)


def get_tmscore(y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Template Modelling score (TM-score). 
    
    Credit: Arian Jamasb, graphein (https://github.com/a-r-j/graphein)

    https://en.wikipedia.org/wiki/Template_modeling_score

    TM-score is a measure of similarity between two protein structures.
    The TM-score is intended as a more accurate measure of the global
    similarity of full-length protein structures than the often used RMSD
    measure. The TM-score indicates the similarity between two structures
    by a score between ``[0, 1]``, where 1 indicates a perfect match
    between two structures (thus the higher the better). Generally scores
    below 0.20 corresponds to randomly chosen unrelated proteins whereas
    structures with a score higher than 0.5 assume roughly the same fold.
    A quantitative study shows that proteins of TM-score = 0.5 have a
    posterior probability of 37% in the same CATH topology family and of
    13% in the same SCOP fold family. The probabilities increase rapidly
    when TM-score > 0.5. The TM-score is designed to be independent of
    protein lengths.
    
    We have adapted the implementation to RNA (TM-score threshold = 0.45).
    Requires aligned C4' coordinates as input.
    """
    l_target = y.shape[0]
    d0_l_target = 1.24 * np.power(l_target - 15, 1 / 3) - 1.8
    di = torch.pairwise_distance(y_hat, y)
    out = torch.sum(1 / (1 + (di / d0_l_target) ** 2)) / l_target
    if torch.isnan(out):
        return torch.tensor(0.0)
    return out


def get_gddt(y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Global Distance Deviation Test metric (GDDT).

    Credit: Arian Jamasb, graphein (https://github.com/a-r-j/graphein)

    https://en.wikipedia.org/wiki/Global_distance_test

    The GDT score is calculated as the largest set of amino acid residues'
    alpha carbon atoms in the model structure falling within a defined
    distance cutoff of their position in the experimental structure, after
    iteratively superimposing the two structures. By the original design the
    GDT algorithm calculates 20 GDT scores, i.e. for each of 20 consecutive distance
    cutoffs (``0.5 Å, 1.0 Å, 1.5 Å, ... 10.0 Å``). For structure similarity assessment
    it is intended to use the GDT scores from several cutoff distances, and scores
    generally increase with increasing cutoff. A plateau in this increase may
    indicate an extreme divergence between the experimental and predicted structures,
    such that no additional atoms are included in any cutoff of a reasonable distance.
    The conventional GDT_TS total score in CASP is the average result of cutoffs at
    ``1``, ``2``, ``4``, and ``8`` Å.

    Random predictions give around 20; getting the gross topology right gets one to ~50; 
    accurate topology is usually around 70; and when all the little bits and pieces, 
    including side-chain conformations, are correct, GDT_TS begins to climb above 90.

    We have adapted the implementation to RNA.
    Requires aligned C4' coordinates as input.
    """
    # Get distance between points
    dist = torch.norm(y - y_hat, dim=1)

    # Return mean fraction of distances below cutoff for each cutoff (1, 2, 4, 8)
    count_1 = (dist < 1).sum() / dist.numel()
    count_2 = (dist < 2).sum() / dist.numel()
    count_4 = (dist < 4).sum() / dist.numel()
    count_8 = (dist < 8).sum() / dist.numel()
    out = torch.mean(torch.tensor([count_1, count_2, count_4, count_8]))
    if torch.isnan(out):
        return torch.tensor(0.0)
    return out


def edit_distance(s: str, t: str) -> int:
    """
    A Space efficient Dynamic Programming based Python3 program 
    to find minimum number operations to convert str1 to str2

    Source: https://www.geeksforgeeks.org/edit-distance-dp-5/
    """
    n = len(s)
    m = len(t)

    prev = [j for j in range(m+1)]
    curr = [0] * (m+1)

    for i in range(1, n+1):
        curr[0] = i
        for j in range(1, m+1):
            if s[i-1] == t[j-1]:
                curr[j] = prev[j-1]
            else:
                mn = min(1 + prev[j], 1 + curr[j-1])
                curr[j] = min(mn, 1 + prev[j-1])
        prev = curr.copy()

    return prev[m]

除此之外,需要加一层简单的attention,models.py代码如下

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
################################################################
# Generalisation of Geometric Vector Perceptron, Jing et al.
# for explicit multi-state biomolecule representation learning.
# Original repository: https://github.com/drorlab/gvp-pytorch
################################################################

from typing import Optional
import torch
from torch import nn
import torch.nn.functional as F
from torch.distributions import Categorical
import torch_geometric

from src.layers import *


class AutoregressiveMultiGNNv1(torch.nn.Module):
    '''
    Autoregressive GVP-GNN for **multiple** structure-conditioned RNA design.
    
    Takes in RNA structure graphs of type `torch_geometric.data.Data` 
    or `torch_geometric.data.Batch` and returns a categorical distribution
    over 4 bases at each position in a `torch.Tensor` of shape [n_nodes, 4].
    
    The standard forward pass requires sequence information as input
    and should be used for training or evaluating likelihood.
    For sampling or design, use `self.sample`.

    Args:
        node_in_dim (tuple): node dimensions in input graph
        node_h_dim (tuple): node dimensions to use in GVP-GNN layers
        node_in_dim (tuple): edge dimensions in input graph
        edge_h_dim (tuple): edge dimensions to embed in GVP-GNN layers
        num_layers (int): number of GVP-GNN layers in encoder/decoder
        drop_rate (float): rate to use in all dropout layers
        out_dim (int): output dimension (4 bases)
    '''
    def __init__(
        self,
        node_in_dim = (64, 4), 
        node_h_dim = (128, 16), 
        edge_in_dim = (32, 1), 
        edge_h_dim = (32, 1),
        num_layers = 3, 
        drop_rate = 0.1,
        out_dim = 4,
    ):
        super().__init__()
        self.node_in_dim = node_in_dim
        self.node_h_dim = node_h_dim
        self.edge_in_dim = edge_in_dim
        self.edge_h_dim = edge_h_dim
        self.num_layers = num_layers
        self.out_dim = out_dim
        activations = (F.silu, None)
        
        # Node input embedding
        self.W_v = torch.nn.Sequential(
            LayerNorm(self.node_in_dim),
            GVP(self.node_in_dim, self.node_h_dim,
                activations=(None, None), vector_gate=True)
        )

        # Edge input embedding
        self.W_e = torch.nn.Sequential(
            LayerNorm(self.edge_in_dim),
            GVP(self.edge_in_dim, self.edge_h_dim, 
                activations=(None, None), vector_gate=True)
        )
        
        # Encoder layers (supports multiple conformations)
        self.encoder_layers = nn.ModuleList(
                MultiGVPConvLayer(self.node_h_dim, self.edge_h_dim, 
                                  activations=activations, vector_gate=True,
                                  drop_rate=drop_rate, norm_first=True)
            for _ in range(num_layers))
        
        # Simple self-attention on pooled scalar node features
        self.attn = nn.MultiheadAttention(
            embed_dim=self.node_h_dim[0], num_heads=4, dropout=drop_rate, batch_first=True
        )
        self.attn_ln = nn.LayerNorm(self.node_h_dim[0])
        
        # Decoder layers
        self.W_s = nn.Embedding(self.out_dim, self.out_dim)
        self.edge_h_dim = (self.edge_h_dim[0] + self.out_dim, self.edge_h_dim[1])
        self.decoder_layers = nn.ModuleList(
                GVPConvLayer(self.node_h_dim, self.edge_h_dim,
                             activations=activations, vector_gate=True, 
                             drop_rate=drop_rate, autoregressive=True, norm_first=True) 
            for _ in range(num_layers))
        
        # Output
        self.W_out = GVP(self.node_h_dim, (self.out_dim, 0), activations=(None, None))
    
    def forward(self, batch):

        h_V = (batch.node_s, batch.node_v)
        h_E = (batch.edge_s, batch.edge_v)
        edge_index = batch.edge_index
        seq = batch.seq

        h_V = self.W_v(h_V)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
        h_E = self.W_e(h_E)  # (n_edges, n_conf, d_se), (n_edges, n_conf, d_ve, 3)

        for layer in self.encoder_layers:
            h_V = layer(h_V, edge_index, h_E)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)

        # Pool multi-conformation features: 
        # nodes: (n_nodes, d_s), (n_nodes, d_v, 3)
        # edges: (n_edges, d_se), (n_edges, d_ve, 3)
        h_V, h_E = self.pool_multi_conf(h_V, h_E, batch.mask_confs, edge_index)

        # Apply simple self-attention over nodes (sequence length = n_nodes)
        x = h_V[0].unsqueeze(0)  # (1, n_nodes, d_s)
        attn_out, _ = self.attn(x, x, x, need_weights=False)
        x = self.attn_ln(x + attn_out)
        h_V = (x.squeeze(0), h_V[1])

        encoder_embeddings = h_V
        
        h_S = self.W_s(seq)
        h_S = h_S[edge_index[0]]
        h_S[edge_index[0] >= edge_index[1]] = 0
        h_E = (torch.cat([h_E[0], h_S], dim=-1), h_E[1])
        
        for layer in self.decoder_layers:
            h_V = layer(h_V, edge_index, h_E, autoregressive_x = encoder_embeddings)
        
        logits = self.W_out(h_V)
        
        return logits
    
    @torch.no_grad()
    def sample(
            self, 
            batch, 
            n_samples, 
            temperature: Optional[float] = 0.1, 
            logit_bias: Optional[torch.Tensor] = None,
            return_logits: Optional[bool] = False
        ):
        '''
        Samples sequences autoregressively from the distribution
        learned by the model.

        Args:
            batch (torch_geometric.data.Data): mini-batch containing one
                RNA backbone to design sequences for
            n_samples (int): number of samples
            temperature (float): temperature to use in softmax over 
                the categorical distribution
            logit_bias (torch.Tensor): bias to add to logits during sampling
                to manually fix or control nucleotides in designed sequences,
                of shape [n_nodes, 4]
            return_logits (bool): whether to return logits or not
        
        Returns:
            seq (torch.Tensor): int tensor of shape [n_samples, n_nodes]
                                based on the residue-to-int mapping of
                                the original training data
            logits (torch.Tensor): logits of shape [n_samples, n_nodes, 4]
                                   (only if return_logits is True)
        ''' 
        h_V = (batch.node_s, batch.node_v)
        h_E = (batch.edge_s, batch.edge_v)
        edge_index = batch.edge_index
    
        device = edge_index.device
        num_nodes = h_V[0].shape[0]
        
        h_V = self.W_v(h_V)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
        h_E = self.W_e(h_E)  # (n_edges, n_conf, d_se), (n_edges, n_conf, d_ve, 3)
        
        for layer in self.encoder_layers:
            h_V = layer(h_V, edge_index, h_E)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
        
        # Pool multi-conformation features
        # nodes: (n_nodes, d_s), (n_nodes, d_v, 3)
        # edges: (n_edges, d_se), (n_edges, d_ve, 3)
        h_V, h_E = self.pool_multi_conf(h_V, h_E, batch.mask_confs, edge_index)
        
        # Apply simple self-attention over nodes (sequence length = n_nodes)
        x = h_V[0].unsqueeze(0)  # (1, n_nodes, d_s)
        attn_out, _ = self.attn(x, x, x, need_weights=False)
        x = self.attn_ln(x + attn_out)
        h_V = (x.squeeze(0), h_V[1])
        
        # Repeat features for sampling n_samples times
        h_V = (h_V[0].repeat(n_samples, 1),
            h_V[1].repeat(n_samples, 1, 1))
        h_E = (h_E[0].repeat(n_samples, 1),
            h_E[1].repeat(n_samples, 1, 1))
        
        # Expand edge index for autoregressive decoding
        edge_index = edge_index.expand(n_samples, -1, -1)
        offset = num_nodes * torch.arange(n_samples, device=device).view(-1, 1, 1)
        edge_index = torch.cat(tuple(edge_index + offset), dim=-1)
        # This is akin to 'batching' (in PyG style) n_samples copies of the graph
        
        seq = torch.zeros(n_samples * num_nodes, device=device, dtype=torch.int)
        h_S = torch.zeros(n_samples * num_nodes, self.out_dim, device=device)
        logits = torch.zeros(n_samples * num_nodes, self.out_dim, device=device)

        h_V_cache = [(h_V[0].clone(), h_V[1].clone()) for _ in self.decoder_layers]

        # Decode one token at a time
        for i in range(num_nodes):
            
            h_S_ = h_S[edge_index[0]]
            h_S_[edge_index[0] >= edge_index[1]] = 0
            h_E_ = (torch.cat([h_E[0], h_S_], dim=-1), h_E[1])
                    
            edge_mask = edge_index[1] % num_nodes == i  # True for all edges where dst is node i
            edge_index_ = edge_index[:, edge_mask]  # subset all incoming edges to node i
            h_E_ = tuple_index(h_E_, edge_mask)
            node_mask = torch.zeros(n_samples * num_nodes, device=device, dtype=torch.bool)
            node_mask[i::num_nodes] = True  # True for all nodes i and its repeats
            
            for j, layer in enumerate(self.decoder_layers):
                out = layer(h_V_cache[j], edge_index_, h_E_,
                        autoregressive_x=h_V_cache[0], node_mask=node_mask)
                
                out = tuple_index(out, node_mask)  # subset out to only node i and its repeats
                
                if j < len(self.decoder_layers)-1:
                    h_V_cache[j+1][0][i::num_nodes] = out[0]
                    h_V_cache[j+1][1][i::num_nodes] = out[1]
                
            lgts = self.W_out(out)
            # Add logit bias if provided to fix or bias positions
            if logit_bias is not None:
                lgts += logit_bias[i]
            # Sample from logits
            seq[i::num_nodes] = Categorical(logits=lgts / temperature).sample()
            h_S[i::num_nodes] = self.W_s(seq[i::num_nodes])
            logits[i::num_nodes] = lgts

        if return_logits:
            return seq.view(n_samples, num_nodes), logits.view(n_samples, num_nodes, self.out_dim)
        else:    
            return seq.view(n_samples, num_nodes)
        
    def pool_multi_conf(self, h_V, h_E, mask_confs, edge_index):

        if mask_confs.size(1) == 1:
            # Number of conformations is 1, no need to pool
            return (h_V[0][:, 0], h_V[1][:, 0]), (h_E[0][:, 0], h_E[1][:, 0])
        
        # True num_conf for masked mean pooling
        n_conf_true = mask_confs.sum(1, keepdim=True)  # (n_nodes, 1)
        
        # Mask scalar features
        mask = mask_confs.unsqueeze(2)  # (n_nodes, n_conf, 1)
        h_V0 = h_V[0] * mask
        h_E0 = h_E[0] * mask[edge_index[0]]

        # Mask vector features
        mask = mask.unsqueeze(3)  # (n_nodes, n_conf, 1, 1)
        h_V1 = h_V[1] * mask
        h_E1 = h_E[1] * mask[edge_index[0]]
        
        # Average pooling multi-conformation features
        h_V = (h_V0.sum(dim=1) / n_conf_true,               # (n_nodes, d_s)
               h_V1.sum(dim=1) / n_conf_true.unsqueeze(2))  # (n_nodes, d_v, 3)
        h_E = (h_E0.sum(dim=1) / n_conf_true[edge_index[0]],               # (n_edges, d_se)
               h_E1.sum(dim=1) / n_conf_true[edge_index[0]].unsqueeze(2))  # (n_edges, d_ve, 3)

        return h_V, h_E


class NonAutoregressiveMultiGNNv1(torch.nn.Module):
    '''
    Non-Autoregressive GVP-GNN for **multiple** structure-conditioned RNA design.
    
    Takes in RNA structure graphs of type `torch_geometric.data.Data` 
    or `torch_geometric.data.Batch` and returns a categorical distribution
    over 4 bases at each position in a `torch.Tensor` of shape [n_nodes, 4].
    
    The standard forward pass requires sequence information as input
    and should be used for training or evaluating likelihood.
    For sampling or design, use `self.sample`.
    
    Args:
        node_in_dim (tuple): node dimensions in input graph
        node_h_dim (tuple): node dimensions to use in GVP-GNN layers
        node_in_dim (tuple): edge dimensions in input graph
        edge_h_dim (tuple): edge dimensions to embed in GVP-GNN layers
        num_layers (int): number of GVP-GNN layers in encoder/decoder
        drop_rate (float): rate to use in all dropout layers
        out_dim (int): output dimension (4 bases)
    '''
    def __init__(
        self,
        node_in_dim = (64, 4), 
        node_h_dim = (128, 16), 
        edge_in_dim = (32, 1), 
        edge_h_dim = (32, 1),
        num_layers = 3, 
        drop_rate = 0.1,
        out_dim = 4,
    ):
        super().__init__()
        self.node_in_dim = node_in_dim
        self.node_h_dim = node_h_dim
        self.edge_in_dim = edge_in_dim
        self.edge_h_dim = edge_h_dim
        self.num_layers = num_layers
        self.out_dim = out_dim
        activations = (F.silu, None)
        
        # Node input embedding
        self.W_v = torch.nn.Sequential(
            LayerNorm(self.node_in_dim),
            GVP(self.node_in_dim, self.node_h_dim,
                activations=(None, None), vector_gate=True)
        )

        # Edge input embedding
        self.W_e = torch.nn.Sequential(
            LayerNorm(self.edge_in_dim),
            GVP(self.edge_in_dim, self.edge_h_dim, 
                activations=(None, None), vector_gate=True)
        )
        
        # Encoder layers (supports multiple conformations)
        self.encoder_layers = nn.ModuleList(
                MultiGVPConvLayer(self.node_h_dim, self.edge_h_dim, 
                                  activations=activations, vector_gate=True,
                                  drop_rate=drop_rate, norm_first=True)
            for _ in range(num_layers))
        
        # Simple self-attention on pooled scalar node features
        self.attn = nn.MultiheadAttention(
            embed_dim=self.node_h_dim[0], num_heads=4, dropout=drop_rate, batch_first=True
        )
        self.attn_ln = nn.LayerNorm(self.node_h_dim[0])
        
        # Output
        self.W_out = torch.nn.Sequential(
            LayerNorm(self.node_h_dim),
            GVP(self.node_h_dim, self.node_h_dim,
                activations=(None, None), vector_gate=True),
            GVP(self.node_h_dim, (self.out_dim, 0), 
                activations=(None, None))   
        )
    
    def forward(self, batch):

        h_V = (batch.node_s, batch.node_v)
        h_E = (batch.edge_s, batch.edge_v)
        edge_index = batch.edge_index
        
        h_V = self.W_v(h_V)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
        h_E = self.W_e(h_E)  # (n_edges, n_conf, d_se), (n_edges, n_conf, d_ve, 3)

        for layer in self.encoder_layers:
            h_V = layer(h_V, edge_index, h_E)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)

        # Pool multi-conformation features: 
        # nodes: (n_nodes, d_s), (n_nodes, d_v, 3)
        # edges: (n_edges, d_se), (n_edges, d_ve, 3)
        # h_V, h_E = self.pool_multi_conf(h_V, h_E, batch.mask_confs, edge_index)
        h_V = (h_V[0].mean(dim=1), h_V[1].mean(dim=1))

        # Apply simple self-attention over nodes (sequence length = n_nodes)
        x = h_V[0].unsqueeze(0)  # (1, n_nodes, d_s)
        attn_out, _ = self.attn(x, x, x, need_weights=False)
        x = self.attn_ln(x + attn_out)
        h_V = (x.squeeze(0), h_V[1])

        logits = self.W_out(h_V)  # (n_nodes, out_dim)
        
        return logits
    
    def sample(self, batch, n_samples, temperature=0.1, return_logits=False):
        
        with torch.no_grad():

            h_V = (batch.node_s, batch.node_v)
            h_E = (batch.edge_s, batch.edge_v)
            edge_index = batch.edge_index
        
            h_V = self.W_v(h_V)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
            h_E = self.W_e(h_E)  # (n_edges, n_conf, d_se), (n_edges, n_conf, d_ve, 3)
            
            for layer in self.encoder_layers:
                h_V = layer(h_V, edge_index, h_E)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
            
            # Pool multi-conformation features
            # h_V, h_E = self.pool_multi_conf(h_V, h_E, batch.mask_confs, edge_index)
            h_V = (h_V[0].mean(dim=1), h_V[1].mean(dim=1))
            
            logits = self.W_out(h_V)  # (n_nodes, out_dim)
            probs = F.softmax(logits / temperature, dim=-1)
            seq = torch.multinomial(probs, n_samples, replacement=True)  # (n_nodes, n_samples)

            if return_logits:
                return seq.permute(1, 0).contiguous(), logits.unsqueeze(0).repeat(n_samples, 1, 1)
            else:
                return seq.permute(1, 0).contiguous()
        
    def pool_multi_conf(self, h_V, h_E, mask_confs, edge_index):

        if mask_confs.size(1) == 1:
            # Number of conformations is 1, no need to pool
            return (h_V[0][:, 0], h_V[1][:, 0]), (h_E[0][:, 0], h_E[1][:, 0])
        
        # True num_conf for masked mean pooling
        n_conf_true = mask_confs.sum(1, keepdim=True)  # (n_nodes, 1)
        
        # Mask scalar features
        mask = mask_confs.unsqueeze(2)  # (n_nodes, n_conf, 1)
        h_V0 = h_V[0] * mask
        h_E0 = h_E[0] * mask[edge_index[0]]

        # Mask vector features
        mask = mask.unsqueeze(3)  # (n_nodes, n_conf, 1, 1)
        h_V1 = h_V[1] * mask
        h_E1 = h_E[1] * mask[edge_index[0]]
        
        # Average pooling multi-conformation features
        h_V = (h_V0.sum(dim=1) / n_conf_true,               # (n_nodes, d_s)
               h_V1.sum(dim=1) / n_conf_true.unsqueeze(2))  # (n_nodes, d_v, 3)
        h_E = (h_E0.sum(dim=1) / n_conf_true[edge_index[0]],               # (n_edges, d_se)
               h_E1.sum(dim=1) / n_conf_true[edge_index[0]].unsqueeze(2))  # (n_edges, d_ve, 3)

        return h_V, h_E

而由于本地服务器一般很难登上wandb,而作者原版代码在路径上用了很多相关的,因而在trainer.py更改路径,

1
2
3
4
5
6
7
8
 if device.type == 'xpu':
        import intel_extension_for_pytorch as ipex
        model, optimizer = ipex.optimize(model, optimizer=optimizer)
    #=======上面的和作者的一样=======
        # Initialise save directory
    save_dir = os.path.join(os.path.dirname(__file__), "..", "mymodel")
    #save_dir = os.path.abspath(save_dir) 这是换成绝对路径,可以不需要;模型保存在主项目目录的mymodel文件夹。  
    os.makedirs(save_dir, exist_ok=True)

之后把wandb.run.dir路径改为save_dir路径,防止训练模型后没有文件;模型默认选择autoaggresive。

注意,训练模型最好挂在后台运行,具体命令如下:

1
nohup python main.py --no_wandb > main.log 2>&1 &

然后可以查看log日志,

1
tail -f main.log

以及及时选择查看gpu的运行情况,例如,

1
nvidia-smi

第一次测试训练的模型

模型以best_checkpoint.h5形式得出,然后根据作者的gRNAde.py脚本,只需修改以下的加载路径和其他没啥用的print字符串就可以,由于多态生成我没有用,因此我没替换那个;其余俩都得替换。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
CHECKPOINT_PATH = {
    'all': {
        1: os.path.join(PROJECT_PATH, "mymodel/best_checkpoint.h5"),#修改为mymodel里的checkpoint
        2: os.path.join(PROJECT_PATH, "checkpoints/gRNAde_ARv1_2state_all.h5"),
        3: os.path.join(PROJECT_PATH, "checkpoints/gRNAde_ARv1_3state_all.h5"),
        5: os.path.join(PROJECT_PATH, "checkpoints/gRNAde_ARv1_5state_all.h5"),
    },
    'das': {
        1: os.path.join(PROJECT_PATH, "mymodel/best_checkpoint.h5"),#修改为mymodel里的checkpoint
        2: os.path.join(PROJECT_PATH, "checkpoints/gRNAde_ARv1_2state_das.h5"),
        3: os.path.join(PROJECT_PATH, "checkpoints/gRNAde_ARv1_3state_das.h5"),
        5: os.path.join(PROJECT_PATH, "checkpoints/gRNAde_ARv1_5state_das.h5"),
    },
    'multi': {
        1: os.path.join(PROJECT_PATH, "checkpoints/gRNAde_ARv1_1state_multi.h5"),
        2: os.path.join(PROJECT_PATH, "checkpoints/gRNAde_ARv1_2state_multi.h5"),
        3: os.path.join(PROJECT_PATH, "checkpoints/gRNAde_ARv1_3state_multi.h5"),
        5: os.path.join(PROJECT_PATH, "checkpoints/gRNAde_ARv1_5state_multi.h5"),
    }
}

除此之外,通过命令行人工测试太费劲了,需要写一个自动化脚本。具体逻辑就是,我在data/raw里面取一些名称作为索引,保存在.txt文件中;然后用脚本把原先命令行的命令包含进去,input文件路径下改为按索引名称.pdb递增,输出.fasta文件同理,注意models.py文件要保持同步!!!

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#这是自动化测试脚本
import os
import subprocess
import sys
from pathlib import Path

def read_test_index(index_file="test.txt"):
    """
    Read test_index.txt file to get the list of PDB files
    """
    try:
        with open(index_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        # Filter out empty lines and comment lines, remove newline characters
        pdb_files = []
        for line in lines:
            line = line.strip()
            if line and not line.startswith('#'):
                pdb_files.append(line)
        
        return pdb_files
    
    except FileNotFoundError:
        print(f"Error: File {index_file} not found")
        return []
    except Exception as e:
        print(f"Error reading file: {e}")
        return []

def extract_filename_without_extension(pdb_filename):
    """
    Extract the filename without extension from a PDB filename
    Example: 100D_1_A-B.pdb -> 100D_1_A-B
    """
    return Path(pdb_filename).stem

def run_command(pdb_file):
    """
    Execute the run.py command for the specified PDB file
    """
    # Extract filename (without extension) for output filename
    output_name = extract_filename_without_extension(pdb_file)
    
    # Build command
    cmd = [
        "python", "run.py",
        "--pdb_filepath", f"data/raw/{pdb_file}",
        "--output_filepath", f"testmyrna/{output_name}.fasta",
        "--split", "das",
        "--max_num_conformers", "1",
        "--n_samples", "16",
        "--temperature", "0.5"
    ]
    
    print(f"Processing: {pdb_file}")
    print(f"Executing command: {' '.join(cmd)}")
    
    try:
        # Execute command
        result = subprocess.run(cmd, capture_output=True, text=True, check=True)
        print(f"✓ Successfully processed {pdb_file}")
        print(f"Output file: testmyrna/{output_name}.fasta")
        
        # Display part of the output if available
        if result.stdout:
            print("Standard output:")
            print(result.stdout[:200] + ("..." if len(result.stdout) > 200 else ""))
        
        return True
        
    except subprocess.CalledProcessError as e:
        print(f"✗ Failed to process {pdb_file}")
        print(f"Error code: {e.returncode}")
        if e.stderr:
            print(f"Error message: {e.stderr}")
        return False
    
    except Exception as e:
        print(f"✗ Unexpected error occurred while processing {pdb_file}: {e}")
        return False

def main():
    """
    Main function
    """
    print("=== RNA Test Script Started ===")
    
    # Check for required files and directories
    if not os.path.exists("run.py"):
        print("Error: run.py file not found in the current directory")
        sys.exit(1)
    
    if not os.path.exists("test.txt"):
        print("Error: test.txt file not found in the current directory")
        sys.exit(1)
    
    # Ensure output directory exists
    output_dir = "testmyrna"
    os.makedirs(output_dir, exist_ok=True)
    print(f"Output directory: {output_dir}")
    
    # Read test index file
    pdb_files = read_test_index()
    
    if not pdb_files:
        print("Warning: test.txt file is empty or failed to read")
        sys.exit(1)
    
    print(f"Found {len(pdb_files)} PDB files to process:")
    for i, pdb_file in enumerate(pdb_files, 1):
        print(f"  {i}. {pdb_file}")
    
    print("\nStarting processing...")
    
    # Statistics
    success_count = 0
    failed_files = []
    
    # Process each PDB file
    for i, pdb_file in enumerate(pdb_files, 1):
        print(f"\n[{i}/{len(pdb_files)}] " + "="*50)
        
        # Check if input file exists
        input_path = f"data/raw/{pdb_file}"
        if not os.path.exists(input_path):
            print(f"Warning: Input file {input_path} does not exist, skipping...")
            failed_files.append(pdb_file)
            continue
        
        # Execute command
        if run_command(pdb_file):
            success_count += 1
        else:
            failed_files.append(pdb_file)
    
    # Print summary
    print("\n" + "="*60)
    print("=== Processing Completed ===")
    print(f"Total: {len(pdb_files)} files")
    print(f"Success: {success_count} files")
    print(f"Failed: {len(failed_files)} files")
    
    if failed_files:
        print("\nFailed files:")
        for failed_file in failed_files:
            print(f"  - {failed_file}")
    
    print(f"\nOutput files saved in: {output_dir}/")

if __name__ == "__main__":
    main()

注意,有时sec_struct_utils.py文件清除缓存速度太快了,所以我进行修改,首先把这部分清除的注释了:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
 output = subprocess.run(cmd, check=True, capture_output=True).stdout.decode("utf-8")

    # Delete temporary files这三行注释掉
   # if sequence is not None:
   #     os.remove(fasta_file_path)

    if n_samples > 1:
        return output.split("\n")[:-1]
    else:
        return [output.split("\n")[-2]]

之后把路径改了,

1
2
3
4
5
    current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")
    try:
        fasta_file_path = os.path.join(wandb.run.dir, f"temp_{current_datetime}.fasta")
    except AttributeError:
        fasta_file_path = os.path.join(PROJECT_PATH, "temp", f"temp_{current_datetime}.fasta")#改这里路径后是这样,我一般不用wandb

生成在某文件夹后,需要用新的脚本处理数据,由于我的sample为16,因而需要分别计算length, perplexity, recovery, edit distance, SC Score的平均值,并按照列的顺序保存为.txt文件:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import os
import re
import glob
from pathlib import Path

def parse_fasta_file(fasta_path):
    """
    Parse a single FASTA file to extract the input sequence length and metrics for 16 samples
    
    Returns:
    - input_length: Length of the input sequence
    - avg_metrics: Average metrics for 16 samples {perplexity, recovery, edit_dist, sc_score}
    """
    try:
        with open(fasta_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # Split different sequence blocks
        sequences = content.strip().split('>')
        sequences = [seq.strip() for seq in sequences if seq.strip()]
        
        input_length = 0
        sample_metrics = []
        
        for seq_block in sequences:
            lines = seq_block.strip().split('\n')
            if not lines:
                continue
            
            header = lines[0]
            sequence_lines = lines[1:]
            
            # Process input sequence
            if 'input_sequence' in header:
                # Combine all sequence lines and calculate length
                sequence = ''.join(sequence_lines).replace(' ', '').replace('\n', '')
                input_length = len(sequence)
                print(f"  Input sequence length: {input_length}")
            
            # Process sample sequence
            elif 'sample=' in header:
                # Extract metrics using regular expressions
                perplexity_match = re.search(r'perplexity=([0-9.]+)', header)
                recovery_match = re.search(r'recovery=([0-9.]+)', header)
                edit_dist_match = re.search(r'edit_dist=([0-9.]+)', header)
                sc_score_match = re.search(r'sc_score=([0-9.]+)', header)
                
                if all([perplexity_match, recovery_match, edit_dist_match, sc_score_match]):
                    metrics = {
                        'perplexity': float(perplexity_match.group(1)),
                        'recovery': float(recovery_match.group(1)),
                        'edit_dist': float(edit_dist_match.group(1)),
                        'sc_score': float(sc_score_match.group(1))
                    }
                    sample_metrics.append(metrics)
        
        # Calculate averages
        if sample_metrics:
            avg_metrics = {
                'perplexity': sum(m['perplexity'] for m in sample_metrics) / len(sample_metrics),
                'recovery': sum(m['recovery'] for m in sample_metrics) / len(sample_metrics),
                'edit_dist': sum(m['edit_dist'] for m in sample_metrics) / len(sample_metrics),
                'sc_score': sum(m['sc_score'] for m in sample_metrics) / len(sample_metrics)
            }
            print(f"  Found {len(sample_metrics)} samples")
            print(f"  Average metrics: perplexity={avg_metrics['perplexity']:.4f}, recovery={avg_metrics['recovery']:.4f}, edit_dist={avg_metrics['edit_dist']:.4f}, sc_score={avg_metrics['sc_score']:.4f}")
        else:
            print("  Warning: No valid sample metrics found")
            avg_metrics = None
        
        return input_length, avg_metrics
        
    except Exception as e:
        print(f"  Error: Exception occurred while processing file: {e}")
        return 0, None

def process_all_fasta_files(input_dir="tout", output_file="data/plotdata/plot.txt"):
    """
    Process all FASTA files in the specified directory
    """
    print("=== FASTA File Processing Script ===")
    
    # Check input directory
    if not os.path.exists(input_dir):
        print(f"Error: Input directory does not exist: {input_dir}")
        return
    
    # Create output directory
    output_dir = os.path.dirname(output_file)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        print(f"Output directory: {output_dir}")
    
    # Find all FASTA files
    fasta_pattern = os.path.join(input_dir, "*.fasta")
    fasta_files = glob.glob(fasta_pattern)
    
    if not fasta_files:
        print(f"Warning: No .fasta files found in {input_dir}")
        return
    
    print(f"Found {len(fasta_files)} FASTA files")
    
    # Store processing results
    results = []
    processed_count = 0
    failed_count = 0
    
    # Process each FASTA file
    for i, fasta_file in enumerate(sorted(fasta_files), 1):
        filename = os.path.basename(fasta_file)
        print(f"\n[{i}/{len(fasta_files)}] Processing file: {filename}")
        
        input_length, avg_metrics = parse_fasta_file(fasta_file)
        
        if input_length > 0 and avg_metrics is not None:
            # Format result
            result_line = f"{input_length} {avg_metrics['perplexity']:.4f} {avg_metrics['recovery']:.4f} {avg_metrics['edit_dist']:.4f} {avg_metrics['sc_score']:.4f}"
            results.append(result_line)
            processed_count += 1
            print(f"  ✓ Processed successfully")
        else:
            print(f"  ✗ Processing failed")
            failed_count += 1
    
    # Save results
    if results:
        try:
            with open(output_file, 'w', encoding='utf-8') as f:
                for result in results:
                    f.write(result + '\n')
            
            print(f"\n=== Processing Completed ===")
            print(f"Total files: {len(fasta_files)}")
            print(f"Successfully processed: {processed_count}")
            print(f"Failed: {failed_count}")
            print(f"Results saved to: {output_file}")
            
            # Show preview of first few lines
            print(f"\nResults preview (first 5 lines):")
            for i, result in enumerate(results[:5]):
                print(f"  {result}")
            if len(results) > 5:
                print(f"  ... (total {len(results)} lines)")
                
        except Exception as e:
            print(f"Error saving file: {e}")
    else:
        print("No successfully processed data, unable to generate output file")

def validate_output_format(output_file):
    """
    Validate the format of the output file
    """
    if not os.path.exists(output_file):
        print("Output file does not exist")
        return
    
    print(f"\n=== Validate Output Format ===")
    try:
        with open(output_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        print(f"Total lines: {len(lines)}")
        
        for i, line in enumerate(lines[:3], 1):  # Check first 3 lines
            parts = line.strip().split()
            if len(parts) == 5:
                length = int(parts[0])
                perplexity = float(parts[1])
                recovery = float(parts[2])
                edit_dist = float(parts[3])
                sc_score = float(parts[4])
                print(f"Line {i}: length={length}, perplexity={perplexity}, recovery={recovery}, edit_dist={edit_dist}, sc_score={sc_score}")
            else:
                print(f"Line {i} format error: {line.strip()}")
    
    except Exception as e:
        print(f"Error during validation: {e}")

def main():
    """
    Main function
    """
    # Set input and output paths
    input_directory = "tout"
    output_filepath = "data/plotdata/plotgrna.txt"
    
    # Process all FASTA files
    process_all_fasta_files(input_directory, output_filepath)
    
    # Validate output format
    validate_output_format(output_filepath)

if __name__ == "__main__":
    main()

以及matlab的绘图代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
% Script to visualize RNA prediction metrics from else.txt

% Load data
data = load('g.txt'); % Format: [length, perplexity, recovery, edit_dist, sc_score]
lengths = data(:, 1);
perplexity = data(:, 2);
recovery = data(:, 3);
edit_dist = data(:, 4);
sc_score = data(:, 5);

% Compute unique lengths and their average metrics
[unique_lengths, ~, idx] = unique(lengths);
n = length(unique_lengths);
avg_perplexity = zeros(n,1);
avg_recovery = zeros(n,1);
avg_edit_dist = zeros(n,1);
avg_sc_score = zeros(n,1);

for i = 1:n
    avg_perplexity(i) = mean(perplexity(idx == i));
    avg_recovery(i) = mean(recovery(idx == i));
    avg_edit_dist(i) = mean(edit_dist(idx == i));
    avg_sc_score(i) = mean(sc_score(idx == i));
end

% Prepare 2x2 subplot
figure('Name','RNA Prediction Metrics','NumberTitle','off');
metrics = {avg_perplexity, avg_recovery, avg_edit_dist, avg_sc_score};
titles = {
    'Variation of Perplexity with RNA Sequence Length', ...
    'Variation of Recovery Rate with RNA Sequence Length', ...
    'Variation of Edit Distance with RNA Sequence Length', ...
    'Variation of Structural Conservation Score with RNA Sequence Length'
};
ylabels = {'Average Perplexity', 'Average Recovery Rate', ...
           'Average Edit Distance', 'Average SC Score'};

% Set Times New Roman font for all text
set(0,'defaultAxesFontName','Times New Roman');
set(0,'defaultTextFontName','Times New Roman');


markerSize = 10;      
markerColor = [0.9, 0.2, 0.2]; 
lineWidth = 1;     

for i = 1:4
    subplot(2,2,i);
    scatter(unique_lengths, metrics{i}, markerSize, ...
            'MarkerEdgeColor', markerColor, ...
            'MarkerFaceColor', markerColor, ...
            'LineWidth', lineWidth);
    
    xlabel('RNA Sequence Length (nt)','FontName','Times New Roman');
    ylabel(ylabels{i},'FontName','Times New Roman');
    title(titles{i},'FontName','Times New Roman');
    grid on;
    set(gca,'FontName','Times New Roman');
end

% Adjust layout
sgtitle('Comparative Analysis of GRNADE Prediction Metrics','FontName','Times New Roman');

除此之外,模型models.py如下:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
################################################################
# Generalisation of Geometric Vector Perceptron, Jing et al.
# for explicit multi-state biomolecule representation learning.
# Original repository: https://github.com/drorlab/gvp-pytorch
################################################################

from typing import Optional
import torch
from torch import nn
import torch.nn.functional as F
from torch.distributions import Categorical
import torch_geometric

from src.layers import *


class AutoregressiveMultiGNNv1(torch.nn.Module):
    '''
    Autoregressive GVP-GNN for **multiple** structure-conditioned RNA design.
    
    Takes in RNA structure graphs of type `torch_geometric.data.Data` 
    or `torch_geometric.data.Batch` and returns a categorical distribution
    over 4 bases at each position in a `torch.Tensor` of shape [n_nodes, 4].
    
    The standard forward pass requires sequence information as input
    and should be used for training or evaluating likelihood.
    For sampling or design, use `self.sample`.

    Args:
        node_in_dim (tuple): node dimensions in input graph
        node_h_dim (tuple): node dimensions to use in GVP-GNN layers
        node_in_dim (tuple): edge dimensions in input graph
        edge_h_dim (tuple): edge dimensions to embed in GVP-GNN layers
        num_layers (int): number of GVP-GNN layers in encoder/decoder
        drop_rate (float): rate to use in all dropout layers
        out_dim (int): output dimension (4 bases)
    '''
    def __init__(
        self,
        node_in_dim = (64, 4), 
        node_h_dim = (128, 16), 
        edge_in_dim = (32, 1), 
        edge_h_dim = (32, 1),
        num_layers = 3, 
        drop_rate = 0.1,
        out_dim = 4,
    ):
        super().__init__()
        self.node_in_dim = node_in_dim
        self.node_h_dim = node_h_dim
        self.edge_in_dim = edge_in_dim
        self.edge_h_dim = edge_h_dim
        self.num_layers = num_layers
        self.out_dim = out_dim
        activations = (F.silu, None)
        
        # Node input embedding
        self.W_v = torch.nn.Sequential(
            LayerNorm(self.node_in_dim),
            GVP(self.node_in_dim, self.node_h_dim,
                activations=(None, None), vector_gate=True)
        )

        # Edge input embedding
        self.W_e = torch.nn.Sequential(
            LayerNorm(self.edge_in_dim),
            GVP(self.edge_in_dim, self.edge_h_dim, 
                activations=(None, None), vector_gate=True)
        )
        
        # Encoder layers (supports multiple conformations)
        self.encoder_layers = nn.ModuleList(
                MultiGVPConvLayer(self.node_h_dim, self.edge_h_dim, 
                                  activations=activations, vector_gate=True,
                                  drop_rate=drop_rate, norm_first=True)
            for _ in range(num_layers))
        
        # Simple self-attention on pooled scalar node features
        self.attn = nn.MultiheadAttention(
            embed_dim=self.node_h_dim[0], num_heads=4, dropout=drop_rate, batch_first=True
        )
        self.attn_ln = nn.LayerNorm(self.node_h_dim[0])
        
        # Decoder layers
        self.W_s = nn.Embedding(self.out_dim, self.out_dim)
        self.edge_h_dim = (self.edge_h_dim[0] + self.out_dim, self.edge_h_dim[1])
        self.decoder_layers = nn.ModuleList(
                GVPConvLayer(self.node_h_dim, self.edge_h_dim,
                             activations=activations, vector_gate=True, 
                             drop_rate=drop_rate, autoregressive=True, norm_first=True) 
            for _ in range(num_layers))
        
        # Output
        self.W_out = GVP(self.node_h_dim, (self.out_dim, 0), activations=(None, None))
    
    def forward(self, batch):

        h_V = (batch.node_s, batch.node_v)
        h_E = (batch.edge_s, batch.edge_v)
        edge_index = batch.edge_index
        seq = batch.seq

        h_V = self.W_v(h_V)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
        h_E = self.W_e(h_E)  # (n_edges, n_conf, d_se), (n_edges, n_conf, d_ve, 3)

        for layer in self.encoder_layers:
            h_V = layer(h_V, edge_index, h_E)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)

        # Pool multi-conformation features: 
        # nodes: (n_nodes, d_s), (n_nodes, d_v, 3)
        # edges: (n_edges, d_se), (n_edges, d_ve, 3)
        h_V, h_E = self.pool_multi_conf(h_V, h_E, batch.mask_confs, edge_index)

        # Apply simple self-attention over nodes (sequence length = n_nodes)
        x = h_V[0].unsqueeze(0)  # (1, n_nodes, d_s)
        attn_out, _ = self.attn(x, x, x, need_weights=False)
        x = self.attn_ln(x + attn_out)
        h_V = (x.squeeze(0), h_V[1])

        encoder_embeddings = h_V
        
        h_S = self.W_s(seq)
        h_S = h_S[edge_index[0]]
        h_S[edge_index[0] >= edge_index[1]] = 0
        h_E = (torch.cat([h_E[0], h_S], dim=-1), h_E[1])
        
        for layer in self.decoder_layers:
            h_V = layer(h_V, edge_index, h_E, autoregressive_x = encoder_embeddings)
        
        logits = self.W_out(h_V)
        
        return logits
    
    @torch.no_grad()
    def sample(
            self, 
            batch, 
            n_samples, 
            temperature: Optional[float] = 0.1, 
            logit_bias: Optional[torch.Tensor] = None,
            return_logits: Optional[bool] = False
        ):
        '''
        Samples sequences autoregressively from the distribution
        learned by the model.

        Args:
            batch (torch_geometric.data.Data): mini-batch containing one
                RNA backbone to design sequences for
            n_samples (int): number of samples
            temperature (float): temperature to use in softmax over 
                the categorical distribution
            logit_bias (torch.Tensor): bias to add to logits during sampling
                to manually fix or control nucleotides in designed sequences,
                of shape [n_nodes, 4]
            return_logits (bool): whether to return logits or not
        
        Returns:
            seq (torch.Tensor): int tensor of shape [n_samples, n_nodes]
                                based on the residue-to-int mapping of
                                the original training data
            logits (torch.Tensor): logits of shape [n_samples, n_nodes, 4]
                                   (only if return_logits is True)
        ''' 
        h_V = (batch.node_s, batch.node_v)
        h_E = (batch.edge_s, batch.edge_v)
        edge_index = batch.edge_index
    
        device = edge_index.device
        num_nodes = h_V[0].shape[0]
        
        h_V = self.W_v(h_V)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
        h_E = self.W_e(h_E)  # (n_edges, n_conf, d_se), (n_edges, n_conf, d_ve, 3)
        
        for layer in self.encoder_layers:
            h_V = layer(h_V, edge_index, h_E)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
        
        # Pool multi-conformation features
        # nodes: (n_nodes, d_s), (n_nodes, d_v, 3)
        # edges: (n_edges, d_se), (n_edges, d_ve, 3)
        h_V, h_E = self.pool_multi_conf(h_V, h_E, batch.mask_confs, edge_index)
        
        # Apply simple self-attention over nodes (sequence length = n_nodes)
        x = h_V[0].unsqueeze(0)  # (1, n_nodes, d_s)
        attn_out, _ = self.attn(x, x, x, need_weights=False)
        x = self.attn_ln(x + attn_out)
        h_V = (x.squeeze(0), h_V[1])
        
        # Repeat features for sampling n_samples times
        h_V = (h_V[0].repeat(n_samples, 1),
            h_V[1].repeat(n_samples, 1, 1))
        h_E = (h_E[0].repeat(n_samples, 1),
            h_E[1].repeat(n_samples, 1, 1))
        
        # Expand edge index for autoregressive decoding
        edge_index = edge_index.expand(n_samples, -1, -1)
        offset = num_nodes * torch.arange(n_samples, device=device).view(-1, 1, 1)
        edge_index = torch.cat(tuple(edge_index + offset), dim=-1)
        # This is akin to 'batching' (in PyG style) n_samples copies of the graph
        
        seq = torch.zeros(n_samples * num_nodes, device=device, dtype=torch.int)
        h_S = torch.zeros(n_samples * num_nodes, self.out_dim, device=device)
        logits = torch.zeros(n_samples * num_nodes, self.out_dim, device=device)

        h_V_cache = [(h_V[0].clone(), h_V[1].clone()) for _ in self.decoder_layers]

        # Decode one token at a time
        for i in range(num_nodes):
            
            h_S_ = h_S[edge_index[0]]
            h_S_[edge_index[0] >= edge_index[1]] = 0
            h_E_ = (torch.cat([h_E[0], h_S_], dim=-1), h_E[1])
                    
            edge_mask = edge_index[1] % num_nodes == i  # True for all edges where dst is node i
            edge_index_ = edge_index[:, edge_mask]  # subset all incoming edges to node i
            h_E_ = tuple_index(h_E_, edge_mask)
            node_mask = torch.zeros(n_samples * num_nodes, device=device, dtype=torch.bool)
            node_mask[i::num_nodes] = True  # True for all nodes i and its repeats
            
            for j, layer in enumerate(self.decoder_layers):
                out = layer(h_V_cache[j], edge_index_, h_E_,
                        autoregressive_x=h_V_cache[0], node_mask=node_mask)
                
                out = tuple_index(out, node_mask)  # subset out to only node i and its repeats
                
                if j < len(self.decoder_layers)-1:
                    h_V_cache[j+1][0][i::num_nodes] = out[0]
                    h_V_cache[j+1][1][i::num_nodes] = out[1]
                
            lgts = self.W_out(out)
            # Add logit bias if provided to fix or bias positions
            if logit_bias is not None:
                lgts += logit_bias[i]
            # Sample from logits
            seq[i::num_nodes] = Categorical(logits=lgts / temperature).sample()
            h_S[i::num_nodes] = self.W_s(seq[i::num_nodes])
            logits[i::num_nodes] = lgts

        if return_logits:
            return seq.view(n_samples, num_nodes), logits.view(n_samples, num_nodes, self.out_dim)
        else:    
            return seq.view(n_samples, num_nodes)
        
    def pool_multi_conf(self, h_V, h_E, mask_confs, edge_index):

        if mask_confs.size(1) == 1:
            # Number of conformations is 1, no need to pool
            return (h_V[0][:, 0], h_V[1][:, 0]), (h_E[0][:, 0], h_E[1][:, 0])
        
        # True num_conf for masked mean pooling
        n_conf_true = mask_confs.sum(1, keepdim=True)  # (n_nodes, 1)
        
        # Mask scalar features
        mask = mask_confs.unsqueeze(2)  # (n_nodes, n_conf, 1)
        h_V0 = h_V[0] * mask
        h_E0 = h_E[0] * mask[edge_index[0]]

        # Mask vector features
        mask = mask.unsqueeze(3)  # (n_nodes, n_conf, 1, 1)
        h_V1 = h_V[1] * mask
        h_E1 = h_E[1] * mask[edge_index[0]]
        
        # Average pooling multi-conformation features
        h_V = (h_V0.sum(dim=1) / n_conf_true,               # (n_nodes, d_s)
               h_V1.sum(dim=1) / n_conf_true.unsqueeze(2))  # (n_nodes, d_v, 3)
        h_E = (h_E0.sum(dim=1) / n_conf_true[edge_index[0]],               # (n_edges, d_se)
               h_E1.sum(dim=1) / n_conf_true[edge_index[0]].unsqueeze(2))  # (n_edges, d_ve, 3)

        return h_V, h_E


class NonAutoregressiveMultiGNNv1(torch.nn.Module):
    '''
    Non-Autoregressive GVP-GNN for **multiple** structure-conditioned RNA design.
    
    Takes in RNA structure graphs of type `torch_geometric.data.Data` 
    or `torch_geometric.data.Batch` and returns a categorical distribution
    over 4 bases at each position in a `torch.Tensor` of shape [n_nodes, 4].
    
    The standard forward pass requires sequence information as input
    and should be used for training or evaluating likelihood.
    For sampling or design, use `self.sample`.
    
    Args:
        node_in_dim (tuple): node dimensions in input graph
        node_h_dim (tuple): node dimensions to use in GVP-GNN layers
        node_in_dim (tuple): edge dimensions in input graph
        edge_h_dim (tuple): edge dimensions to embed in GVP-GNN layers
        num_layers (int): number of GVP-GNN layers in encoder/decoder
        drop_rate (float): rate to use in all dropout layers
        out_dim (int): output dimension (4 bases)
    '''
    def __init__(
        self,
        node_in_dim = (64, 4), 
        node_h_dim = (128, 16), 
        edge_in_dim = (32, 1), 
        edge_h_dim = (32, 1),
        num_layers = 3, 
        drop_rate = 0.1,
        out_dim = 4,
    ):
        super().__init__()
        self.node_in_dim = node_in_dim
        self.node_h_dim = node_h_dim
        self.edge_in_dim = edge_in_dim
        self.edge_h_dim = edge_h_dim
        self.num_layers = num_layers
        self.out_dim = out_dim
        activations = (F.silu, None)
        
        # Node input embedding
        self.W_v = torch.nn.Sequential(
            LayerNorm(self.node_in_dim),
            GVP(self.node_in_dim, self.node_h_dim,
                activations=(None, None), vector_gate=True)
        )

        # Edge input embedding
        self.W_e = torch.nn.Sequential(
            LayerNorm(self.edge_in_dim),
            GVP(self.edge_in_dim, self.edge_h_dim, 
                activations=(None, None), vector_gate=True)
        )
        
        # Encoder layers (supports multiple conformations)
        self.encoder_layers = nn.ModuleList(
                MultiGVPConvLayer(self.node_h_dim, self.edge_h_dim, 
                                  activations=activations, vector_gate=True,
                                  drop_rate=drop_rate, norm_first=True)
            for _ in range(num_layers))
        
        # Output
        self.W_out = torch.nn.Sequential(
            LayerNorm(self.node_h_dim),
            GVP(self.node_h_dim, self.node_h_dim,
                activations=(None, None), vector_gate=True),
            GVP(self.node_h_dim, (self.out_dim, 0), 
                activations=(None, None))   
        )
    
    def forward(self, batch):

        h_V = (batch.node_s, batch.node_v)
        h_E = (batch.edge_s, batch.edge_v)
        edge_index = batch.edge_index
        
        h_V = self.W_v(h_V)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
        h_E = self.W_e(h_E)  # (n_edges, n_conf, d_se), (n_edges, n_conf, d_ve, 3)

        for layer in self.encoder_layers:
            h_V = layer(h_V, edge_index, h_E)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)

        # Pool multi-conformation features: 
        # nodes: (n_nodes, d_s), (n_nodes, d_v, 3)
        # edges: (n_edges, d_se), (n_edges, d_ve, 3)
        # h_V, h_E = self.pool_multi_conf(h_V, h_E, batch.mask_confs, edge_index)
        h_V = (h_V[0].mean(dim=1), h_V[1].mean(dim=1))

        logits = self.W_out(h_V)  # (n_nodes, out_dim)
        
        return logits
    
    def sample(self, batch, n_samples, temperature=0.1, return_logits=False):
        
        with torch.no_grad():

            h_V = (batch.node_s, batch.node_v)
            h_E = (batch.edge_s, batch.edge_v)
            edge_index = batch.edge_index
        
            h_V = self.W_v(h_V)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
            h_E = self.W_e(h_E)  # (n_edges, n_conf, d_se), (n_edges, n_conf, d_ve, 3)
            
            for layer in self.encoder_layers:
                h_V = layer(h_V, edge_index, h_E)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
            
            # Pool multi-conformation features
            # h_V, h_E = self.pool_multi_conf(h_V, h_E, batch.mask_confs, edge_index)
            h_V = (h_V[0].mean(dim=1), h_V[1].mean(dim=1))
            
            logits = self.W_out(h_V)  # (n_nodes, out_dim)
            probs = F.softmax(logits / temperature, dim=-1)
            seq = torch.multinomial(probs, n_samples, replacement=True)  # (n_nodes, n_samples)

            if return_logits:
                return seq.permute(1, 0).contiguous(), logits.unsqueeze(0).repeat(n_samples, 1, 1)
            else:
                return seq.permute(1, 0).contiguous()
        
    def pool_multi_conf(self, h_V, h_E, mask_confs, edge_index):

        if mask_confs.size(1) == 1:
            # Number of conformations is 1, no need to pool
            return (h_V[0][:, 0], h_V[1][:, 0]), (h_E[0][:, 0], h_E[1][:, 0])
        
        # True num_conf for masked mean pooling
        n_conf_true = mask_confs.sum(1, keepdim=True)  # (n_nodes, 1)
        
        # Mask scalar features
        mask = mask_confs.unsqueeze(2)  # (n_nodes, n_conf, 1)
        h_V0 = h_V[0] * mask
        h_E0 = h_E[0] * mask[edge_index[0]]

        # Mask vector features
        mask = mask.unsqueeze(3)  # (n_nodes, n_conf, 1, 1)
        h_V1 = h_V[1] * mask
        h_E1 = h_E[1] * mask[edge_index[0]]
        
        # Average pooling multi-conformation features
        h_V = (h_V0.sum(dim=1) / n_conf_true,               # (n_nodes, d_s)
               h_V1.sum(dim=1) / n_conf_true.unsqueeze(2))  # (n_nodes, d_v, 3)
        h_E = (h_E0.sum(dim=1) / n_conf_true[edge_index[0]],               # (n_edges, d_se)
               h_E1.sum(dim=1) / n_conf_true[edge_index[0]].unsqueeze(2))  # (n_edges, d_ve, 3)

        return h_V, h_E

经过实验与对比,起初的问题是,当epoch = 1的时候,模型简直非常差:

很差的效果

但是当epoch提升到50,肉眼可见的好多了,和作者训练的模型几乎一致,因此代码是没问题的:

epoch=50的效果

然而在我训练后发现我加了attention的模型与不加一模一样,后来发现,我在nonaggressive里面加的,而模型默认没选择这个。

效果几乎一模一样,很奇怪

在之后,我重新修改了multiheadattention代码,训练好模型后如下,进行预测

我的模型
使用情况正常

我随机抽取了500个datapoint后进行测试,保持全部的一致,控制变量,结果如图:

实际情况看来,貌似有效果但不大

因此我在考虑multi-scale attention的想法,根据论文Atlas: Multi-Scale Attention Improves Long Context Image Modeling的思路与AI的协助(实际上压力ai了),现在模型是如下这样,训练之后在进行测试。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
################################################################
# Generalisation of Geometric Vector Perceptron, Jing et al.
# for explicit multi-state biomolecule representation learning.
# Original repository: https://github.com/drorlab/gvp-pytorch
################################################################

from typing import Optional
import torch
from torch import nn
import torch.nn.functional as F
from torch.distributions import Categorical
import torch_geometric

from src.layers import *


class MultiScaleAttention(nn.Module):
    '''
    Multi-scale attention module to capture dependencies at different window sizes.
    '''
    def __init__(
        self,
        embed_dim: int,
        num_heads: int = 8,
        window_sizes: list = [10, 50, 200, None],  # None for global scale
        dropout: float = 0.1
    ):
        super().__init__()
        self.window_sizes = window_sizes
        self.attentions = nn.ModuleList([
            nn.MultiheadAttention(
                embed_dim=embed_dim,
                num_heads=num_heads,
                dropout=dropout,
                batch_first=True
            ) for _ in window_sizes
        ])
    
    def forward(self, x, mask=None):
        # x: (batch_size, seq_len, embed_dim)
        outputs = []
        seq_len = x.size(1)
        
        for idx, (attn, window_size) in enumerate(zip(self.attentions, self.window_sizes)):
            if window_size is None:
                # Global attention
                attn_output, _ = attn(x, x, x, attn_mask=mask)
            else:
                # Local window attention: create sliding window mask or approximate
                # For simplicity, we can use full attention but in practice, implement windowing
                # Here, as a placeholder, apply full attention per scale (can be optimized with sparse masks)
                attn_output, _ = attn(x, x, x, attn_mask=mask)
            outputs.append(attn_output)
        
        # Fuse outputs: average across scales
        fused_output = torch.mean(torch.stack(outputs), dim=0)
        return fused_output


class AutoregressiveMultiGNNv1(torch.nn.Module):
    '''
    Autoregressive GVP-GNN for **multiple** structure-conditioned RNA design.
    
    Takes in RNA structure graphs of type `torch_geometric.data.Data` 
    or `torch_geometric.data.Batch` and returns a categorical distribution
    over 4 bases at each position in a `torch.Tensor` of shape [n_nodes, 4].
    
    The standard forward pass requires sequence information as input
    and should be used for training or evaluating likelihood.
    For sampling or design, use `self.sample`.

    Args:
        node_in_dim (tuple): node dimensions in input graph
        node_h_dim (tuple): node dimensions to use in GVP-GNN layers
        node_in_dim (tuple): edge dimensions in input graph
        edge_h_dim (tuple): edge dimensions to embed in GVP-GNN layers
        num_layers (int): number of GVP-GNN layers in encoder/decoder
        drop_rate (float): rate to use in all dropout layers
        out_dim (int): output dimension (4 bases)
    '''
    def __init__(
        self,
        node_in_dim = (64, 4), 
        node_h_dim = (128, 16), 
        edge_in_dim = (32, 1), 
        edge_h_dim = (32, 1),
        num_layers = 3, 
        drop_rate = 0.1,
        out_dim = 4,
        num_attention_heads = 8,
        attention_window_sizes = [10, 50, 200, None],  # Multi-scale windows
    ):
        super().__init__()
        self.node_in_dim = node_in_dim
        self.node_h_dim = node_h_dim
        self.edge_in_dim = edge_in_dim
        self.edge_h_dim = edge_h_dim
        self.num_layers = num_layers
        self.out_dim = out_dim
        activations = (F.silu, None)
        
        # Node input embedding
        self.W_v = torch.nn.Sequential(
            LayerNorm(self.node_in_dim),
            GVP(self.node_in_dim, self.node_h_dim,
                activations=(None, None), vector_gate=True)
        )

        # Edge input embedding
        self.W_e = torch.nn.Sequential(
            LayerNorm(self.edge_in_dim),
            GVP(self.edge_in_dim, self.edge_h_dim, 
                activations=(None, None), vector_gate=True)
        )
        
        # Encoder layers (supports multiple conformations)
        self.encoder_layers = nn.ModuleList(
                MultiGVPConvLayer(self.node_h_dim, self.edge_h_dim, 
                                  activations=activations, vector_gate=True,
                                  drop_rate=drop_rate, norm_first=True)
            for _ in range(num_layers))
        
        # Multi-scale attention for capturing long-distance dependencies at different scales
        self.multi_scale_attention = MultiScaleAttention(
            embed_dim=self.node_h_dim[0],  # Scalar dimension
            num_heads=num_attention_heads,
            window_sizes=attention_window_sizes,
            dropout=drop_rate
        )
        
        # Decoder layers
        self.W_s = nn.Embedding(self.out_dim, self.out_dim)
        self.edge_h_dim = (self.edge_h_dim[0] + self.out_dim, self.edge_h_dim[1])
        self.decoder_layers = nn.ModuleList(
                GVPConvLayer(self.node_h_dim, self.edge_h_dim,
                             activations=activations, vector_gate=True, 
                             drop_rate=drop_rate, autoregressive=True, norm_first=True) 
            for _ in range(num_layers))
        
        # Output
        self.W_out = GVP(self.node_h_dim, (self.out_dim, 0), activations=(None, None))
    
    def forward(self, batch):

        h_V = (batch.node_s, batch.node_v)
        h_E = (batch.edge_s, batch.edge_v)
        edge_index = batch.edge_index
        seq = batch.seq

        h_V = self.W_v(h_V)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
        h_E = self.W_e(h_E)  # (n_edges, n_conf, d_se), (n_edges, n_conf, d_ve, 3)

        for layer in self.encoder_layers:
            h_V = layer(h_V, edge_index, h_E)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)

        # Pool multi-conformation features: 
        # nodes: (n_nodes, d_s), (n_nodes, d_v, 3)
        # edges: (n_edges, d_se), (n_edges, d_ve, 3)
        h_V, h_E = self.pool_multi_conf(h_V, h_E, batch.mask_confs, edge_index)

        # Apply multi-scale attention on pooled scalar features
        h_V_s = h_V[0].unsqueeze(0)  # (1, n_nodes, d_s)
        attn_output = self.multi_scale_attention(h_V_s)
        h_V = (attn_output.squeeze(0) + h_V[0], h_V[1])  # Residual connection

        encoder_embeddings = h_V
        
        h_S = self.W_s(seq)
        h_S = h_S[edge_index[0]]
        h_S[edge_index[0] >= edge_index[1]] = 0
        h_E = (torch.cat([h_E[0], h_S], dim=-1), h_E[1])
        
        for layer in self.decoder_layers:
            h_V = layer(h_V, edge_index, h_E, autoregressive_x = encoder_embeddings)
        
        logits = self.W_out(h_V)
        
        return logits
    
    @torch.no_grad()
    def sample(
            self, 
            batch, 
            n_samples, 
            temperature: Optional[float] = 0.1, 
            logit_bias: Optional[torch.Tensor] = None,
            return_logits: Optional[bool] = False
        ):
        '''
        Samples sequences autoregressively from the distribution
        learned by the model.

        Args:
            batch (torch_geometric.data.Data): mini-batch containing one
                RNA backbone to design sequences for
            n_samples (int): number of samples
            temperature (float): temperature to use in softmax over 
                the categorical distribution
            logit_bias (torch.Tensor): bias to add to logits during sampling
                to manually fix or control nucleotides in designed sequences,
                of shape [n_nodes, 4]
            return_logits (bool): whether to return logits or not
        
        Returns:
            seq (torch.Tensor): int tensor of shape [n_samples, n_nodes]
                                based on the residue-to-int mapping of
                                the original training data
            logits (torch.Tensor): logits of shape [n_samples, n_nodes, 4]
                                   (only if return_logits is True)
        ''' 
        h_V = (batch.node_s, batch.node_v)
        h_E = (batch.edge_s, batch.edge_v)
        edge_index = batch.edge_index
    
        device = edge_index.device
        num_nodes = h_V[0].shape[0]
        
        h_V = self.W_v(h_V)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
        h_E = self.W_e(h_E)  # (n_edges, n_conf, d_se), (n_edges, n_conf, d_ve, 3)
        
        for layer in self.encoder_layers:
            h_V = layer(h_V, edge_index, h_E)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
        
        # Pool multi-conformation features
        # nodes: (n_nodes, d_s), (n_nodes, d_v, 3)
        # edges: (n_edges, d_se), (n_edges, d_ve, 3)
        h_V, h_E = self.pool_multi_conf(h_V, h_E, batch.mask_confs, edge_index)
        
        # Apply multi-scale attention on pooled scalar features
        h_V_s = h_V[0].unsqueeze(0)  # (1, n_nodes, d_s)
        attn_output = self.multi_scale_attention(h_V_s)
        h_V = (attn_output.squeeze(0) + h_V[0], h_V[1])  # Residual connection
        
        # Repeat features for sampling n_samples times
        h_V = (h_V[0].repeat(n_samples, 1),
            h_V[1].repeat(n_samples, 1, 1))
        h_E = (h_E[0].repeat(n_samples, 1),
            h_E[1].repeat(n_samples, 1, 1))
        
        # Expand edge index for autoregressive decoding
        edge_index = edge_index.expand(n_samples, -1, -1)
        offset = num_nodes * torch.arange(n_samples, device=device).view(-1, 1, 1)
        edge_index = torch.cat(tuple(edge_index + offset), dim=-1)
        # This is akin to 'batching' (in PyG style) n_samples copies of the graph
        
        seq = torch.zeros(n_samples * num_nodes, device=device, dtype=torch.int)
        h_S = torch.zeros(n_samples * num_nodes, self.out_dim, device=device)
        logits = torch.zeros(n_samples * num_nodes, self.out_dim, device=device)

        h_V_cache = [(h_V[0].clone(), h_V[1].clone()) for _ in self.decoder_layers]

        # Decode one token at a time
        for i in range(num_nodes):
            
            h_S_ = h_S[edge_index[0]]
            h_S_[edge_index[0] >= edge_index[1]] = 0
            h_E_ = (torch.cat([h_E[0], h_S_], dim=-1), h_E[1])
                    
            edge_mask = edge_index[1] % num_nodes == i  # True for all edges where dst is node i
            edge_index_ = edge_index[:, edge_mask]  # subset all incoming edges to node i
            h_E_ = tuple_index(h_E_, edge_mask)
            node_mask = torch.zeros(n_samples * num_nodes, device=device, dtype=torch.bool)
            node_mask[i::num_nodes] = True  # True for all nodes i and its repeats
            
            for j, layer in enumerate(self.decoder_layers):
                out = layer(h_V_cache[j], edge_index_, h_E_,
                        autoregressive_x=h_V_cache[0], node_mask=node_mask)
                
                out = tuple_index(out, node_mask)  # subset out to only node i and its repeats
                
                if j < len(self.decoder_layers)-1:
                    h_V_cache[j+1][0][i::num_nodes] = out[0]
                    h_V_cache[j+1][1][i::num_nodes] = out[1]
                
            lgts = self.W_out(out)
            # Add logit bias if provided to fix or bias positions
            if logit_bias is not None:
                lgts += logit_bias[i]
            # Sample from logits
            seq[i::num_nodes] = Categorical(logits=lgts / temperature).sample()
            h_S[i::num_nodes] = self.W_s(seq[i::num_nodes])
            logits[i::num_nodes] = lgts

        if return_logits:
            return seq.view(n_samples, num_nodes), logits.view(n_samples, num_nodes, self.out_dim)
        else:    
            return seq.view(n_samples, num_nodes)
        
    def pool_multi_conf(self, h_V, h_E, mask_confs, edge_index):

        if mask_confs.size(1) == 1:
            # Number of conformations is 1, no need to pool
            return (h_V[0][:, 0], h_V[1][:, 0]), (h_E[0][:, 0], h_E[1][:, 0])
        
        # True num_conf for masked mean pooling
        n_conf_true = mask_confs.sum(1, keepdim=True)  # (n_nodes, 1)
        
        # Mask scalar features
        mask = mask_confs.unsqueeze(2)  # (n_nodes, n_conf, 1)
        h_V0 = h_V[0] * mask
        h_E0 = h_E[0] * mask[edge_index[0]]

        # Mask vector features
        mask = mask.unsqueeze(3)  # (n_nodes, n_conf, 1, 1)
        h_V1 = h_V[1] * mask
        h_E1 = h_E[1] * mask[edge_index[0]]
        
        # Average pooling multi-conformation features
        h_V = (h_V0.sum(dim=1) / n_conf_true,               # (n_nodes, d_s)
               h_V1.sum(dim=1) / n_conf_true.unsqueeze(2))  # (n_nodes, d_v, 3)
        h_E = (h_E0.sum(dim=1) / n_conf_true[edge_index[0]],               # (n_edges, d_se)
               h_E1.sum(dim=1) / n_conf_true[edge_index[0]].unsqueeze(2))  # (n_edges, d_ve, 3)

        return h_V, h_E


class NonAutoregressiveMultiGNNv1(torch.nn.Module):
    '''
    Non-Autoregressive GVP-GNN for **multiple** structure-conditioned RNA design.
    
    Takes in RNA structure graphs of type `torch_geometric.data.Data` 
    or `torch_geometric.data.Batch` and returns a categorical distribution
    over 4 bases at each position in a `torch.Tensor` of shape [n_nodes, 4].
    
    The standard forward pass requires sequence information as input
    and should be used for training or evaluating likelihood.
    For sampling or design, use `self.sample`.
    
    Args:
        node_in_dim (tuple): node dimensions in input graph
        node_h_dim (tuple): node dimensions to use in GVP-GNN layers
        node_in_dim (tuple): edge dimensions in input graph
        edge_h_dim (tuple): edge dimensions to embed in GVP-GNN layers
        num_layers (int): number of GVP-GNN layers in encoder/decoder
        drop_rate (float): rate to use in all dropout layers
        out_dim (int): output dimension (4 bases)
    '''
    def __init__(
        self,
        node_in_dim = (64, 4), 
        node_h_dim = (128, 16), 
        edge_in_dim = (32, 1), 
        edge_h_dim = (32, 1),
        num_layers = 3, 
        drop_rate = 0.1,
        out_dim = 4,
    ):
        super().__init__()
        self.node_in_dim = node_in_dim
        self.node_h_dim = node_h_dim
        self.edge_in_dim = edge_in_dim
        self.edge_h_dim = edge_h_dim
        self.num_layers = num_layers
        self.out_dim = out_dim
        activations = (F.silu, None)
        
        # Node input embedding
        self.W_v = torch.nn.Sequential(
            LayerNorm(self.node_in_dim),
            GVP(self.node_in_dim, self.node_h_dim,
                activations=(None, None), vector_gate=True)
        )

        # Edge input embedding
        self.W_e = torch.nn.Sequential(
            LayerNorm(self.edge_in_dim),
            GVP(self.edge_in_dim, self.edge_h_dim, 
                activations=(None, None), vector_gate=True)
        )
        
        # Encoder layers (supports multiple conformations)
        self.encoder_layers = nn.ModuleList(
                MultiGVPConvLayer(self.node_h_dim, self.edge_h_dim, 
                                  activations=activations, vector_gate=True,
                                  drop_rate=drop_rate, norm_first=True)
            for _ in range(num_layers))
        
        # Output
        self.W_out = torch.nn.Sequential(
            LayerNorm(self.node_h_dim),
            GVP(self.node_h_dim, self.node_h_dim,
                activations=(None, None), vector_gate=True),
            GVP(self.node_h_dim, (self.out_dim, 0), 
                activations=(None, None))   
        )
    
    def forward(self, batch):

        h_V = (batch.node_s, batch.node_v)
        h_E = (batch.edge_s, batch.edge_v)
        edge_index = batch.edge_index
        
        h_V = self.W_v(h_V)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
        h_E = self.W_e(h_E)  # (n_edges, n_conf, d_se), (n_edges, n_conf, d_ve, 3)

        for layer in self.encoder_layers:
            h_V = layer(h_V, edge_index, h_E)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)

        # Pool multi-conformation features: 
        # nodes: (n_nodes, d_s), (n_nodes, d_v, 3)
        # edges: (n_edges, d_se), (n_edges, d_ve, 3)
        # h_V, h_E = self.pool_multi_conf(h_V, h_E, batch.mask_confs, edge_index)
        h_V = (h_V[0].mean(dim=1), h_V[1].mean(dim=1))

        logits = self.W_out(h_V)  # (n_nodes, out_dim)
        
        return logits
    
    def sample(self, batch, n_samples, temperature=0.1, return_logits=False):
        
        with torch.no_grad():

            h_V = (batch.node_s, batch.node_v)
            h_E = (batch.edge_s, batch.edge_v)
            edge_index = batch.edge_index
        
            h_V = self.W_v(h_V)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
            h_E = self.W_e(h_E)  # (n_edges, n_conf, d_se), (n_edges, n_conf, d_ve, 3)
            
            for layer in self.encoder_layers:
                h_V = layer(h_V, edge_index, h_E)  # (n_nodes, n_conf, d_s), (n_nodes, n_conf, d_v, 3)
            
            # Pool multi-conformation features
            # h_V, h_E = self.pool_multi_conf(h_V, h_E, batch.mask_confs, edge_index)
            h_V = (h_V[0].mean(dim=1), h_V[1].mean(dim=1))
            
            logits = self.W_out(h_V)  # (n_nodes, out_dim)
            probs = F.softmax(logits / temperature, dim=-1)
            seq = torch.multinomial(probs, n_samples, replacement=True)  # (n_nodes, n_samples)

            if return_logits:
                return seq.permute(1, 0).contiguous(), logits.unsqueeze(0).repeat(n_samples, 1, 1)
            else:
                return seq.permute(1, 0).contiguous()
        
    def pool_multi_conf(self, h_V, h_E, mask_confs, edge_index):

        if mask_confs.size(1) == 1:
            # Number of conformations is 1, no need to pool
            return (h_V[0][:, 0], h_V[1][:, 0]), (h_E[0][:, 0], h_E[1][:, 0])
        
        # True num_conf for masked mean pooling
        n_conf_true = mask_confs.sum(1, keepdim=True)  # (n_nodes, 1)
        
        # Mask scalar features
        mask = mask_confs.unsqueeze(2)  # (n_nodes, n_conf, 1)
        h_V0 = h_V[0] * mask
        h_E0 = h_E[0] * mask[edge_index[0]]

        # Mask vector features
        mask = mask.unsqueeze(3)  # (n_nodes, n_conf, 1, 1)
        h_V1 = h_V[1] * mask
        h_E1 = h_E[1] * mask[edge_index[0]]
        
        # Average pooling multi-conformation features
        h_V = (h_V0.sum(dim=1) / n_conf_true,               # (n_nodes, d_s)
               h_V1.sum(dim=1) / n_conf_true.unsqueeze(2))  # (n_nodes, d_v, 3)
        h_E = (h_E0.sum(dim=1) / n_conf_true[edge_index[0]],               # (n_edges, d_se)
               h_E1.sum(dim=1) / n_conf_true[edge_index[0]].unsqueeze(2))  # (n_edges, d_ve, 3)

        return h_V, h_E

根据新的模型训练后进行测试比对,

Multi-Scale Attention

看起来有点作用不是吗?

测试Multi-scale在数据集的效果

在同样500条数据集测试,grnade模型的数据如下:

Length Range Statistic Perplexity Recovery Edit Distance SC Score Sample Count
0-100 Median 1.4128 0.56445 17.5972 0.62315 234
100-200 Median 1.26885 0.73585 30.2634 0.61495 72
200+ Median 1.3132 0.64605 226.875 0.42835 138
0-100 Mean 1.47854274 0.569795726 18.55480427 0.634651282 234
100-200 Mean 1.34514583 0.685369444 37.49759306 0.569984722 72
200+ Mean 1.41243696 0.639286232 211.1403986 0.421031884 138

而加了multiheadattention后数据如表格:

Length Range Statistic Perplexity Recovery Edit Distance SC Score Sample Count
0-100 Median 1.30675 0.5532 17.21875 0.62105 234
100-200 Median 1.17275 0.7444 29.75 0.62205 72
200+ Median 1.30785 0.70485 183.84375 0.41915 140
0-100 Mean 1.36121325 0.572758974 18.34507564 0.638930342 234
100-200 Mean 1.24458056 0.704726389 35.62898472 0.605529167 72
200+ Mean 1.36845071 0.685876429 180.6455357 0.430865714 140

之后用的多层注意力Multiscale attention如下:

Length Range Statistic Perplexity Recovery Edit Distance SC Score Sample Count
0-100 Median 1.3358 0.5592 17.6667 0.6396 235
100-200 Median 1.2226 0.71205 33.78125 0.6771 72
200+ Median 1.32545 0.68665 193.0625 0.42465 140
0-100 Mean 1.36930766 0.557402979 18.98926 0.660275319 235
100-200 Mean 1.28780417 0.682672222 37.93399306 0.624776389 72
200+ Mean 1.40186786 0.673555 187.8361607 0.434380714 140

在多数据测试之后,我做可视化对比,我们主要比较SCC,Recovery。

加了注意力的两个模型实际在表现上都比原版好,现在就要取舍了

最终我认为SCC指标更重要,因而在12000条数据对multiscale和grnade进行测试,得到二者lgx为坐标的图像对比:

是有细微差别的,毕竟grnade已经算sota了,优化起来费劲
经过对比,在中长序列显著提高了SCC,同时平均Recovery差别不大
使用 Hugo 构建
主题 StackJimmy 设计