XGBoostをJavaのwrapperを使用して実行する

スポンサーリンク

Kaggle や KDD cup などの機械学習コンペで人気な分類器である XGBoost に関して、R と Python での使用方法はいくつか日本語の記事があるのですが Java は見つからなかったのでまとめました。現状では Java を使う理由が無いのであれば R を使った方が視覚的にも分かり易くオススメです。

XGBoostとは

XGBoost とは Gradient Boosting Decision Tree(以下GBDT)の高速な C++ 実装です。今までの GBM より10倍高速らしいです。Java 以外にも R と Python のラッパーもあります。

 xgboost/xgboost4j.md at master · dmlc/xgboost · GitHub

専門職でないので理論の説明はできませんが、調べた結果を簡単にまとめると以下の通りです。

複数の弱学習記(GBDTの場合は決定木)を1つずつ順番に構築して集団で学習させます。新しい学習器を構築する際にそれまでに構築された学習器を利用するため、計算を並列化することはできません。 前のステップで間違って識別されたものへのウェイトを重くして、その状態で次のステップで間違ったものをうまく識別できるように学習器を構築し...ということを指定した回数だけ繰り返します。各ステップごとに弱学習記を構築して損失関数を最小化します。弱学習器を集めているせいか、予測精度は高いのに過適合しにくいのが特徴であり、誤差に対して学習しなおしてくれるので良いモデルができあがります。

チューニング次第では今話題のディープラーニングにも迫る精度が出るらしいです。すごいですね。

XGBoostのビルド

XGBoost は Mavenリポジトリに登録されていないのでソースからビルドする必要があります。

$ git clone https://github.com/dmlc/xgboost.git
$ cd xgboost/java/
$ ./create_wrap.sh
build java wrapper
...
move native lib
complete

以下のエラーが出たらyum install gcc-c++でコンパイラをインストールしてください。

$ ./create_wrap.sh
build java wrapper
which: no g++-5 in (/usr/local/bin:/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/sbin:/home/vagrant/bin)
g++ -c -Wall -O3 -msse2  -Wno-unknown-pragmas -funroll-loops -fopenmp -fPIC -o updater.o src/tree/updater.cpp
which: no gcc-5 in (/usr/local/bin:/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/sbin:/home/vagrant/bin)
which: no g++-5 in (/usr/local/bin:/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/sbin:/home/vagrant/bin)
make: g++: コマンドが見つかりませんでした
make: *** [updater.o] エラー 127

ビルドできたら Maven で jar ファイルにパッケージングします。

$ cd xgboost4j/
$ mvn package
...
[INFO] Building jar: /home/vagrant/xgboost/java/xgboost4j/target/xgboost4j-1.1.jar
[INFO] ------------------------------------------------------------------------
[INFO] BUILD SUCCESS
[INFO] ------------------------------------------------------------------------
[INFO] Total time: 55.182 s
[INFO] Finished at: 2015-10-12T23:24:57+09:00
[INFO] Final Memory: 17M/42M
[INFO] ------------------------------------------------------------------------

データの用意

まず実行するために学習データの用意が必要なので定番のアヤメ(iris)を使用します。ただし Java では入力データが  LIBSVM 形式 でなければいけません。加工する必要がありますが、今回は  LIBSVM Data: Classification (Multi Class) に iris の LIBSVM 形式への加工済データがありますのでこれを使用します。

$ wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale

LIBSVM 形式は次のように label に目的変数を、value に1から始まるインデックスを付けて説明変数を指定します。

<label> <index1>:<value1> <index2>:<value2> ... <indexN>:<valueN>

アヤメのデータの場合、目的変数はsetosa,versicolor,virginicaですがこのままでは使用できませんので数値に変換する必要があります。XGBoost のクラスは 0 ベースなのでsetosa -> 0,versicolor -> 1,virginica -> 2という感じです。

しかし、ダウンロードしたiris.scaleを見るとラベルが 1 〜 3 になっていますので修正する必要があります。ついでに各ラベルの30件をiris.train.scaleとして学習データに、20件をiris.test.scaleとしてテストデータに使用できるように分けましょう。R とかなら簡単にできるのですが、Java なので今回は件数も少ないですし手動でやった方が早いです。

XGBoostの実行

以上で準備が出来たのでビルドして出来た jar ファイルtarget/xgboost4j-1.1.jarをクラスパスに追加して次のサンプルコードを実行してください。今回は 3 クラスの分類になるので、多クラス分類のmulti:softmaxとクラス数num_classに 3 を設定するだけの最低限のパラメータで実行します。

実行結果は以下のようになります。

10 18, 2015 12:27:03 午前 org.dmlc.xgboost4j.util.Trainer train
情報: [0] train-merror:0.044444   test-merror:0.050000
10 18, 2015 12:27:03 午前 org.dmlc.xgboost4j.util.Trainer train
情報: [1] train-merror:0.044444   test-merror:0.050000
...
10 18, 2015 1:17:12 午後 org.dmlc.xgboost4j.util.Trainer train
情報: [9] train-merror:0.044444   test-merror:0.033333

predict length1: 60
predict length2: 1
error num: 2.0
error of predicts: 0.033333335

予測の結果はどの学習方法を選択してもfloat[][]の2次元配列です。multi:softmaxではどのクラスに分類されたかだけが出力されるので2次元目は1要素しかありません。

multi:softprobの場合はデータがどのクラスに属するかの予測確率が出るので、クラスが 0〜2 の場合は 0 の確率がpredicts[n][0]に、2 の確率がpredicts[n][2]のように出力されます。

0.98633957 0.0047045643 0.008955895
0.9847445 0.005253896 0.010001636
0.98764324 0.0042555854 0.008101191
0.6970763 0.052834447 0.25008926
0.66417605 0.07275859 0.26306537
...

予測の対象が単体の場合

まとめて予測する場合を除いて、1件、または数件ずつのデータを予測したい場合は、以下のように floatの特徴量配列からDMatrixを生成します。ファイルから読み込む場合との違いとしてLIBSVM形式でなくていいです。また、ラベルも含める必要があるので注意しましょう。ラベルの値は何でもいいです。

// {0, 0.111111F, 0.0833333F, 0.694915F, 0.1F} は {ラベル, 説明変数1, ... 説明変数4} で1セットとなる。
float[] test = new float[] {0, 0.111111F, 0.0833333F, 0.694915F, 0.1F, 0, -0.166667F, -0.416667F, 0.38983F, 0.5F....};
int nrow = 3;
int ncol = 15;
DMatrix testMat = new DMatrix(test, nrow, ncol);

モデルの保存と読み込み

作成したモデルはバイナリファイルで保存することができます。保存したモデルは R や Python で使用することもでき、逆に R や Python で作成したモデルを Java で読み込むことも可能です。

// モデルの保存
Booster booster = Trainer.train(param.entrySet(), trainMat, round, watchs, null, null);
booster.saveModel(modelPath);

// モデルの読み込み
Booster booste = new Booster(param.entrySet(), modelPath);