0%

adjoint state method

Adjoint State Method

Abstract

将深度神经网络视为连续动力系统的想法诞生了一类新的深度学习模型,它们拥有着一些独特的性质,在实践性能上也较为优越。需要指出,动力系统的解是以数值形式隐式表达的,通常的反向传播算法并不适用,ASM长久以来被用于此类模型的优化中。

Introduction

考虑时间序列向量\(\mathbf{y}_i\in\mathbb{R}^n,i=1,\ldots,N\),这里\(n\)是观测空间的维度,\(N\)是观测数,与落在区间\(I :=[0,T]\)\(t_i\)对应。在研究中经常使用如下初值问题来模型化这些数据: \[ \begin{aligned} d_{t} \boldsymbol{u} &=\boldsymbol{f}(t, \boldsymbol{u}, \boldsymbol{\phi}), \quad t \in[0, T] \\ \boldsymbol{u}(0) &=\boldsymbol{u}_{0}(\boldsymbol{\phi}) \end{aligned}\tag{1.1} \] 这里\(\boldsymbol{\phi}\)为模型参数,\(\boldsymbol{u}\in\mathbb{R}^m\),预测函数是对式子\((1.1)\)的积分结果与一个之后的后处理,比如样本降维。可以表示为\(\hat{y}=\mathscr{P}(\boldsymbol{u}(t, \boldsymbol{\phi}))=: \boldsymbol{g}(t, \boldsymbol{\phi}),\mathscr{P} : \mathbb{R}^{m} \rightarrow \mathbb{R}^{n}\)\(\mathscr{P}\)即为对解值的后处理算子。

伴随状态方法的目的,就是为了对上面的模型进行优化,即有效的计算下面式子的一二阶梯度(关于参量): \[ l(\boldsymbol{\phi})=\pm \sum_{i} d\left(\boldsymbol{y}_{i}, \boldsymbol{g}\left(t_{i}, \boldsymbol{\phi}\right)\right)\tag{1.2} \] 这里\(d : \mathbb{R}^{n} \times \mathbb{R}^{n} \rightarrow[0, \infty)\)是一个充分光滑的度量函数,式子\((1.2)\)给出了函数族的泛函。

这里必须指出,我们通常无法确定\(\boldsymbol{u}\)的解析形式,而是完全以数值的方式计算它在每个时刻的值,只有\(\boldsymbol{f}(\cdot)\)的解析形式是已知的,直接用链式法则解析地计算\(d_\phi l\)并不可行(\(J_\phi(\boldsymbol{u})\)无法计算),这正是这个问题的难点所在。

Adjoint-state method

ASM被用于不同的领域中,有着较长的历史,但很难准确地追溯它的起源,因为它只是基于另一个一般的法则-对偶。它的主要思想是得到式子\((1.1)\)的对偶,这使得我们能够以简单的形式写出式子\((1.2)\)的梯度,并且能够简单地计算。通常来说,在合适的hilbert空间中得到的内积都包含着对偶状态。

模型敏感性

定义\(f\)的一阶方向导数为\(\mathscr{D} f(\boldsymbol{x} ; \boldsymbol{h})\),令\(s :=\mathscr{D} \boldsymbol{u}(\boldsymbol{\phi} ; \boldsymbol{h})\),若它存在,则通过式子\((1.1)\)可得\(\boldsymbol{s}\)满足如下的初值问题: \[ \begin{aligned} d_{t} \boldsymbol{s} &=J_{u}(\boldsymbol{f}) \boldsymbol{s}+J_{\boldsymbol{\phi}}(\boldsymbol{f}) \boldsymbol{h}, \quad t \in[0, T] \\ \boldsymbol{s}(0) &=J_{\boldsymbol{\phi}}\left(\boldsymbol{u}_{0}\right) \boldsymbol{h} \end{aligned}\tag{2.1} \] 这被称为敏感性等式,这里\(J_{\boldsymbol{\phi}}(\boldsymbol{f}) : \mathbb{R}^{p} \rightarrow \mathbb{R}^{m}\)\(J_{\boldsymbol{u}}(\boldsymbol{f}) : \mathbb{R}^{m} \rightarrow \mathbb{R}^{m}\)\(\boldsymbol{f}\)的雅可比阵。令\(\boldsymbol{e}_i,i=1,\ldots,p\)为空间\(\mathbb{R}^p\)的规范基,对\(\boldsymbol{h}=\boldsymbol{e}_i,i=1,\ldots,p\)进行积分得到\(\boldsymbol{s}=\left(\frac{\partial \boldsymbol{u}_{1}}{\partial \boldsymbol{\phi}_{i}}, \frac{\partial \boldsymbol{u}_{2}}{\partial \boldsymbol{\phi}_{i}}, \ldots, \frac{\partial \boldsymbol{u}_{m}}{\partial \boldsymbol{\phi}_{i}}\right)^{*}\). 这意味着计算整个雅可比阵\(J_{\phi}(\boldsymbol{u})\)需要对初值问题\((2.1)\)进行\(p\)次积分,这样做的复杂性与一阶有限差分相同。

定理1.假设映射\(\boldsymbol{f} : U \subseteq \mathbb{R} \times \mathbb{R}^{m} \times \mathbb{R}^{p} \rightarrow \mathbb{R}^{m}\)\(u_{0} : V \subseteq \mathbb{R}^{p} \rightarrow \mathbb{R}^{m}\)都是\(C^{k}\)的,\(k \geq 1\)并且\(U,V\)是分别包含\(\left(0, \boldsymbol{u}_{0}\left(\boldsymbol{\phi}_{0}\right), \boldsymbol{\phi}_{0}\right)\)\(\boldsymbol{\phi}_{0}\)的开集,那么:

  • (a)存在一个区间\((-a,a)\)\(a>0\),以及一个开邻域\(U\left(\boldsymbol{\phi}_{0}\right)\)使得初值问题\((1.1)\)对每个\(\phi \in U\left(\boldsymbol{\phi}_{0}\right)\)都正好有一个解。
  • (b)映射\((t, \boldsymbol{\phi}) \mapsto \boldsymbol{u}(t ; \boldsymbol{\phi})\)\((-a, a) \times U\left(\boldsymbol{\phi}_{0}\right)\)上是\(C^{k}\)的,并且\((2.1)\)也满足。

连续化处理

这里将离散数据与连续模型通过似然泛函的方式进行联系,我们可以作如下表示: \[ \sum_{i} d\left(\boldsymbol{y}_{i}, \boldsymbol{g}\left(t_{i}, \boldsymbol{\phi}\right)\right)=\int_{0}^{T} \delta\left\{t-t_{i}\right\} d(\boldsymbol{y}(t), \boldsymbol{g}(t, \boldsymbol{\phi})) \mathrm{d} t\tag{2.2} \] 这里\(\delta\left\{t-t_{i}\right\}\)为所有测量时间上的狄拉克函数。同时为了使上面的积分定义良好,我们定义\(\boldsymbol{y}(t) :=\boldsymbol{y}_{i}, t \in\left(t_{i}-\varepsilon, t_{i}+\varepsilon\right)\),这里\(\varepsilon\)为一个小的正常量。使用线性开拓,我们使\(\boldsymbol{y}(t)\)在区间\([0,T]\)上连续。

为了式子\((2.2)\)的良好定义,\(\boldsymbol{g}(t, \boldsymbol{\phi})\)应该至少在每个\(t_i\)附近是连续的。这里假定\(\mathscr{P} \in \mathscr{L}\left(\mathbb{R}^{m}, \mathbb{R}^{n}\right)\),这个条件足够充分使下式成立: \[ \mathscr{P}(\boldsymbol{u}(t, \boldsymbol{\phi}))-\mathscr{P}\left(\boldsymbol{u}_{0}(\boldsymbol{\phi})\right)=\int_{0}^{t} \mathscr{P}(\boldsymbol{f}(t, \boldsymbol{\phi}, \boldsymbol{u})) \mathrm{d} t\tag{2.3} \] 由定理1以及假设,可以保证\(\boldsymbol{g}(t,\boldsymbol{\phi})\)是连续的。

利用希尔伯特空间\(L^{2}([0, T])\)上的内积,式子\((2.2)\)可以表达为: \[ \int_{0}^{T} \delta\left\{t-t_{i}\right\} d(\boldsymbol{y}(t), \boldsymbol{g}(t, \boldsymbol{\phi})) \mathrm{d} t=\left\langle\delta\left\{t-t_{i}\right\}, d(\boldsymbol{y}(t), \boldsymbol{g}(t, \boldsymbol{\phi}))\right\rangle\tag{2.4} \] 为简便起见,我们将\(\delta\{t-t_i\}\)记为\(\{\delta\}\),使用上面的记号我们可以证明如下的引理

引理1.令定理1的假设在\(k=1\)的情况满足,度量函数\(d\)\(C^1\)的。则泛函\((1.2)\)为frechet可微的,并且导数\(\mathscr{D} l(\boldsymbol{\phi} ; \boldsymbol{h})\)可以表示为: \[ \mathscr{D} l(\boldsymbol{\phi} ; \boldsymbol{h})=\left\langle\{\delta\} d_{\boldsymbol{u}}(\boldsymbol{y}, \boldsymbol{g}(t, \boldsymbol{\phi})), \boldsymbol{s}\right\rangle\tag{2.3} \] 这里\(\boldsymbol{s}\)为敏感性方程\((2.1)\)的唯一解。用链式法则即可证明,而且由于式子\((2.1)\)的线性性质,导数\(\mathscr{D l}(\boldsymbol{\phi} ; \boldsymbol{h})\)可以被简单地写作线性化的形式\(\mathscr{D} l(\boldsymbol{\phi} ; \boldsymbol{h})=l^{\prime}(\boldsymbol{\phi}) \boldsymbol{h}\),并且我们在有限维的情况考虑问题,下面的式子: \[ \nabla l(\boldsymbol{\phi}) \cdot \boldsymbol{h} :=l^{\prime}(\boldsymbol{\phi}) \boldsymbol{h} \quad \text { for all } \boldsymbol{h} \in \mathbb{R}^{p}\tag{2.4} \] 良好地定义了\(l\)的梯度\(\nabla l(\boldsymbol{\phi})\),作为\(\boldsymbol{\phi}\)的映射,即\(\nabla l : \mathbb{R}^{p} \rightarrow \mathbb{R}^{p}\).

ASM

直接来说,方法的核心是如下的定理:

定理2.令引理1的假设被满足,那么式子\((2.3)\)中的一阶frechet导数也能够被写作: \[ \mathscr{D} l(\boldsymbol{\phi} ; \boldsymbol{h})=-\boldsymbol{v}^{*}(0) J_{\boldsymbol{\phi}}\left(\boldsymbol{u}_{0}\right) \boldsymbol{h}-\left(J_{\boldsymbol{\phi}}(\boldsymbol{f}) \boldsymbol{h}, \boldsymbol{v}\right)\tag{2.5} \]

这里\(\boldsymbol{v}\)是下面初值问题的唯一解: \[ \begin{aligned} d_{t} \boldsymbol{v} &=-J_{u}^{*}(\boldsymbol{f}) \boldsymbol{v}+\{\delta\} d_{\boldsymbol{u}}(\boldsymbol{y}, \boldsymbol{g}(t, \boldsymbol{\phi})), \quad t \in[0, T] \\ \boldsymbol{v}(T) &=0 \end{aligned}\tag{2.6} \] 式子\((2.5)\)可以写为线性形式,即: \[ \nabla l=-\boldsymbol{v}^{*}(0) J_{\boldsymbol{\phi}}\left(\boldsymbol{u}_{0}\right)-\left(\boldsymbol{v}, J_{\boldsymbol{\phi}}(\boldsymbol{f})\right)\tag{2.7} \] 我们只需要对伴随问题\((2.6)\)进行积分,然后估计式子\((2.7)\)即可,式子中的每一项都可以直接计算。

证明

这里只说定理2的证明:

首先假设已经得到了定理2中式子\((2.6)\)对于某个测量时刻\(t_i\)的解\(\boldsymbol{v}\),从结果上说,这里会将其延拓至\([t_i,t_{i-1})\).

\(\boldsymbol{v}_{i}^{+}\)\(t_i^+\)时刻方程的解,在这个时刻中断积分,将\(d_{\boldsymbol{u}}\left(\boldsymbol{y}_{i}, \boldsymbol{g}\left(t_{i}, \boldsymbol{\phi}\right)\right)\)加入\(\boldsymbol{v}_{i}^{+}\)中然后求解 \[ \begin{aligned} d_{t} \boldsymbol{v} &=-J_{\boldsymbol{u}}^{*}(\boldsymbol{f}) \boldsymbol{v}, \quad t \in\left(t_{i}, t_{i-1}\right) \\ \boldsymbol{v}\left(t_{i}\right) &=\boldsymbol{v}_{i}^{+}+d_{\boldsymbol{u}}\left(\boldsymbol{y}_{i}, \boldsymbol{g}\left(t_{i}, \boldsymbol{\phi}\right)\right) \end{aligned}\tag{3.1} \] 这是一个简单的线性ODE,系数\(J_{u}^{*}(f)\)为连续的,因为\(f \in \boldsymbol{C}^{1}\). 经典理论告诉我们这会得到一个\((t_i,t_{i-1})\)上的全局解。上面即说明了存在以及唯一性,下面对式子\((2.5)\)进行证明。

假定在起始与终止时刻都没有进行测量,有下面分部积分公式: \[ \int_{0}^{T} d_{t} \boldsymbol{v} \boldsymbol{w} d t=[\boldsymbol{v} \boldsymbol{w}]_{0}^{T}-\int_{0}^{T} \boldsymbol{v} d_{t} w d t\tag{3.2} \] 对于\(\boldsymbol{w}\in\boldsymbol{C}^1\)成立,尽管\(d_t\boldsymbol{v}\)仅仅是几乎处处存在。又因为\(\boldsymbol{s} \in \boldsymbol{C}^{1}([0, T])\),所以有: \[ \begin{aligned}\left\langle s,\{\delta\} d_{u}(\boldsymbol{y}, \boldsymbol{g}(t, \boldsymbol{\phi}))\right\rangle &=\left\langle\boldsymbol{s}, d_{t} \boldsymbol{v}+J_{\boldsymbol{u}}^{*}(\boldsymbol{f}) \boldsymbol{v}\right\rangle\quad (2.6) \\ &=-v^{*}(0) J_{\boldsymbol{\phi}}\left(\boldsymbol{u}_{0}\right) \boldsymbol{h}-\left(d_{t} \boldsymbol{s}-J_{\boldsymbol{u}}(\boldsymbol{f}) \boldsymbol{s}, \boldsymbol{v}\right)\quad(3.2)\\ &=-v^{*}(0) J_{\boldsymbol{\phi}}\left(\boldsymbol{u}_{0}\right) \boldsymbol{h}-\left(J_{\boldsymbol{\phi}}(\boldsymbol{f}) \boldsymbol{h}, \boldsymbol{v}\right).\quad(2.1) \end{aligned}\tag{3.3} \] 即证。