log.saiias

あてにならない備忘録

オンライン学習器AdamをC++で実装してみた

  • 元論文はこちらです.

  • 自分の実装はここに置いてあります.

1データに対する学習率の調整部分は部分は以下の通りです. (論文内の擬似コードをそのまま実装している感じです.

double pred = sigma(X,i);
for(int idx = 0; idx < d; idx++){
  if(X(i,idx) != 0){
    double tbeta = 1-(1-beta1) * pow(lambda,iter);
    double grad = (pred - label(i)) * X(i,idx)+ C * w(idx);
    m[idx] = tbeta * grad + (1-tbeta) * m[idx];
    v[idx] = beta2*pow(grad,2) + (1-beta2) * v[idx];
    double hat_m = m[idx]/(1-pow((1-beta1),iter+1));
    double hat_v = v[idx]/(1-pow((1-beta2),iter+1));
    w(idx) -= alpha*hat_m/(sqrt(hat_v) + epsilon);
  }
}

実装が間違っている可能性は高いですが,この実装ではa9aデータセットにおいてAdagradやAdadeltaなどと大きな差は出ませんでした.(損失関数やサンプルサイズが小さすぎるのかもしれません)

参考にしたサイト