HadoopTimes

ストリーミングアーキテクチャ Apache Kafka とMapR Streams による新しい設計手法
技術情報

Apache Spark機械学習ランダムフォレストを用いた信用貸付のリスク予測

このブログ投稿では、銀行融資の信用リスク分類のためにApache Sparkのspark.mlランダムフォレストの利用を開始するお手伝いをいたします。Sparkのspark.mlライブラリの目的は、データフレーム上でユーザー自身が機械学習ワークフローやパイプラインの作成や調整に役立つAPIのセットを提供することです。データフレーム上でspark.mlを利用すると、インテリジェントな最適化によりパフォーマンスが向上します。

クラシフィケーション(分類)

クラシフィケーション(分類)は、機械学習アルゴリズムを監視する機能の一つです。ラベルをつけられた既知の項目 (例えば、取引が詐欺か詐欺でないかがわかっている) に基づいて、どのカテゴリに属するか (例えば、取引が詐欺か詐欺でないか) を特定します。分類は、既知のラベルとあらかじめ定められた特徴のデータセットを受け取り、その情報に基づいてどのように新たなレコードをラベルするかを学習します。特徴はあなたが問うことになるif構文が使用され、ラベルはその質問への回答です。次の例では、もしもそれが、歩き、泳ぎ、アヒルのようにガーガー鳴くなら、ラベルは「アヒル」です。

クラシフィケーション(分類)

銀行融資の信用リスクの例を見てみましょう。

  • 何を予測するか?
    • 融資を返済するかどうか
    • ラベル: その人の信用力
  • 予測のために利用できるif構文またはプロパティは何か?
    • 申請者の人口統計的および社会経済的なプロファイル: 職業、年齢、婚姻区分、貯蓄額など。
    • これらは特徴であり、類別詞モデルの確立のために、分類にもっとも寄与する特徴を抽出します。

デシジョンツリー

デシジョンツリーでは、入力されたいくつかの特徴に基づいてクラスやラベルを予測するモデルを作ります。デシジョンツリーの仕組みは、すべてのノードで特徴を含む表現を評価し、その回答に基づいて次のノードへの分岐を選択します。信用リスクの予測のために考えられるデシジョンツリーを以下に示します。特徴の質問がノードで、「はい」か「いいえ」の回答が次のノード(質問)となります。

  •  Q1: 当座預金残高 > 200DM ですか?
    • いいえ
    •  Q2: 現在の職場での雇用 > 1年 ですか?
      • いいえ
      • 信用力なし

デシジョンツリー

ランダムフォレスト

アンサンブル学習アルゴリズムでは、より良いモデルを得るために機械学習アルゴリズムを複数組み合わせます。ランダムフォレストは分類と回帰のための最も一般的なアンサンブル学習法です。このアルゴリズムでは、トレーニングステージでの異なるデータサブセットに基づいて、複数のデシジョンツリーからなるモデルを確立します。すべてのデシジョンツリーの出力を組み合わせる方法で予測を行い、それによって変動を少なくし、予測の精度を向上させます。ランダムフォレスト分類では、それぞれのデシジョン・ツリーの予測は1つのクラスの投票として数えられます。ラベルはもっとも多くの票を獲得したクラスとして予測されます。

ランダムフォレスト

Spark機械学習シナリオによる信用リスクの分析

利用するサンプルデータはドイツのクレジットカードデータセットです。これは信用リスクが「良好」か「不良」か、という属性セットによる記述で分類しています。各銀行融資の申請については、以下の情報を利用できます。

ドイツのクレジットデータ

ドイツのクレジットカードのCSVファイルフォーマットは以下のとおりです。

1,1,18,4,2,1049,1,2,4,2,1,4,2,21,3,1,1,3,1,1,1
1,1,9,4,0,2799,1,3,2,3,1,2,1,36,3,1,2,3,2,1,1
1,2,12,2,9,841,2,4,2,2,1,4,1,23,3,1,1,2,1,1,1

このシナリオでは、以下の特徴に基づいて、信用力の「ある」、「なし」でラベル/分類を予測し、デシジョンツリーのランダムフォレストを作成します。

  • ラベル → 信用力が「ある」か「ない」 (1か0)
  • 特徴 → {残高、履歴、目的など}

ソフトウェア

このチュートリアルの実行環境はSpark 1.6.1です。

  • この例を実行するためのコードとデータはここからダウンロードできます。
  • この投稿での例は、Sparkシェルコマンドで起動した後に、Sparkシェルで実行できます。
  • コードはスタンドアロンのアプリケーションとしても実行できます。方法は以下のチュートリアルに記載されています。
    MapR SandboxでSparkを始める

MapR SandboxでSparkを始めるに記載したとおり、ユーザーIDをuser01、パスワードをmaprとして、MapR Sandboxにログインします。scpで、サンプルデータファイルをSandboxホームのdirectory/user/user01にコピーします。(Sandbox上でSparkのバージョンを更新する必要があるかもしれないことに注意してください。) 以下により、Sparkシェルを開始します。

$spark-shell --master local[1]

CSVファイルからのデータロード及びパース

最初に機械学習パッケージをインポートします。

(コードボックスでコメントは緑色で、出力は青色で記述されます)

import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.VectorAssembler
import sqlContext.implicits._
import sqlContext._
import org.apache.spark.ml.tuning.{ ParamGridBuilder, CrossValidator }
import org.apache.spark.ml.{ Pipeline, PipelineStage }

CSVデータファイル内のある行に対応する信用スキーマを定めるために、Scalaケースクラスを利用します。

// define the Credit Schema
case class Credit(
    creditability: Double,
    balance: Double, duration: Double, history: Double, purpose: Double, amount: Double,
    savings: Double, employment: Double, instPercent: Double, sexMarried: Double, guarantors: Double,
    residenceDuration: Double, assets: Double, age: Double, concCredit: Double, apartment: Double,
    credits: Double, occupation: Double, dependents: Double, hasPhone: Double, foreign: Double
  )

以下の関数は、データファイルの行を信用力クラスへとパースします。すべてが常に0から開始されるために、いくつかのカテゴリ値から1を差し引きます。

// function to create a  Credit class from an Array of Double
def parseCredit(line: Array[Double]): Credit = {
    Credit(
      line(0),
      line(1) - 1, line(2), line(3), line(4) , line(5),
      line(6) - 1, line(7) - 1, line(8), line(9) - 1, line(10) - 1,
      line(11) - 1, line(12) - 1, line(13), line(14) - 1, line(15) - 1,
      line(16) - 1, line(17) - 1, line(18) - 1, line(19) - 1, line(20) - 1
    )
  }
// function to transform an RDD of Strings into an RDD of Double
  def parseRDD(rdd: RDD[String]): RDD[Array[Double]] = {
    rdd.map(_.split(",")).map(_.map(_.toDouble))
  }

以下では、germancredit.csvファイルのデータをStringのRDDにロードします。次に、RDD上のマップ変換を利用して、ParseRDD関数でRDDの各String要素をDoubleの配列に変換します。それから、別のマップ変換を利用して、ParseCredit関数でRDDの各Double配列を信用力オブジェクトの配列に変換します。toDF()メソッドは、信用力クラススキーマをもつデータフレームへ配列 [[信用力]] のRDDを変換します。

// load the data into a  RDD
val creditDF= parseRDD(sc.textFile("germancredit.csv")).map(parseCredit).toDF().cache()
creditDF.registerTempTable("credit")

データフレームprintSchema() はスキーマをデシジョン・ツリーフォーマットでコンソールに表示します。

// Return the schema of this DataFrame
creditDF.printSchema

root
 |-- creditability: double (nullable = false)
 |-- balance: double (nullable = false)
 |-- duration: double (nullable = false)
 |-- history: double (nullable = false)
 |-- purpose: double (nullable = false)
 |-- amount: double (nullable = false)
 |-- savings: double (nullable = false)
 |-- employment: double (nullable = false)
 |-- instPercent: double (nullable = false)
 |-- sexMarried: double (nullable = false)
 |-- guarantors: double (nullable = false)
 |-- residenceDuration: double (nullable = false)
 |-- assets: double (nullable = false)
 |-- age: double (nullable = false)
 |-- concCredit: double (nullable = false)
 |-- apartment: double (nullable = false)
 |-- credits: double (nullable = false)
 |-- occupation: double (nullable = false)
 |-- dependents: double (nullable = false)
 |-- hasPhone: double (nullable = false)
 |-- foreign: double (nullable = false)

// Display the top 20 rows of DataFrame 
creditDF.show

+-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+
|creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign|
+-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+
|          1.0|    0.0|    18.0|    4.0|    2.0|1049.0|    0.0|       1.0|        4.0|       1.0|       0.0|              3.0|   1.0|21.0|       2.0|      0.0|    0.0|       2.0|       0.0|     0.0|    0.0|
|          1.0|    0.0|     9.0|    4.0|    0.0|2799.0|    0.0|       2.0|        2.0|       2.0|       0.0|              1.0|   0.0|36.0|       2.0|      0.0|    1.0|       2.0|       1.0|     0.0|    0.0|
|          1.0|    1.0|    12.0|    2.0|    9.0| 841.0|    1.0|       3.0|        2.0|       1.0|       0.0|              3.0|   0.0|23.0|       2.0|      0.0|    0.0|       1.0|       0.0|     0.0|    0.0|
|          1.0|    0.0|    12.0|    4.0|    0.0|2122.0|    0.0|       2.0|        3.0|       2.0|       0.0|              1.0|   0.0|39.0|       2.0|      0.0|    1.0|       1.0|       1.0|     0.0|    1.0|
|          1.0|    0.0|    12.0|    4.0|    0.0|2171.0|    0.0|       2.0|        4.0|       2.0|       0.0|              3.0|   1.0|38.0|       0.0|      1.0|    1.0|       1.0|       0.0|     0.0|    1.0|
|          1.0|    0.0|    10.0|    4.0|    0.0|2241.0|    0.0|       1.0|        1.0|       2.0|       0.0|              2.0|   0.0|48.0|       2.0|      0.0|    1.0|       1.0|       1.0|     0.0|    1.0|
|          1.0|    0.0|     8.0|    4.0|    0.0|3398.0|    0.0|       3.0|        1.0|       2.0|       0.0|              3.0|   0.0|39.0|       2.0|      1.0|    1.0|       1.0|       0.0|     0.0|    1.0|
|          1.0|    0.0|     6.0|    4.0|    0.0|1361.0|    0.0|       1.0|        2.0|       2.0|       0.0|              3.0|   0.0|40.0|       2.0|      1.0|    0.0|       1.0|       1.0|     0.0|    1.0|
|          1.0|    3.0|    18.0|    4.0|    3.0|1098.0|    0.0|       0.0|        4.0|       1.0|       0.0|              3.0|   2.0|65.0|       2.0|      1.0|    1.0|       0.0|       0.0|     0.0|    0.0|
|          1.0|    1.0|    24.0|    2.0|    3.0|3758.0|    2.0|       0.0|        1.0|       1.0|       0.0|              3.0|   3.0|23.0|       2.0|      0.0|    0.0|       0.0|       0.0|     0.0|    0.0|
|          1.0|    0.0|    11.0|    4.0|    0.0|3905.0|    0.0|       2.0|        2.0|       2.0|       0.0|              1.0|   0.0|36.0|       2.0|      0.0|    1.0|       2.0|       1.0|     0.0|    0.0|
|          1.0|    0.0|    30.0|    4.0|    1.0|6187.0|    1.0|       3.0|        1.0|       3.0|       0.0|              3.0|   2.0|24.0|       2.0|      0.0|    1.0|       2.0|       0.0|     0.0|    0.0|
|          1.0|    0.0|     6.0|    4.0|    3.0|1957.0|    0.0|       3.0|        1.0|       1.0|       0.0|              3.0|   2.0|31.0|       2.0|      1.0|    0.0|       2.0|       0.0|     0.0|    0.0|
|          1.0|    1.0|    48.0|    3.0|   10.0|7582.0|    1.0|       0.0|        2.0|       2.0|       0.0|              3.0|   3.0|31.0|       2.0|      1.0|    0.0|       3.0|       0.0|     1.0|    0.0|
|          1.0|    0.0|    18.0|    2.0|    3.0|1936.0|    4.0|       3.0|        2.0|       3.0|       0.0|              3.0|   2.0|23.0|       2.0|      0.0|    1.0|       1.0|       0.0|     0.0|    0.0|
|          1.0|    0.0|     6.0|    2.0|    3.0|2647.0|    2.0|       2.0|        2.0|       2.0|       0.0|              2.0|   0.0|44.0|       2.0|      0.0|    0.0|       2.0|       1.0|     0.0|    0.0|
|          1.0|    0.0|    11.0|    4.0|    0.0|3939.0|    0.0|       2.0|        1.0|       2.0|       0.0|              1.0|   0.0|40.0|       2.0|      1.0|    1.0|       1.0|       1.0|     0.0|    0.0|
|          1.0|    1.0|    18.0|    2.0|    3.0|3213.0|    2.0|       1.0|        1.0|       3.0|       0.0|              2.0|   0.0|25.0|       2.0|      0.0|    0.0|       2.0|       0.0|     0.0|    0.0|
|          1.0|    1.0|    36.0|    4.0|    3.0|2337.0|    0.0|       4.0|        4.0|       2.0|       0.0|              3.0|   0.0|36.0|       2.0|      1.0|    0.0|       2.0|       0.0|     0.0|    0.0|
|          1.0|    3.0|    11.0|    4.0|    0.0|7228.0|    0.0|       2.0|        1.0|       2.0|       0.0|              3.0|   1.0|39.0|       2.0|      1.0|    1.0|       1.0|       0.0|     0.0|    0.0|
+-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+             

データフレームをインスタンス化した後に、SQLクエリを利用して問い合わせができます。以下にScalaデータフレームAPIを利用したクエリ例をいくつか示します。

count、mean、stddev、min、maxを含むnumeric列のcomputes statisticsを記述します。

//  computes statistics for balance 
  creditDF.describe("balance").show

+-------+-----------------+
|summary|          balance|
+-------+-----------------+
|  count|             1000|
|   mean|            1.577|
| stddev|1.257637727110893|
|    min|              0.0|
|    max|              3.0|
+-------+-----------------+
 

// compute the avg balance by creditability (the label) 
 creditDF.groupBy("creditability").avg("balance").show

+-------------+------------------+
|creditability|      avg(balance)|
+-------------+------------------+
|          1.0|1.8657142857142857|
|          0.0|0.9033333333333333|
+-------------+------------------+
 

データフレームを所与の名前で一時テーブルとして登録できます。次に、sqlContextが提供するsqlメソッドを利用してSQL文を実行できます。以下にsqlContextを利用したクエリ例をいくつか示します。

// Compute the average balance, amount, duration grouped by creditability  
 sqlContext.sql("SELECT creditability, avg(balance) as avgbalance, avg(amount) as avgamt, avg(duration) as avgdur  FROM credit GROUP BY creditability ").show

+-------------+------------------+------------------+------------------+
|creditability|        avgbalance|            avgamt|            avgdur|
+-------------+------------------+------------------+------------------+
|          1.0|1.8657142857142857| 2985.442857142857|19.207142857142856|
|          0.0|0.9033333333333333|3938.1266666666666|             24.86|
+-------------+------------------+------------------+------------------+
 

フィーチャーエクストラクション(特徴の抽出)

類別詞モデルを確立するためには最初に、分類にもっとも寄与する特徴を抽出します。ドイツクレジットデータセットではデータが2つのクラス ― 1 (信用力がある) と0 (信用力がない) ― でラベルされています。

各項目の特徴は以下に示すフィールドからなります。

  • ラベル → 信用力: 0または1
  • 特徴 → {“balance”, “duration”, “history”, “purpose”, “amount”, “savings”, “employment”, “instPercent”, “sexMarried”, “guarantors”, “residenceDuration”, “assets”, “age”, “concCredit”, “apartment”, “credits”, “occupation”, “dependents”, “hasPhone”, “foreign”}

特徴配列の定義

(Learning Sparkを参照する)

機械学習アルゴリズムで特徴を利用することを目的として、特徴を変換して特徴ベクトルの中に入れます。そのベクトルは各特徴の値を表す数字からなるベクトルです。

以下では、ベクトル列の変換を行い、ベクトル列中のすべての特徴列をもつ新たなデータフレームを返すために、VectorAssemblerを利用します。

//define the feature columns to put in the feature vector
val featureCols = Array("balance", "duration", "history", "purpose", "amount",
    "savings", "employment", "instPercent", "sexMarried",  "guarantors",
    "residenceDuration", "assets",  "age", "concCredit", "apartment",
    "credits",  "occupation", "dependents",  "hasPhone", "foreign" )
//set the input and output column names
  val assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features")
//return a dataframe with all of the  feature columns in  a vector column
val df2 = assembler.transform( creditDF)
// the transform method produced a new column: features.
df2.show

+-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+
|creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign|            features|
+-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+
|          1.0|    0.0|    18.0|    4.0|    2.0|1049.0|    0.0|       1.0|        4.0|       1.0|       0.0|              3.0|   1.0|21.0|       2.0|      0.0|    0.0|       2.0|       0.0|     0.0|    0.0|(20,[1,2,3,4,6,7,...|

次に、ラベルとして追加された信用力の列をもつデータフレームを返すために、StringIndexerを利用します。

//  Create a label column with the StringIndexer  
val labelIndexer = new StringIndexer().setInputCol("creditability").setOutputCol("label")
val df3 = labelIndexer.fit(df2).transform(df2)
// the  transform method produced a new column: label.
df3.show

+-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+
|creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign|            features|label|
+-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+
|          1.0|    0.0|    18.0|    4.0|    2.0|1049.0|    0.0|       1.0|        4.0|       1.0|       0.0|              3.0|   1.0|21.0|       2.0|      0.0|    0.0|       2.0|       0.0|     0.0|    0.0|(20,[1,2,3,4,6,7,...|  0.0|

以下のデータは、トレーニング用データセットとテスト用データセットに分けられます。70%のデータをモデルのトレーニングで利用し、30%をテストで利用します。

//  split the dataframe into training and test data
val splitSeed = 5043 
val Array(trainingData, testData) = df3.randomSplit(Array(0.7, 0.3), splitSeed)

モデルのトレーニング

次に、 以下のパラメータを用いて、RandomForest類別詞のトレーニングを行います。

  • maxDepth:デシジョン・ツリーの最大の深さ。深さを増すとモデルは強力になるが、深いツリーはトレーニングに多くの時間がかかる。
  • maxBins:連続的な特徴を離散化するために、また、各ノードで特徴を分割する方法を選択するために、使用されるビンの最大数。
  • impurity:情報取得の計算に利用される基準。
  • auto:デシジョン・ツリーの各ノードでの分割を検討するために特徴の数を自動的に選択する。
  • seed:結果の反復を許して、無作為なシード番号を利用する。

入力した特徴と、それらの特徴のラベルされた出力との関連付けを行う方法により、モデルをトレーニングします。

// create the classifier,  set parameters for training
val classifier = new RandomForestClassifier().setImpurity("gini").setMaxDepth(3).setNumTrees(20).setFeatureSubsetStrategy("auto").setSeed(5043)
//  use the random forest classifier  to train (fit) the model
val model = classifier.fit(trainingData) 

// print out the random forest trees
model.toDebugString
res20: String = 
res5: String = 
"RandomForestClassificationModel (uid=rfc_6c4ceb92ba78) with 20 trees
  Tree 0 (weight 1.0):
    If (feature 0 <= 1.0)
     If (feature 10 <= 0.0)
      If (feature 3 <= 6.0) Predict: 0.0 Else (feature 3 > 6.0)
       Predict: 0.0
     Else (feature 10 > 0.0)
      If (feature 12 <= 63.0) Predict: 0.0 Else (feature 12 > 63.0)
       Predict: 0.0
    Else (feature 0 > 1.0)
     If (feature 13 <= 1.0)
      If (feature 3 <= 3.0) Predict: 0.0 Else (feature 3 > 3.0)
       Predict: 1.0
     Else (feature 13 > 1.0)
      If (feature 7 <= 1.0) Predict: 0.0 Else (feature 7 > 1.0)
       Predict: 0.0
  Tree 1 (weight 1.0):
    If (feature 2 <= 1.0)
     If (feature 15 <= 0.0)
      If (feature 11 <= 0.0) Predict: 0.0 Else (feature 11 > 0.0)
       Predict: 1.0
     Else (feature 15 > 0.0)
      If (feature 11 <= 0.0) Predict: 0.0 Else (feature 11 > 0.0)
       Predict: 1.0
    Else (feature 2 > 1.0)
     If (feature 12 <= 31.0)
      If (feature 5 <= 0.0) Predict: 0.0 Else (feature 5 > 0.0)
       Predict: 0.0
     Else (feature 12 > 31.0)
      If (feature 3 <= 4.0) Predict: 0.0 Else (feature 3 > 4.0)
       Predict: 0.0
  Tree 2 (weight 1.0):
    If (feature 8 <= 1.0)
     If (feature 6 <= 2.0)
      If (feature 4 <= 10875.0) Predict: 0.0 Else (feature 4 > 10875.0)
       Predict: 1.0
     Else (feature 6 > 2.0)
      If (feature 1 <= 36.0) Predict: 0.0 Else (feature 1 > 36.0)
       Predict: 1.0
    Else (feature 8 > 1.0)
     If (feature 5 <= 0.0)
      If (feature 4 <= 4113.0) Predict: 0.0 Else (feature 4 > 4113.0)
       Predict: 1.0
     Else (feature 5 > 0.0)
      If (feature 11 <= 2.0) Predict: 0.0 Else (feature 11 > 2.0)
       Predict: 0.0
  Tree 3 ...

モデルのテスト

次に、予測を行うためにテストデータを用います。

// run the  model on test features to get predictions
val predictions = model.transform(testData) 
//As you can see, the previous model transform produced a new columns: rawPrediction, probablity and prediction.
predictions.show

+-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+--------------------+--------------------+----------+
|creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign|            features|label|       rawPrediction|         probability|prediction|
+-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+--------------------+--------------------+----------+
|          0.0|    0.0|    12.0|    0.0|    5.0|1108.0|    0.0|       3.0|        4.0|       2.0|       0.0|              2.0|   0.0|28.0|       2.0|      1.0|    1.0|       2.0|       0.0|     0.0|    0.0|(20,[1,3,4,6,7,8,...|  1.0|[14.1964586927573...|[0.70982293463786...|       0.0|

以下では、予測を評価します。テストラベル列をテスト予測列と比較することで精確なメトリック (ROC曲線下の面積) を返すためにBinaryClassificationEvaluatorを利用します。この場合には、評価は78%の精度を返します。

// create an Evaluator for binary classification, which expects two input columns: rawPrediction and label.
val evaluator = new BinaryClassificationEvaluator().setLabelCol("label")
// Evaluates predictions and returns a scalar metric areaUnderROC(larger is better). 
val accuracy = evaluator.evaluate(predictions) 
accuracy: Double = 0.7824906081835722

MLパイプラインの利用

次に、パイプラインを利用してモデルをトレーニングします。それにより良好な結果を得ることができます。パイプラインは、パラメータの異なる組み合わせをテストするための簡単な方法を提供します。これはグリッドサーチと呼ばれるプロセスを利用して、テストするパラメータを設定することで、MLLibがすべての組み合わせをテストします。パイプラインでは、パイプラインの各要素を個別に調整するよりも、ワークフローを構成する全モデルを1度に調整する方が容易になります。

以下では、パラメータグリッドを構成するためにParamGridBuilderユーティリティを用います。

// We use a ParamGridBuilder to construct a grid of parameters to search over
val paramGrid = new ParamGridBuilder()
  .addGrid(classifier.maxBins, Array(25, 28, 31))
  .addGrid(classifier.maxDepth, Array(4, 6, 8))
  .addGrid(classifier.impurity, Array("entropy", "gini"))
  .build()

パイプラインを作成し、設定します。パイプラインは、そのそれぞれが推定者または変換者である、段階のシーケンスで構成されます。

  
val steps: Array[PipelineStage] = Array(classifier)
val pipeline = new Pipeline().setStages(steps)

モデルの選択のためにCrossValidatorクラスを利用します。CrossValidatorは推定者、ParamMapsセット、評価者を利用します。CrossValidatorの利用はきわめて高額になることがあるので注意してください。

// Evaluate model on test instances and compute test error
val evaluator = new BinaryClassificationEvaluator()
  .setLabelCol("label")
val cv = new CrossValidator()
  .setEstimator(pipeline)
  .setEvaluator(evaluator)
  .setEstimatorParamMaps(paramGrid)
  .setNumFolds(10)

パイプラインは、パラメータグリッドを探索することで自動的に最適化を行います。各ParamMapについて、CrossValidatorは所与の推定者をトレーニングし、それを所与の評価者を用いて評価します。そして、最良のParamMapと全データセットを用いて最良の推定者を定めます。

// When fit is called, the stages are executed in order. 
// Fit will run cross-validation,  and choose the best set of parameters 
//The fitted model from a Pipeline is an PipelineModel, which consists of fitted models and transformers

val pipelineFittedModel = cv.fit(trainingData)

これで、テスト予測をテストラベルと比較することによって、パイプラインの最良適合モデルを評価できます。この場合では、評価者は以前の78%よりも精確な82%の精度を返します。

//  call tranform to make predictions on test data. The fitted model will use the best model found 
val predictions = pipelineFittedModel.transform(testData)
val accuracy = evaluator.evaluate(predictions)  
Double = 0.8204386232104784
val rm2 = new RegressionMetrics(
  predictions.select("prediction", "label").rdd.map(x =>
  (x(0).asInstanceOf[Double], x(1).asInstanceOf[Double])))
println("MSE: " + rm2.meanSquaredError)
println("MAE: " + rm2.meanAbsoluteError)
println("RMSE Squared: " + rm2.rootMeanSquaredError)
println("R Squared: " + rm2.r2)
println("Explained Variance: " + rm2.explainedVariance + "\n")

MSE: 0.2575250836120402
MAE: 0.25752508361204013
RMSE Squared: 0.5074692932700856
R Squared: -0.1687988628287138
Explained Variance: 0.15466269952237702

詳細情報について

このブログ記事では、Apache Sparkの機械学習ランダムフォレストと分類のためのmlパイプラインを記しました。さらに詳細を学びたい方は以下のページ(英語)をご覧ください。

著者情報

CarolMcDonald

キャロル・マクドナルド

(MapR Technologies ソリューションアーキテクト)

実践 機械学習 – レコメンデーションにおけるイノベーション –

実践 機械学習 – レコメンデーションにおけるイノベーション –
機械学習とレコメンデーションにおける、もっとも洗練され、効率的なアプローチの1つに至る鍵は、「仔馬が欲しい」という状況の観察の中にあります。

どれを選べばいいのかめまいがするほどの数多くのアルゴリズムがあり、それらの中から選択をするためだけでも、選択肢を理解し、合理的な判断を行うのに必要な、高度な数学の背景知識を十分に持っていることが前提になります。

こちらの資料は、機械学習とレコメンデーションについて学習したいけれど、どこから始めればよいか迷っているという方におすすめです。

無料ダウンロードはこちら

こちらの記事もおすすめです