Encoder-decoderモデルとTeacher Forcing、それを拡張したScheduled Sampling、Professor Forcingについて簡単に書きました。
概要
Encoder-decoderモデルは、ソース系列をEncoderと呼ばれるLSTMを用いて固定長のベクトルに変換(Encode)し、Decoderと呼ばれる別のLSTMを用いてターゲット系列に近くなるように系列を生成するモデルです。もちろん、LSTMでなくてGRUでもいいです。機械翻訳のほか、文書要約や対話生成にも使われます。
Encoder-decoderモデルの概略図
Encoder
ソース系列X=(x1,x2,…,xS)が与えられたとき、ソースの辞書サイズVSに基づく埋め込み重みUS∈RES×VSにより、埋め込み表現へと変換され、各タイムスタンプ毎に順にセル(cell∈Rh)と隠れ状態(hidden∈Rh)が計算される。
cellt+1,hiddent+1=LSTM(USxt,cellt+1,hiddent+1)
ここで、xtはVS次元のone-hotベクトル、hは隠れ状態の次元、ESはソース側の埋め込み次元を表す。なお、セルと隠れ状態は一般に0で初期化されることが多い。
LSTMについて前回の記事を参照ください。
この段階でソース系列は固定長ベクトルであるセルと隠れ状態に変換(Encode)されました2。
Decoder
Encoder側で出力されたセルと隠れ状態からDecoderを用いて系列を生成していきます。
初期状態は”BOS”記号や”null”記号を開始記号として下記のような数式で一単語目w1を生成します。
cell1,hidden1=LSTM(UTwnull,cell0,hidden0)w1=argmax(WThidden1)
ここで、UTはターゲットの辞書サイズVTに基づく埋め込み重み、WTはWT∈RVT×hである。
この調子で
cell2,hidden2=LSTM(UTw1,cell1,hidden1)w2=argmax(WThidden2)
と、(w1,w2,…)を計算して、ターゲット系列Y=(y1,y2,…,yT)との交差エントロピーを使って学習すると思われます。
前の出力をそのまま次の入力として使って学習を行う例
しかし、この方法だと連鎖的に誤差が大きくなっていき、学習が不安定になったり、収束が遅かったりしてしまうという問題があります。
Teacher Forcing
この問題に対して、Teacher Forcingという方法を取ることが多いですが、この方法にも評価時に問題が生じます(後述)。
Teacher Forcingとは、訓練時には下図のようにDecoder側の入力にはターゲット系列Yをそのまま使うというものです。
Encoder-decoderモデルにおけるTeacher Forcingの概略図
こうすることによって学習が安定し、収束が早くなるというメリットがありますが、
逆に、評価時はDecoderの入力が自動生成されたものが使われるため、学習時と分布が異なってしまうというデメリットもあります。
Teacher Forcingの拡張
Scheduled Sampling
Teacher Forcingの拡張として、Scheduled Sampling [Samy Bengio+ 2015]がある。
Scheduled Sampling
Schedule Samplingは上図のように、ターゲットytを入力とするか、生成されたwtを入力とするか確率的にサンプルするというもの。
Professor Forcing
Professor Forcing[Lamb+ 2016]とは、Free Running(Decoder側で自動生成されたものを入力とする)とTeacher Forcingで出力されるセルと隠れ状態の差を小さくするようにGANを用いたと言うもの。
- Generator:Free RunningとTeacher Forcingの出力の区別がつかないようにする。
- Discreminator:Free RunningとTeacher Forcingを正しく分類できるようにする。
Professor Forcing