这篇文章应该是 Co-PLMs 的直接前身,其结构十分相似:
- Client SLM 通过私有数据集进行 LoRA 微调
- Server LLM 利用共有数据集吸收 SLM 的知识,达到类似集中微调的效果。
- Server LLM 聚合后同样利用共有数据集向 SLM 传输知识,提高 SLM 质量。
其核心问题也是 tokenizer 不一致,使用了最小距离编辑的方法,进行 Token 对齐。
在学习的过程中实际上有两段过程:
- 词表对齐(通过
MinED最小编辑距离衡量分词之间的距离) + 双向 Token 对齐 - 选择性双向知识传递
在双向 Token 对齐的同时,实际上对齐了两个东西:tokenizer 的分词方式和 Logits 的输出分布。也就是先通过 MinED 得出词表映射,然后根据这个映射进行 Token 的对齐。但实际上我觉得不能叫做双向对齐,因为逆过程是不完全一样的,只是方法上是完全一致的。
那么对于 Logits 分布的对齐,通过词元的映射,其实可以得到一个词元序列的对齐,对于同一个输出,可以切片为若干的 logits 分布:
$$\begin{bmatrix} x_{11} \\ x_{12} \\ \vdots \\ x_{1d} \end{bmatrix} \& \begin{bmatrix} y_{11} \\ y_{12} \\ \vdots \\ y_{1d} \end{bmatrix} \to \begin{bmatrix} x_{21} \\ x_{22} \\ \vdots \\ x_{2d} \end{bmatrix} \& \begin{bmatrix} y_{21} \\ y_{22} \\ \vdots \\ y_{2d} \end{bmatrix} \to \cdots $$
那么考虑到对于学习方的每一个词元对齐的时候存在三种情况:一对一,一对多。
- 对于一对一的情况,那么取出概率 Top-K 的位置,将其他位置置 0,然后放入对应位置。
- 对于一对多的情况(例如 Bloom 中
utilize对应 LLaMa 的util + ize),那么放弃同步 LLaMa 中的内容,转而使用一个 One-Hot 向量,只保留这个目前最可能的输出。
我总觉得论文里 Token 和 Vocabulary 的表述很混乱。这里就统一用词元表示吧。
形式化一点,这里有两种映射方式,词距映射和语义映射,其中词距映射是一个一对一的映射,而语义映射是个多映射,记为 $D(t)$ 和 $S(t)$。其中 $t$ 表示一个词元。记 $L_L(i)$ 表示教学方的第 $i$ 个 logits 池,$L_S(i)$ 表示学习方的第 $i$ 个 logits 池,$lt_i$ 表示教学方的第 $i$ 个输出词元,$st_i$ 表示学习方的第 $i$ 个输出词元。那么过程核心过程是这样的:
for i = 1 ~ len(L_S)
if |S(st_i)| > 1
L_S(i) = OneHot(st_i)
else if L_S(i) hasn't set
L_S(i) = {D(t): L_L(i)[t] for t in TopK(L_L(i))}
else
continue
虽然不是很符合规范,但是我看着最清晰
其实在这个过程中会有一个很大的问题:如果两个模型差距太大,即 Logits 映射后差距非常大,那么就可能破坏模型的能力。所以论文中还有一个 DualMinCE 的过程,用于限制 Loss 和数据,也就是说只有学生模型“认为”值得学习的才学习。
于是乎有了 选择性双向知识传递。具体来说:
- 对于共有数据集的每一个数据 $x^i$,计算每一个 Client 对这个数据的损失 $l_k^i$,以及本地模型对这个数据的损失 $l_0^i$
- 取损失最小的那个 $l_k^i$ 对应的 $k$,如果 $l_k^i < l_0^i$,那么将 $(x^i, p_k^i)$ 加入数据集中。其中 $p_k^i$ 表示第 $k$ 个 Client 在 $x^i$ 数据上的 Logits 分布。
- 反向学习的时候同理,如果 SLM 在 $x^i$ 上的 Loss 小于 LLM,那么就不学习 LLM 的 Logits。
值得注意的是,这个不学习只是不学习另一个模型的 Logits 分布,但是始终有在共有数据集的 Ground Truth 学习。