Encoder-decoderモデルとTeacher Forcing、Scheduled Sampling、Professor Forcing

Encoder-decoderモデルとTeacher Forcing、それを拡張したScheduled Sampling、Professor Forcingについて簡単に書きました。

概要

Encoder-decoderモデルは、ソース系列をEncoderと呼ばれるLSTMを用いて固定長のベクトルに変換(Encode)し、Decoderと呼ばれる別のLSTMを用いてターゲット系列に近くなるように系列を生成するモデルです。もちろん、LSTMでなくてGRUでもいいです。機械翻訳のほか、文書要約や対話生成にも使われます。

encdec Encoder-decoderモデルの概略図

Encoder

ソース系列X=(x1,x2,,xS)X = (x_1, x_2, \dots, x_S)が与えられたとき、ソースの辞書サイズVSV_Sに基づく埋め込み重みUSRES×VSU_S \in \mathbb{R}^{E_S \times V_S}により、埋め込み表現へと変換され、各タイムスタンプ毎に順にセル(cellRh\text{cell} \in \mathbb{R}^{h})と隠れ状態(hiddenRh\text{hidden} \in \mathbb{R}^{h})が計算される。

cellt+1,hiddent+1=LSTM(USxt,cellt+1,hiddent+1)\text{cell}_{t+1}, \text{hidden}_{t+1} = \text{LSTM}(U_Sx_t, \text{cell}_{t+1}, \text{hidden}_{t+1})

ここで、xtx_tVSV_S次元のone-hotベクトル、hhは隠れ状態の次元、ESE_Sはソース側の埋め込み次元を表す。なお、セルと隠れ状態は一般に0で初期化されることが多い。 LSTMについて前回の記事を参照ください。

この段階でソース系列は固定長ベクトルであるセルと隠れ状態に変換(Encode)されました2。

Decoder

Encoder側で出力されたセルと隠れ状態からDecoderを用いて系列を生成していきます。 初期状態は”BOS”記号や”null”記号を開始記号として下記のような数式で一単語目w1w_1を生成します。

cell1,hidden1=LSTM(UTwnull,cell0,hidden0)w1=argmax(WThidden1)\text{cell}_{1}, \text{hidden}_{1} = \text{LSTM}(U_T w_{\text{null}}, \text{cell}_{0}, \text{hidden}_{0})\\ w_1 = \text{argmax}({W_{T} \text{hidden}_{1}})

ここで、UTU_Tはターゲットの辞書サイズVTV_Tに基づく埋め込み重み、WTW_TWTRVT×hW_T \in \mathbb{R}^{V_T \times h}である。

この調子で

cell2,hidden2=LSTM(UTw1,cell1,hidden1)w2=argmax(WThidden2)\text{cell}_{2}, \text{hidden}_{2} = \text{LSTM}(U_T w_1, \text{cell}_{1}, \text{hidden}_{1})\\ w_2 = \text{argmax}({W_{T} \text{hidden}_{2}})

と、(w1,w2,)(w_1, w_2, \dots)を計算して、ターゲット系列Y=(y1,y2,,yT)Y = (y_1, y_2, \dots, y_T)との交差エントロピーを使って学習すると思われます。

without teacher forcing 前の出力をそのまま次の入力として使って学習を行う例

しかし、この方法だと連鎖的に誤差が大きくなっていき、学習が不安定になったり、収束が遅かったりしてしまうという問題があります。

Teacher Forcing

この問題に対して、Teacher Forcingという方法を取ることが多いですが、この方法にも評価時に問題が生じます(後述)。 Teacher Forcingとは、訓練時には下図のようにDecoder側の入力にはターゲット系列YYをそのまま使うというものです。

teacher forcing Encoder-decoderモデルにおけるTeacher Forcingの概略図

こうすることによって学習が安定し、収束が早くなるというメリットがありますが、 逆に、評価時はDecoderの入力が自動生成されたものが使われるため、学習時と分布が異なってしまうというデメリットもあります。

Teacher Forcingの拡張

Scheduled Sampling

Teacher Forcingの拡張として、Scheduled Sampling [Samy Bengio+ 2015]がある。

scheduled sampling Scheduled Sampling

Schedule Samplingは上図のように、ターゲットyty_tを入力とするか、生成されたwtw_tを入力とするか確率的にサンプルするというもの。

Professor Forcing

Professor Forcing[Lamb+ 2016]とは、Free Running(Decoder側で自動生成されたものを入力とする)とTeacher Forcingで出力されるセルと隠れ状態の差を小さくするようにGANを用いたと言うもの。

  • Generator:Free RunningとTeacher Forcingの出力の区別がつかないようにする。
  • Discreminator:Free RunningとTeacher Forcingを正しく分類できるようにする。
professor forcing Professor Forcing

Written by@Minato Sato
Senior Software Engineer - Embedded AI

GitHubTwitterFacebookLinkedIn

© 2023 Minato Sato. All Rights Reserved