Generalizing to Unseen Domains: A Survey on Domain Generalization

1. Abstract

机器学习系统通常假设训练和测试分布是相同的。为此,一个关键的需求是开发可以泛化到不可见分布 (unseen distribution) 上的模型。域泛化 (Domain Generalization, DG),即O.O.D Generalization,引起了越来越多的兴趣。 Domain generalization deals with a challenging setting where one or several different but related domain(s) are given, and the goal is to learn a model that can generalize to an unseen test domain. 本文首次对该领域的最新进展进行了回顾。首先,我们提供了域泛化的正式定义,并讨论了几个相关领域。然后,我们彻底回顾了与领域泛化相关的理论,并仔细分析了泛化背后的理论。我们将最近的算法分为三类:数据操作 (data manipulation)、表征学习 (representation learning) 和学习策略 (learning strategy),并针对每一类详细介绍了几种流行的算法。第三,我们介绍了常用的数据集和应用。最后,我们总结了现有的文献并提出了一些未来的潜在研究课题。

2. Introduction

  • 为什么要做Domain Generalization?
    目前的ml models是假设训练和测试i.i.d., 但是现实中经常会遇到训练和测试有distribution gap的情况,在这种情况下ml models的表现会恶化 (deteriorate),但如果想把所有域 (domain)的数据都收集起来训练的话成本过高,甚至是很难做到,因而无论对于学术界还是工业界都要增强模型的泛化能力。

  • 域泛化的目标
    从一个或几个不同但相关的领域中学习一个模型,让它在未见的领域中也能泛化得很好。

the goal of domain generalization is to learn a model from one or several different but related domains (i.e., diverse training datasets) that will generalize well on unseen testing domains.

3. Background

3.1 Definitions

Definition 1 Domain
X\mathcal{X} 代表非空的输入空间 (nonempty input space),Y\mathcal{Y}代表输出空间,一个域是由从一个分布中抽样的数据组成的,我们记作S={(xi,yi)}i=1nPXY\mathcal{S}=\{(\mathrm{x}_i,y_i)\}_{i=1}^n\sim P_{\mathrm{XY}},其中xXRd\mathrm{x}\in\mathcal{X}\subset \mathbb{R}^d代表输入样本,yYRy\in\mathcal{Y}\subset\mathbb{R}代表标签,PXYP_{\mathrm{XY}}表示输入样本和输出标签的联合分布,X,Y\mathrm{X,Y}代表对应的随机变量。

Definition 2 Domain Generalization
如下图所示,在域泛化中我们有MM个训练域 (training domain, 或源域, source domain) Strain={Si  i=1,2,3,...,M}\mathcal{S}_{train}=\{\mathcal{S}^i\ |\ i=1,2,3,...,M \},其中Si={(xji,yji)}i=1n\mathcal{S}^i=\{(\mathrm{x}_j^i,y_j^i)\}_{i=1}^n代表第ii个域。域和域之间的联合分布是不同的:PXYiPXYj,1ijMP_{\mathrm{XY}}^i\neq P_{\mathrm{XY}}^j, 1\leq i\neq j\leq M。域泛化的目标是从MM个源域中学到一个鲁棒性好,可泛化的预测函数h:XYh: \mathcal{X}\rightarrow\mathcal{Y},从而在未见的目标域Stest\mathcal{S}_{test} (即,Stest\mathcal{S}_{test}在训练中不可达,并且PXYtestPXYiP_{\mathrm{XY}}^{test}\neq P_{\mathrm{XY}}^i,对于i{1,2,...,M}i\in\{1,2,...,M\})中最小化预测误差:

minhE(x,y)Stest[(h(x),y)]\min _h \mathbb{E}_{(\mathrm{x},y)\in\mathcal{S}_{test}}\left[\ell(h(\mathrm{x}),y)\right]

其中E是期望,\mathbb{E}是期望,(,)\ell(\cdot, \cdot)是损失函数。

下表是常用的一些标记符号。

本文对比了一些与域泛化有关的研究领域,包括:迁移学习 (transfer learning)、域适应 (domain adaptation)、多任务学习 (multi-task learning)、多领域学习 (multiple domain learning)、元学习 (meta-learning)、终生学习 (lifelong learning)、零样本学习 (zero-shot learning)。

  • Multi-task learning 在几个有关的任务上共同优化模型,任务间共享表示,可以让模型在原任务上泛化的更好,不是让模型泛化到新的 (unseen)任务中Multiple domain learning 是一种multi-task learning,在多个相关域上进行训练,为每个源域学习好的模型,不是新的目标域
  • Transfer learning 在一个源任务上训练模型,希望增强模型在不同但相关的目标域/任务中的表现。pretraining-finetuning是迁移学习的常用策略,其中源域和目标域具有不同的任务,并且在训练中访问目标域。在域泛化中,目标域是不可访问的,并且训练和测试集通常是相同的,只是他们拥有不同的分布。
  • Domain adaptation 希望根据已有的几个源域最大化模型在指定目标域上的表现。Domain adaptation 和 Domain generalization最大的区别是DA可以访问目标域数据,但是DG在训练时无法看到它们。
  • Meta-learning 希望算法可以通过以前的经验或任务自我学习,又叫 learn-to-learn。相比于域泛化 (学习任务相同),元学习的学习任务是不同的。 元学习也是一个DG常用的策略,我们可以通过在源域中模拟 meta-train 和 meta-test 任务来增进DG的表现。(meta-train, meta-test tbd)
  • Lifelong learning 或 continuous learning,注重模型在多个连续领域/任务之间的学习能力。它要求模型随着时间的推移不断学习新知识,同时保留以前学到的经验。这也与 DG 不同,因为它可以在每个时间步访问目标域,并且它没有显式地处理跨域的不同分布。
  • Zero-shot learning 旨在从可见类中学习模型,对训练中看不到类别的样本进行分类。域泛化研究的是训练数据和测试数据来自同一类但分布不同的问题。

此外,域泛化还与分布鲁棒优化(distributionally robust optimization, DRO)有关,DRO的目的是学习最坏分布情况 (worst-case distribution scenario) 下的模型,希望它能很好地推广到测试数据。DRO专注于优化过程,也可以用于域泛化的研究。此外,DG还可以通过数据操作 (data manipulation) 或表示学习 (representation learning) 方法来实现,这与DRO方法不同。

4. Theory

见论文部分。


5. Methodology

本文将现有的域泛化分为三组,分别是:

  • 数据操作 Data Manipulation: 生成样本以增强模型泛化能力的最便宜、最简单的方法之一。主要目标是使用不同的数据操作方法增加现有训练数据的多样性。同时,数据量也随之增加。学习目标是:

minhEx,y[(h(x),y)]+Ex,y[(h(x),y)]\min_h\mathbb{E}_{\mathbf{x},y}[\ell(h(\mathbf{x}),y)]+\mathbb E_{\mathbf{x}',y}[\ell(h(\mathbf{x}'),y)]

其中x=mani(x)\mathbf{x}'=\text{mani}(\mathbf{x}),代表着Data Manipulation操作。
- 数据增强 Data Augmentation: augmentation, randomization, and transformation of input data 典型的增强操作包括翻转、旋转、缩放、裁剪、添加噪声等。

- 数据生成 Data Generation: generates diverse samples to help generalization
  • 表示学习 Representation Learning:

    • 域不变表示学习 Domain-invariant Representation Learning: performs kernel, adversarial training, explicitly feature alignment between domains, or invariant risk minimization to learn domain-invariant representations

    • 特征解耦 Feature Disentanglement: disentangle the features into domain-shared or domain specific parts for better generalization

  • 学习策略 Learning Strategy:

    • 集成学习 Ensemble Learning: relies on the power of ensemble to learn a unified and generalized predictive function

    • 元学习 Meta-learning: based on the learning-to-learn mechanism to learn general knowledge by constructing meta-learning tasks to simulate domain shift

    • 梯度操作 Gradient Operation: learn generalized representations by directly operating on gradients.

    • 其他学习策略 Other Learning Strategy

tbd…