提出关于模型异构型的解决方案,这是 FedMKT 的一个基础。
对于 No-IID 的数据,原有的解决方案有三:
- 贝叶斯方法:采用 变分推断(Variational Inference) 来建模模型参数的后验分布,通过概率建模捕捉数据分布的不确定性。通过共享先验分布,客户端之间可以共享统计强度,同时保留各自的后验分布以实现个性化。
- 元学习(学会学习)目标是学习一个良好的 模型初始化参数 或 优化算法。
- 迁移学习:通过迁移通用知识,弥补本地数据分布偏差带来的性能下降。FedMD 本身的核心就是利用公共数据进行迁移学习。
FedMD 提出,这三者没有办法解决模型异构的情况,所以提出了该框架。
明确一下 FedMD 的任务范围是在 MNIST 和 CIFAR 这种分类任务上,所有参与者的模型输出空间(类别标签)是一致的。这里的 logits 指的是模型在公共数据样本上的输出的原始分数,分数越高,概率越高。论文主要针对 CNN 而不是 LLM(Transformer)
算法流程:
- 对于边缘模型 $f_k$,在共有数据集 $D_0$ 上训练到收敛后,在私有数据集 $D_k$ 上训练。
- 对于每轮训练,先计算公用数据集上的分数,上传推理结果
- 服务端平均聚合边缘的知识,形成共识。
- Digest:边缘下载共识,在共有数据集上训练
- Revisit:在私有数据集上重新微调,回到 2
一些细节:
- 直接比较 Logits 和有温度的 Softmax 对于作者来说,期望没有什么影响
- 对于共有数据集,随机采样,而非全部使用
- 如果说中间出现了短暂的退化情况,可以适当的减少 Revisit 轮次,增加 Digest 的 batch size,避免偏差。
- 对于聚合的共识,可以采用加权平均(依据