ゼロからAI理論を再構築する

ー 文系エンジニアがAIの内部構造をゼロから理解する記録 ー

ミニバッチ学習とSGD:全部見なくても学習できる理由

前回、勾配降下法で損失を減らしていくという話を書きました。ただ、あの説明は「全データの勾配をまとめて計算する」前提になっていて、データが数百万件ある現実のタスクだとそのまま使うのは厳しいです。1回パラメータを更新するだけで全データを走査するので、とにかく遅い。

今回は、この計算コストの問題をどう回避するかという話です。

バッチ勾配降下法の問題

理論通りの勾配降下法(バッチ勾配降下法)では、全データ \( n \) 個に対する損失の平均勾配を計算します。

\[ \nabla \hat{R}(w) = \frac{1}{n} \sum_{i=1}^{n} \nabla L(f(x_i), y_i) \]

全データを使うので勾配の推定は正確です。ただ、 \( n \) が大きいと1ステップの計算に時間がかかりすぎて、実用的ではありません。

SGDは1点だけ見て更新する

確率的勾配降下法(SGD)は割り切りがすごくて、全データの中から1つだけランダムに選んで、その1点の勾配でパラメータを更新します。

\[ w_{t+1} = w_t - \eta \nabla L(f(x_i), y_i) \]

1点しか見ていないので、更新の方向はかなりブレます。全体の谷底とは違う方向に進んでしまうことも普通にある。ただ、繰り返せば平均的には正しい方向に向かうことが理論的に保証されています。

このブレ(ノイズ)には副産物もあって、損失の地形に浅い窪み(局所解)がある場合、正確な勾配だとそこにハマって出られなくなることがあります。SGDのノイジーな更新は、浅い窪みを飛び越える効果があるとされています。欠点に見えるものが利点にもなるというのは、自分がこの分野を勉強していて面白いと感じるところです。

ミニバッチが実質的な標準

全データだと遅い、1点だと不安定。その間を取ったのがミニバッチ学習です。

データを \( m \) 個(32や64、256あたりが多い)のグループに分けて、グループ内の平均勾配で更新します。

\[ \nabla \hat{R}_{batch}(w) = \frac{1}{m} \sum_{j=1}^{m} \nabla L(f(x_j), y_j) \]

SGDより勾配の推定が安定するのに加えて、GPUとの相性がいいのも大きいです。GPUは同じ計算を大量に並列で回すのが得意なので、ミニバッチ単位で処理するとかなり速くなります。PyTorchやTensorFlowも基本的にミニバッチ前提で設計されていて、データローダーにバッチサイズを指定するだけで勝手にこの仕組みが動きます。

まとめ

全データを使わず、断片的なサンプルで勾配を推定して更新を繰り返す。精度は落ちるけど速いし、ノイズが局所解の回避に役立つこともある。ミニバッチはそのちょうどいい落としどころです。

次回は、ここまで暗黙の前提にしてきた「線形モデル」について、何ができて何ができないのかを整理します。


参考文献