0%

PAPER READING - Contrastive Learning

用于无监督视觉表示学习的动量对比 Momentum Contrast for Unsupervised Visual Representation Learning

1. Momentum Contrast

1.1 定义

  • 内容: 对比学习(Contrastive Learning) 通过让模型学习区分相似正样本与不相似负样本的数据点来学习有用的特征。

1.2 正样本对(Positive Pairs):

  • 内容: 通常来自同一数据点的不同数据增强视图(例如,同一张图片的两次随机裁剪、颜色抖动等)。它们应该具有相似的语义信息。

1.3 负样本(Negatives):

  • 内容: 来自与查询样本不同的其他数据点。它们代表不同的语义内容。

1.4 目标: 模型的目标是学习一个编码器(Encoder)

  • 内容: 查询样本与其对应的正样本在特征空间中的距离很近(相似度高)。
  • 内容: 查询样本与大量负样本在特征空间中的距离很远(相似度低)。

2. 创新

2.1 动态字典(Dynamic Dictionary):

  • 内容: MoCo 维护一个先进先出FIFO的队列来存储编码后的特征Keys
  • 内容:当前批次的数据经过键编码器编码后,其特征被入队添加到字典队列尾部。
  • 内容:同时,队列中最老的批次特征被出队dequeue 移除。
  • 内容:队列可以将字典大小 \(K\) 设计得远大于单个批次的大小,从而提供海量且一致的负样本来源(一致性由下面的动量更新保证)。队列解耦了字典大小与批次大小的限制。

2.2 动量更新编码器(Momentum Update of Key Encoder):

  • 内容: 查询编码器使用标准的梯度下降更新(SGD
  • 内容: 键编码器不通过反向传播更新
  • 内容: 键编码器的参数 \(θ_k\) 通过动量更新Momentum Update从查询编码器的参数 \(θ_q\) 获得:\[ \theta_k \gets m \cdot \theta_k + (1 - m) \cdot \theta_q \]其中 \(m\) 是一个动量系数(如 \(m\) = \(0.999\)),非常接近\(1\)。 ## 2.3 优势:
  • 内容:动量更新使得键编码器 \(f_k\) 的参数变化非常缓慢和平滑。

3. 对比损失函数(InfoNCE Loss:

3.1 InfoNCE Loss (Noise-Contrastive Estimation Loss):

对比损失函数(InfoNCE/NT-Xent Loss)定义为: \(\mathcal{L}_{\text{q}} = -\log \underbrace{\left( \frac{\exp\left( \mathbf{q} \cdot \mathbf{k}^{+} / \tau \right)}{\exp\left( \mathbf{q} \cdot \mathbf{k}^{+} / \tau \right) + \sum\limits_{i=1}^{N} \exp\left( \mathbf{q} \cdot \mathbf{k}_{i}^{-} / \tau \right)} \right)}_{\text{Softmax 概率}}\)

3.2其中:

参数 含义
\(\mathbf{q}\) 查询向量(Query Vector):由查询编码器 \(f_q\) 输出(如 \(2048\) 维)。
\(\mathbf{k}^{+}\) 正样本键向量(Positive Key):由键编码器 \(f_k\) 输出(来自同一数据的不同增强视图)。
\(\mathbf{k}_{i}^{-}\) 负样本键向量(Negative Keys):来自字典队列的其他数据样本(数量为 \(N\),如 \(65536\))。
\(\tau\) 温度参数(Temperature):控制相似度分布的尖锐程度(典型值 \(0.05 \sim 0.2\))。
\(\cdot\) 向量点积(L2归一化后等价于余弦相似度,即 \(\mathbf{q} \cdot \mathbf{k} = \cos\theta\))。

3.2 数据流:

  • 原始输入: 一张图片 P.jpg (256x256 原始尺寸)
  • 预处理Step1: 随机裁剪出 224x224 的区域
  • 预处理Step2: 随机轻微改变颜色和亮度。
  • 预处理Step3: 归一化像素值。
  • 结果: [3, 224, 224] (一个 3x224x224 的张量)。输入到编码器的形式。
  • 查询编码器 \(f_q\) 处理:
  • Step1: 输入:[3, 224, 224] 张量。
  • Step2: 经过Model
  • Step3: 全局平均池化层 (Global Average Pooling) 将空间维度压缩掉。
  • Step4: 一个线性投影层将特征维度映射到 D
  • 结果:一个 \(D\) 维(如 \(M\) 维)的归一化向量 \(q\)。这个 \(q\) 代表了经过裁剪、颜色扰动后的猫头像的抽象特征。例如, [0.12, -0.05, 0.87, …, 0.03] (\(M\)个数值)。
  • 键编码器 \(f_k\) 处理:
  • Step1: 输入:[3, 224, 224] 张量。(对 P.jpg 应用另一组随机预处理得到的另一个 [3, 224, 224] 张量。)
  • Step2: 经过结构相同但参数由动量更新的 \(f_k\)
  • Step3: 一个 \(D\) 维(如 \(M\) 维)的归一化向量 \(k\)。例如 [0.15, -0.08, 0.84, …, 0.02]。这个 \(k\) 代表了同一数据但不同视角/颜色下的抽象特征。
  • 动态字典:
  • Step1:包含之前通过 \(f_k\) 计算出的 \(k\) 向量。例如,队列大小是 \(L\),里面存储了 \(L\) 个不同的 D=128 维向量,每个代表处理过的一张数据的特征。