0%

FusionProt - 论文阅读

Fusing Sequence and Structural Information for Unified Protein Representation Learning

FusionProt

1 蛋白质表示学习:

  • 内容:

FusionProt :可学习融合 token和迭代双向信息交换,实现序列与结构的动态协同学习,而非静态拼接。

2. 一维(1D)氨基酸序列和三维(3D)空间结构:

  • 单模态依赖: ProteinBERT、ESM-2仅基于序列

  • 静态融合缺陷 :ESM-GearNet、SaProt 结合序列与结构,但采用 “单向 / 一次性融合”

好的,完全没有问题。这是对 FusionNetwork 模型架构代码的中文复述分析。

3. 模型总体

fusion

1
2
3
4
5
6
7
8
@R.register("models.FusionNetwork")
class FusionNetwork(nn.Module, core.Configurable):
def __init__(self, sequence_model, structure_model, fusion="series", cross_dim=None):
super(FusionNetwork, self).__init__()
self.sequence_model = sequence_model
self.structure_model = structure_model
self.output_dim = sequence_model.output_dim + structure_model.output_dim
self.inject_step = 5 # (sequence_layers / structure_layers) layers

  • class FusionNetwork(...): 定义了模型类,它继承自 PyTorch 的基础模块 nn.Module
  • __init__(...): 构造函数,接收已经初始化好的 sequence_modelstructure_model 作为输入。
  • self.output_dim: 定义了模型最终输出特征的维度。因为最后会将两个模型的特征拼接起来,所以是两者输出维度之和。
  • self.inject_step = 5:定义了信息“注入”或“交流”的频率。这里设置为 5,意味着每经过序列模型的 5 层,就会进行一次信息交换
1
2
3
4
# Structure embeddings layer
raw_input_dim = 21 # amino acid tokens
self.structure_embed_linear = nn.Linear(raw_input_dim, structure_model.input_dim)
self.embedding_batch_norm = nn.BatchNorm1d(structure_model.input_dim)
  • self.structure_embed_linear: 一个线性层,用于将原始的结构输入(比如 21 种氨基酸的独热编码)转换为结构模型(GNN)所期望的输入维度。
  • self.embedding_batch_norm: 批归一化层,用于稳定结构嵌入层的训练过程。
1
2
3
4
# Normal Initialization of the 3D structure token
structure_token = nn.Parameter(torch.Tensor(structure_model.input_dim).unsqueeze(0))
nn.init.normal_(structure_token, mean=0.0, std=0.01)
self.structure_token = nn.Parameter(structure_token.squeeze(0))
  • self.structure_token: 一个可学习的向量 (nn.Parameter)。这个“令牌”不代表任何真实的原子或氨基酸,而是一个抽象的载体。在训练过程中,它将学习如何编码和表示整个蛋白质的全局 3D 结构信息。它就像一个信息信使。
1
2
3
# Linear Transformation between structure to sequential spaces
self.structure_linears = nn.ModuleList([...])
self.seq_linears = nn.ModuleList([...])
  • self.structure_linears / self.seq_linears: 序列模型和结构模型内部处理的特征向量维度可能不同。当“3D 令牌”需要在两个模型之间传递时,这些线性层负责将它的表示从一个模型的特征空间转换到另一个模型的特征空间。

4. 前向

1
2
3
def forward(self, graph, input, all_loss=None, metric=None):
# Build a new protein graph with the 3D token (the lase node)
new_graph = self.build_protein_graph_with_3d_token(graph)
  • 首先调用辅助函数,将输入的蛋白质图谱进行改造:为图谱增加一个代表“3D 令牌”的新节点,并将这个新节点与图中所有其他节点连接起来。
序列模型的初始化
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Sequence (ESM) model initialization
sequence_input = self.sequence_model.mapping[graph.residue_type]
sequence_input[sequence_input == -1] = graph.residue_type[sequence_input == -1]
size = graph.num_residues

# Check if sequence size is not bigger than max seq length
if (size > self.sequence_model.max_input_length).any():
starts = size.cumsum(0) - size
size = size.clamp(max=self.sequence_model.max_input_length)
ends = starts + size
mask = functional.multi_slice_mask(starts, ends, graph.num_residues)
sequence_input = sequence_input[mask]
graph = graph.subresidue(mask)
size_ext = size

# BOS == CLS
if self.sequence_model.alphabet.prepend_bos:
bos = torch.ones(graph.batch_size, dtype=torch.long, device=self.sequence_model.device) * self.sequence_model.alphabet.cls_idx
sequence_input, size_ext = functional._extend(bos, torch.ones_like(size_ext), sequence_input, size_ext)

if self.sequence_model.alphabet.append_eos:
eos = torch.ones(graph.batch_size, dtype=torch.long, device=self.sequence_model.device) * self.sequence_model.alphabet.eos_idx
sequence_input, size_ext = functional._extend(sequence_input, size_ext, eos, torch.ones_like(size_ext))

# Padding
tokens = functional.variadic_to_padded(sequence_input, size_ext, value=self.sequence_model.alphabet.padding_idx)[0]
repr_layers = [self.sequence_model.repr_layer]
assert tokens.ndim == 2
padding_mask = tokens.eq(self.sequence_model.model.padding_idx) # B, T
  • 序列数据进行 Transformer 模型(如 ESM)所需的标准预处理。
  • 包括添加序列开始(BOS)和结束(EOS)标记,以及将所有序列填充(Padding)到相同长度,以便进行批处理。
模型初始化与初次融合
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Sequence embedding layer
x = self.sequence_model.model.embed_scale * self.sequence_model.model.embed_tokens(tokens)

if self.sequence_model.model.token_dropout:
x.masked_fill_((tokens == self.sequence_model.model.mask_idx).unsqueeze(-1), 0.0)
# x: B x T x C
mask_ratio_train = 0.15 * 0.8
src_lengths = (~padding_mask).sum(-1)
mask_ratio_observed = (tokens == self.sequence_model.model.mask_idx).sum(-1).to(x.dtype) / src_lengths
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]

# Structure model initialization
structure_hiddens = []
batch_size = graph.batch_size
structure_embedding = self.embedding_batch_norm(self.structure_embed_linear(input))
structure_token_batched = self.structure_token.unsqueeze(0).expand(batch_size, -1)
structure_input = torch.cat([structure_embedding.squeeze(1), structure_token_batched], dim=0)

# Add the 3D token representation
structure_token_expanded = self.structure_token.unsqueeze(0).expand(x.size(0), -1).unsqueeze(1)
x = torch.cat((x[:, :-1], structure_token_expanded, x[:, -1:]), dim=1)
padding_mask = torch.cat([padding_mask[:, :-1],
torch.zeros(padding_mask.size(0), 1).to(padding_mask), padding_mask[:, -1:]], dim=1)
size_ext += 1

if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))

repr_layers = set(repr_layers)
hidden_representations = {}
if 0 in repr_layers:
hidden_representations[0] = x

# (B, T, E) => (T, B, E)
x = x.transpose(0, 1)
if not padding_mask.any():
padding_mask = None
  • 将 3D 令牌插入序列。
    1. 为序列数据生成初始的词嵌入表示 x
    2. self.structure_token 的初始状态插入到序列嵌入 x 中,通常是放在序列结束标记(EOS)之前。
    3. 序列模型看到的输入序列变成了 [BOS, 残基1, 残基2, ..., 残基N, **3D令牌**, EOS] 的形式。
融合循环
1
2
3
4
5
6
7
8
for seq_layer_idx, seq_layer in enumerate(self.sequence_model.model.layers):
x, attn = seq_layer(
x,
self_attn_padding_mask=padding_mask,
need_head_weights=False,
)
if (seq_layer_idx + 1) in repr_layers:
hidden_representations[seq_layer_idx + 1] = x.transpose(0, 1)
  • 模型开始逐层遍历序列模型的所有层(例如 Transformer 的编码器层)。x 在每一层都会被更新。
1
if seq_layer_idx > 0 and seq_layer_idx % self.inject_step == 0:
  • 信息注入点:每当层数的索引能被 inject_step (即 5) 整除时,就触发一次信息交换。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 1. 从序列中提取 3D 令牌的表示
if structure_layer_index == 0:
structure_input = torch.cat((structure_input[:-1 * batch_size], x[-2, :, :]), dim=0)
else:
structure_input = torch.cat((structure_input[:-1 * batch_size],
self.seq_linears[structure_layer_index](x[-2, :, :])), dim=0)

# 2. 用结构模型的一层来处理
hidden = self.structure_model.layers[structure_layer_index](new_graph, structure_input)
if self.structure_model.short_cut and hidden.shape == structure_input.shape:
hidden = hidden + structure_input
if self.structure_model.batch_norm:
hidden = self.structure_model.batch_norms[structure_layer_index](hidden)

structure_hiddens.append(hidden)
structure_input = hidden

# 3. 将更新后的 3D 令牌表示插回序列
updated_structure_token = self.structure_linears[...](structure_input[-1 * batch_size:])
x = torch.cat((x[:-2, :, :], updated_structure_token.unsqueeze(0), x[-1:, :, :]), dim=0)
structure_layer_index += 1
  • 信息流程
    1. 从序列到结构:模型从序列表示 x 中提取出“3D 令牌”的最新向量。这个向量此时已经吸收了前面几层序列模型的上下文信息。然后,通过(seq_linears)将其转换后,更新到结构模型的输入中。
    2. 结构信息处理:运行一层结构模型(GNN)。GNN 根据图的连接关系更新所有节点的表示,当然也包括“3D 令牌”这个特殊节点。
    3. 从结构到序列:从 GNN 的输出中,再次提取出“3D 令牌”的向量。这个向量包含更新后的结构信息。再通过(structure_linears)转换后,把它插回到序列表示 x 中,替换掉旧的版本。

这个循环不断重复。

输出
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Structural Output
if self.structure_model.concat_hidden:
structure_node_feature = torch.cat(structure_hiddens, dim=-1)[:-1 * batch_size]
else:
structure_node_feature = structure_hiddens[-1][:-1 * batch_size]

structure_graph_feature = self.structure_model.readout(graph, structure_node_feature)

# Sequence Output
x = self.sequence_model.model.emb_layer_norm_after(x)
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)

# last hidden representation should have layer norm applied
if (seq_layer_idx + 1) in repr_layers:
hidden_representations[seq_layer_idx + 1] = x
x = self.sequence_model.model.lm_head(x)

output = {"logits": x, "representations": hidden_representations}

# Sequence (ESM) model outputs
residue_feature = output["representations"][self.sequence_model.repr_layer]
residue_feature = functional.padded_to_variadic(residue_feature, size_ext)
starts = size_ext.cumsum(0) - size_ext
if self.sequence_model.alphabet.prepend_bos:
starts = starts + 1
ends = starts + size
mask = functional.multi_slice_mask(starts, ends, len(residue_feature))
residue_feature = residue_feature[mask]
graph_feature = self.sequence_model.readout(graph, residue_feature)

# Combine both models outputs
node_feature = torch.cat(...)
graph_feature = torch.cat(...)

return {"graph_feature": graph_feature, "node_feature": node_feature}
  • 提取输出:循环结束后,分别从两个模型中提取最终的特征表示。
  • 读出(Readout):使用一个“读出函数”(如求和或平均)将节点级别的特征聚合成一个代表整个蛋白质的图级别特征。
  • 最终组合:将来自序列模型和结构模型的节点特征(node_feature)和图特征(graph_feature)分别拼接(concatenate)起来。
  • 返回结果:返回一个包含组合后特征的字典,可用于下游任务(如功能预测、属性回归等)。