Goでオンライン線形分類器
とりあえず作ってみた. 今回は以下のアルゴリズムを実装した.
- Perceptron
- Passive Aggressive
- AROW
それぞれの細かい説明は省く. AROWのcovarianceの計算は対角成分のみを利用する実装となっている.
ソースコードはこちら saiias/goml · GitHub
ベクトル計算
ベクトルは配列ではなく今後スパースなベクトルも扱えるようにmap[int]float64とした. (各要素は倍精度でななく単精度でも十分だったかもしれない…)
オンライン学習でよく用いられる内積計算と定数倍とベクトルとベクトルの足し算を実装した.
type Array map[int]float64 func (a Array) Dot(b Array) float64 { s := 0.0 for i, elem := range a { if value, ok := b[i]; ok { s += elem * value } } return s } func (a *Array) ConsFactor(f float64) *Array { mutex.Lock() defer mutex.Unlock() vec := make(Array, 0) for i, elem := range *a { vec[i] = f * elem } return &vec } func (a *Array) Add(b *Array) { mutex.Lock() defer mutex.Unlock() for i, elem := range *b { if _, ok := (*a)[i]; ok { (*a)[i] += elem } else { (*a)[i] = elem } } }
実装は更新部分のみ記載します.
Perceptron
func (p *Perceptron) Update(label float64, vec array.Array) { pred := p.W.Dot(vec) if pred*label <= 0 { grad := vec.ConsFactor(p.Eta * label * pred) p.W.Add(grad) } }
Passive Aggressive
func (p *PA) pa2(l float64, vec *array.Array) float64 { return l / (math.Pow(vec.Norm(), 2) + (1.0 / 2 * p.C)) } func (p *PA) pa1(l float64, vec *array.Array) float64 { return math.Min(p.C, l/math.Pow(vec.Norm(), 2)) } func (p *PA) Update(label float64, vec array.Array) { pred := p.W.Dot(vec) if pred*label <= 0 { l := math.Max(0, 1-label*p.W.Dot(vec)) tau := 0.0 if p.Loss == "hinge" { tau = p.pa1(l, &vec) } else if p.Loss == "squere_hinge" { tau = p.pa2(l, &vec) } p.W.Add(vec.ConsFactor(tau * label)) } }
AROW
func (a *Arow) Update(label float64, vec array.Array) { margin := a.W.Dot(vec) if margin*label >= 1.0 { return } confidence := 0.0 for i, value := range vec { if _, ok := a.Confidence[i]; ok { confidence += a.Confidence[i] * value * value } else { confidence += value * value } } beta := 1.0 / (confidence + a.R) alpha := (1.0 - label*margin) * beta for i, value := range vec { if _, ok := a.Confidence[i]; !ok { a.Confidence[i] = 1.0 } a.W[i] += alpha * a.Confidence[i] * value * label a.Confidence[i] = 1.0 / ((1.0 / a.Confidence[i]) + value*value/a.R) } }
精度確認
実装間違ってるかもしれないがとりあえずLIVSVMのサンプルデータセットのa9aで精度確認をしてみた.
サンプルコードはこちら
デフォルト引数がないのでデフォルトのパラメータみたいのを定義できないのがすこし残念
データを1イテレーションのみで学習&Accuracyで精度比較を行った. ハイパーパラメータの調整は適当.
Perceptron Acc:0.7637737239727289 PA-1 Acc:0.7839813279282599 AROW Acc:0.8499477919046742
とりあえず大間違いはしてなさそう.