ゼロからAI理論を再構築する

ー 文系エンジニアがAIの内部構造をゼロから理解する記録 ー

LSTM:忘れる・覚える・取り出すを制御する

前回、RNNは長い系列で勾配消失が起きるため、長期的な依存関係を学ぶのが苦手だという話を書きました。LSTM(Long Short-Term Memory)は、この問題に対処するためにゲート機構を導入したRNNの拡張です。

セル状態という別経路

LSTMが通常のRNNと違うのは、隠れ状態とは別に「セル状態(Cell State)」という経路を持っている点です。

通常のRNNだと、情報は毎ステップ行列演算と活性化関数を通るので、ステップを重ねるうちに劣化していきます。LSTMのセル状態は、この複雑な計算を通さずに情報を次のステップへ流せる経路です。不要な変換を経由しないぶん、勾配が消えにくい。

3つのゲートで情報を選別する

セル状態にどの情報を残し、何を捨てるかを制御するのがゲートです。シグモイド関数の出力(0から1)を使って、情報の通過量を調整します。

忘却ゲートは、過去のセル状態のどの部分を捨てるかを決めます。

\[ f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \]

入力ゲートは、新しい情報のどの部分をセル状態に書き込むかを決めます。

\[ i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \]

出力ゲートは、更新されたセル状態のどの部分を隠れ状態として外に出すかを決めます。

\[ o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \]

それぞれのゲートは学習されるパラメータを持っていて、どのタイミングで何を忘れ、何を覚え、何を出力するかをデータから学びます。名前だけ見ると複雑ですが、やっていることは「0から1の値で情報を絞る」という同じ操作の繰り返しです。

なぜ長期記憶が可能になるのか

セル状態の更新式を見るとわかります。

\[ C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t \]

前のセル状態に忘却ゲートをかけて(一部を忘れて)、新しい情報を入力ゲートで絞って足す。ポイントは「足し算」で更新されている点です。

RNNの勾配消失は、行列の掛け算が連鎖することで起きていました。LSTMではセル状態が足し算で更新されるので、勾配が過去へ遡るときに減衰しにくい。ResNet(残差接続)と似た発想で、足し算による勾配のバイパスが長期記憶を可能にしています。

まとめ

LSTMはセル状態という別経路と3つのゲートで、情報の忘却・記憶・出力を制御します。セル状態が足し算で更新されるため勾配が消えにくく、RNNでは難しかった長期的な依存関係を学習できます。構造は複雑に見えますが、勾配消失への対策としては「足し算でバイパスを作った」というシンプルな一点に帰着します。

ゲートによる「どの情報に注目するか」の制御は、後のAttentionメカニズムに通じる考え方です。次回はEncoder-Decoderモデルについて書きます。


参考文献