AdaGrad + RDAを実装してみた
AdaGrad + RDA
「結局オンライン最適化はAdaGrad1択だよね」
「AdaGrad + RDAでの分類が精度一番良いみたいです」
「AdaGrad + RDAの方が実装がはるかに簡単」
と立て続けに聞いたのでAdaGrad + RDA*1を理解するために実装してみた。
結論から言うと、確かに実装は簡単だし精度もでました。
損失関数と正則化項
AdaGrad自身は最適化手法です。すると適用するには最適化問題の形で記述する必要があります。
分類問題における定式化では損失関数と正則化項を足したものになります。今回の損失関数は
ヒンジロスとなります。
は1 or -1のラベル、はベクトル空間上にマッピングしたデータです。
正則化項は-正則化、-正則化など色々有りますが今回は-正則化となります。
劣微分
で、最適化問題を解くわけです。今回はやの偏微分を用いるのですが、max関数や絶対値とか入ってて微分できないわけです。こういう時のための劣微分を用います。劣微分の説明は省略しますが、平たく言うと微分不可能な値での微分値としては微分の値としてとり得る範囲の値のどれかを使おうというものです。例えばヒンジロスの部分の劣微分は次のようになります。
ちょうどとなるところでは微分不可能ですがどちらの値を採用しても構わないので0とします。
省略できる処理
は0にしても問題ない*2。は学習データが多い時はあんま影響ないので1にしとけばいい*3などがあります。また、学習データの値が0の要素では値は変わらないので、計算を省略することができます。最後にが対角行列になる場合に限定します。
アルゴリズム
以上の項目と合わせつつ冒頭に書いたアルゴリズムをうにゃうにゃ変形させる*4とt番目のデータによる学習は次の様になります。
//ヒンジロスが0の場合は劣微分は0になるので学習にならない double loss = 1.0 - t.first * calcInnerProduct(t.second); if(loss <= 0.0){ return; } //update gradient numTrain++; for(SparseVector::const_iterator i=t.second.begin();i!=t.second.end();++i){ int feature = i->first; double value = i->second; SumOfGradients[feature] -= t.first * value; SumOfSquaredGradients[feature] += value * value; //sign(u_{t,i}) int sign = SumOfGradients[feature] > 0 ? 1 : -1; //|u_{t,i}|/t - \lambda double meansOfGradients = sign * SumOfGradients[feature] / (double)numTrain - lambda; if(meansOfGradients < 0.0){ // x_{t,i} = 0 SeparationPlain.erase(feature); }else{ // x_{t,i} = sign(- u_{t,i}) * \frac{\eta t}{\sqrt{G_{t,ii}}}(|u_{t,i}|/t - \lambda) SeparationPlain[feature] = -1 * sign * (double)numTrain * meansOfGradients / sqrtf(SumOfSquaredGradients[feature]); } }
実装は https://github.com/jnishi/rda こちらにあげています。
実行結果
News20*5のデータのうち、15000件を正解データとし、残る4996件をテストデータとして分類をしてみたところ96.7174%であった。 https://code.google.com/p/oll/wiki/OllMainJaこちらによると様々なオンラインアルゴリズムの中で一番良いConfidence Weightedでも96.437%なので確かにAdagrad + RDAの方が性能が良い*6。
*1:原論文:「Adaptive Subgradient Methods for Online Learning and Stochastic Optimization.」 http://www.magicbroom.info/Papers/DuchiHaSi10.pdf
*2:原論文参照
*3:「Notes on AdaGrad」http://www.ark.cs.cmu.edu/cdyer/adagrad.pdf
*4:原論文参照
*5: http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#news20.binary
*6:Soft Confidence Weightedだとどうなるのだろう