こんにちは。最近GINZA SIXで本当のスタバ*1を知ってしまった福田です。
私たちの身の周りは、様々なデータで溢れています。 ある2つの異なるデータ集合を互いに紐付けたいこともよくあります。
どのように紐付けられるでしょうか。
一方のデータ集合から分類器をつくることができれば、分類結果を媒介として他のデータ集合とのマッチングができるかもしれません。
では、どうやって分類できるでしょう。
ここではSparkとHBaseを使って実装がシンプルで、文書分類でよく使われるナイーブベイズの分類器を実装してみます。
材料と調理器具
材料
特許の要約と分類のデータ
簡単のため以下のように正規化されたテーブル構造のデータがあるとします。
特許出願(appln)を親として、要約テキスト(appln_abstr)と、分類コード(appln_ipc)がぶら下がっています。今回使うのは右側の2つのデータのみです。
特許要約のデータ(appln_abstr)
約4千万行
サンプル
+--------+-------------------------------------------------------------------------------------------------------+ |appln_id|appln_abstract | +--------+-------------------------------------------------------------------------------------------------------+ |153620 |The present invention relates to uses, methods and compositions for modulating replication of viruse...| |197020 |A base station includes a group determination unit grouping mobile stations residing within a cell a...| |286620 |The programming method comprises supplying a turnoff voltage to the source terminal of the selected... | +--------+-------------------------------------------------------------------------------------------------------+
分類コードのデータ(appln_ipc)
約2億行
サンプル
+--------+----------------+ |appln_id|ipc_class_symbol| +--------+----------------+ |153620 |A61K 31/57 | |153620 |A61P 31/14 | |153620 |C12N 15/113 | |197020 |H04L 5/22 | |197020 |H04W 4/06 | |197020 |H04W 72/00 | |286620 |G11C 16/12 | |455820 |H04L 1/18 | |455820 |H04W 28/04 | +--------+----------------+
調理器具
- Spark 1.6
- HBase 1.2.0
設計編
ナイーブベイズの文書分類器*5では、分類対象の文書Dについて、分類毎に事後確率P(C|D)を計算し、その確率が最大のものを選択します。
分類アルゴリズム
事後確率
- P(C|D): 文書Dが与えられたときに分類Cである確率(事後確率)
- P(C): 分類Cが現れる確率(事前確率)
- P(D|C): 分類Cが与えられたときに文書Dが生成される確率(尤度)
P(D|C)は次の式のように、分類Cにワードが出現する確率の積で表されます。
ここで、分類時に必要なデータを事前に集計、計算し、永続化したものをモデルとします。
HBaseのデータ構造はは分散ソート済みマップとも呼ばれ、疎なデータを効率よく扱うことができます。ここではその構造を活かしてテーブル設計をしました。
テーブルレイアウト(イメージ)
row-key | feature: | stats: | ||
---|---|---|---|---|
cafebabe | スギ花粉 | 4 | n_occurence | 16 |
薬 | 2 | s_occurence | 300 | |
n_feature | 2 | |||
s_val | 6 | |||
label | 花粉症対策 | |||
prior_prob | 0.053 | |||
deadbeaf | ウィルス | 20 | n_occurence | 10 |
手洗い | 3 | s_occurence | 300 | |
うがい | 16 | n_feature | 3 | |
s_val | 39 | |||
label | インフルエンザ対策 | |||
prior_prob | 0.033 |
- row-keyはラベルのハッシュ値を分類IDとして使用
- feature:カラムファミリ
- ワードとその出現回数をKey-Valueとして格納
- stats:カラムファミリ
- prior_prob: 事前確率
- label: 分類ラベル
- s_val: 分類におけるの語彙毎の出現回数の合計
- n_occurence: 分類の出現回数(デバッグ用)
- s_occurence: 分類の総出現回数(デバッグ用)
- n_feature: 分類の語彙数(デバッグ用)
分類に必要な値以外に調査や検証に役立つ値も格納しています
実装編
テーブルの作成
CREATE 'test_model', { NAME => 'feature', VERSIONS => 1, COMPRESSION => 'LZ4', DATA_BLOCK_ENCODING => 'FAST_DIFF', 'IN_MEMORY' => 'true' }, { NAME => 'stats', VERSIONS => 1, COMPRESSION => 'LZ4', DATA_BLOCK_ENCODING => 'FAST_DIFF', 'IN_MEMORY' => 'true' }, { SPLITS => [ '1000000000000000000000000000000000000000', '2000000000000000000000000000000000000000', '3000000000000000000000000000000000000000', '4000000000000000000000000000000000000000', '5000000000000000000000000000000000000000', '6000000000000000000000000000000000000000', '7000000000000000000000000000000000000000', '8000000000000000000000000000000000000000', '9000000000000000000000000000000000000000', 'a000000000000000000000000000000000000000', 'b000000000000000000000000000000000000000', 'c000000000000000000000000000000000000000', 'd000000000000000000000000000000000000000', 'e000000000000000000000000000000000000000', 'f000000000000000000000000000000000000000' ] }
- スループットを稼ぎたいのでインメモリ指定しています
- データが分散するように分割ポイントを指定しています
モデルデータの生成
Sparkを使ってデータを変換していきます。
ステップ1 グループIDの生成
文書:分類が1:Nの関係となっているのを、ここでは複数コードの集合を1つのグループとして扱うように変換し、このグループのIDを生成します。*6
// 出願ごとに複数付与される階層分類コードの上位4桁の集合からカグループを生成し、ハッシュ値をキーとする val appln_group = appln_ipc.groupBy("appln_id") .agg(concat_ws(",", sort_array(collect_set(substring($"ipc_class_symbol", 0, 4)))) as "ipc_sss", sha1(concat_ws(",", sort_array(collect_set(substring($"ipc_class_symbol", 0, 4))))) as "group")
DataFrame APIを使用しています。
+--------+--------------+----------------------------------------+ |appln_id|ipc_sss |group | +--------+--------------+----------------------------------------+ |153620 |A61K,A61P,C12N|947ed80b48cae17d652fdd8fb6f6de2eff130710| |197020 |H04L,H04W |fc38214c2f0b7f1636cc0ee9206d2023537b0dcd| |286620 |G11C |e3a1862a8f7ce681b57c4c41f711922f3b0bb490| |455820 |H04L,H04W |fc38214c2f0b7f1636cc0ee9206d2023537b0dcd| +--------+--------------+----------------------------------------+
この段階で、1文書1行のデータに変換されます。
ステップ2 学習データの生成
// 約8割を学習データとする val trainingset = appln_group.filter($"appln_id" % lit(5) !== lit(0) // 学習データ(要約テキストからキーワードを抽出し、キーワード毎に行展開) val instances = trainingset.as("ag") .join(appln_abstr, "appln_id") .select($"ag.appln_id", explode(Util.wordcount($"appln_abstract", lit(30))) as "wc", $"ag.group" as "group", $"ag.ipc_sss" as "ipc") .select($"*", $"wc"("_1").as("keyword"), $"wc"("_2").as("n")) .drop("wc") }
ここはデータ全体の8割を学習データとし、残りをテストデータにし、実際の学習データを作ります。*7
要約のデータをjoinしつつキーワードを抽出し行を展開しています。 *8
データイメージ
+--------+----------------------------------------+--------------+----------------------------------------------------+---+ |appln_id|group |ipc |keyword |n | +--------+----------------------------------------+--------------+----------------------------------------------------+---+ |50020 |947ed80b48cae17d652fdd8fb6f6de2eff130710|A61K,A61P,C12N|virus |5 | |50020 |947ed80b48cae17d652fdd8fb6f6de2eff130710|A61K,A61P,C12N|cell proliferative disorders |4 | |50020 |947ed80b48cae17d652fdd8fb6f6de2eff130710|A61K,A61P,C12N|Ras-pathway |2 | |454620 |e3a1862a8f7ce681b57c4c41f711922f3b0bb490|G11C |ground stage |1 | |454620 |e3a1862a8f7ce681b57c4c41f711922f3b0bb490|G11C |process condition |1 | |454620 |e3a1862a8f7ce681b57c4c41f711922f3b0bb490|G11C |TR1 |1 | +--------+----------------------------------------+--------------+----------------------------------------------------+---+
ここまでで下ごしらえが完了です。次のステップではHBaseのテーブルの各カラムファミリに流しこむデータを生成していきます。
HDFS上のステージングディレクトリに一通りHFileを出力した後、最後にファイル群をHBaseのデータディレクトリに移動させてバルクロードを完了させます。
ステップ3 HFileの生成
hbase-sparkというライブラリ*9を使用してHFileを生成していきます。ここでは入力としてRDDへの変換をしています。
def createHFile(hbaseContext: HBaseContext, tableName: TableName, rdd: RDD[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])], stagingDir: String) = { hbaseContext.bulkLoad[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])]( rdd, tableName, (r) => { r._2.map { v => (new KeyFamilyQualifier(r._1, v._1, v._2), v._3) }.iterator }, stagingDir) } } val hbconf = HBaseConfiguration.create() val hbaseContext = new HBaseContext(sc, hbconf) // グループ毎のキーワード出現数(feature:列ファミリ)の集計データ生成 val wc_by_group = instances.groupBy($"group", $"keyword").agg(sum("n")).orderBy($"group", $"keyword") val feature_put_rdd = wc_by_group.rdd.map{ x => (Bytes.toBytes(x.getString(0)), Array((Bytes.toBytes("feature"), Bytes.toBytes(x.getString(1)), Bytes.toBytes(x.getLong(2))))) } createHFile(hbaseContext, TableName.valueOf(tableName), feature_put_rdd, stagingDirClassificationModel)
出力イメージ
+----------------------------------------+----------------------------------------------------+------+ |group |keyword |sum(n)| +----------------------------------------+----------------------------------------------------+------+ |947ed80b48cae17d652fdd8fb6f6de2eff130710|HSV |1 | |947ed80b48cae17d652fdd8fb6f6de2eff130710|Methods |1 | |e3a1862a8f7ce681b57c4c41f711922f3b0bb490|first output signal |3 | |e3a1862a8f7ce681b57c4c41f711922f3b0bb490|first resistor |6 | |e3a1862a8f7ce681b57c4c41f711922f3b0bb490|first transistor |6 | |fc38214c2f0b7f1636cc0ee9206d2023537b0dcd|terminal |1 | |fc38214c2f0b7f1636cc0ee9206d2023537b0dcd|traffic |3 | |fc38214c2f0b7f1636cc0ee9206d2023537b0dcd|traffic differentiation |1 | |fc38214c2f0b7f1636cc0ee9206d2023537b0dcd|transfer |1 | |fc38214c2f0b7f1636cc0ee9206d2023537b0dcd|wireless LAN |1 | +----------------------------------------+----------------------------------------------------+------+
stats:n_featureデータの生成
// グループ毎の特徴の種類数(stats:n_feature)のデータ生成 val n_feature_by_group = wc_by_group.groupBy($"group").agg(count("*")).orderBy($"group") val n_feature_put_rdd = n_feature_by_group.rdd.map { x => (Bytes.toBytes(x.getString(0)), Array( (Bytes.toBytes("stats"), Bytes.toBytes("n_feature"), Bytes.toBytes(x.getLong(1))))) } createHFile(hbaseContext, TableName.valueOf(tableName), n_feature_put_rdd, stagingDirClassificationModel)
出力イメージ
+----------------------------------------+--------+ |group |count(1)| +----------------------------------------+--------+ |947ed80b48cae17d652fdd8fb6f6de2eff130710|17 | |e3a1862a8f7ce681b57c4c41f711922f3b0bb490|83 | |fc38214c2f0b7f1636cc0ee9206d2023537b0dcd|70 | +----------------------------------------+--------+
stats:s_valとlabelデータの生成
// グループ毎のキーワード出現数の合計(stats:s_val)と 可視化ラベル(stats:label)のデータ生成 val s_val_by_group = instances.groupBy($"group", $"ipc").agg(sum("n")).orderBy($"group") val s_val_put_rdd = s_val_by_group.rdd.map{ x => (Bytes.toBytes(x.getString(0)), Array( (Bytes.toBytes("stats"), Bytes.toBytes("s_val"), Bytes.toBytes(x.getLong(2))), (Bytes.toBytes("stats"), Bytes.toBytes("label"), Bytes.toBytes(x.getString(1))))) } createHFile(hbaseContext, TableName.valueOf(tableName), s_val_put_rdd, stagingDirClassificationModel)
出力イメージ
+----------------------------------------+--------------+------+ |group |ipc |sum(n)| +----------------------------------------+--------------+------+ |947ed80b48cae17d652fdd8fb6f6de2eff130710|A61K,A61P,C12N|29 | |e3a1862a8f7ce681b57c4c41f711922f3b0bb490|G11C |263 | |fc38214c2f0b7f1636cc0ee9206d2023537b0dcd|H04L,H04W |216 | +----------------------------------------+--------------+------+
stats:prior_probのデータの生成
// グループごとの出現回数(stats:n_occurence)のデータ生成 val n_occurence_by_group = trainingset.groupBy($"group").agg(count("*")).orderBy($"group") n_occurence_by_group.persist(StorageLevel.DISK_ONLY) val s_group_occurence = trainingset.count().toDouble val n_occurence_put_rdd = n_occurence_by_group.rdd.map{ x => (Bytes.toBytes(x.getString(0)), Array( (Bytes.toBytes("stats"), Bytes.toBytes("prior_prob"), Bytes.toBytes(x.getLong(1) / s_group_occurence)), (Bytes.toBytes("stats"), Bytes.toBytes("s_occurence"), Bytes.toBytes(s_group_occurence)), (Bytes.toBytes("stats"), Bytes.toBytes("n_occurence"), Bytes.toBytes(x.getLong(1))))) } createHFile(hbaseContext, TableName.valueOf(tableName), n_occurence_put_rdd, stagingDirClassificationModel)
ステージングディレクトリに出力されたHFileをHBaseのデータディレクトリに移動させてバルクロードを完了させます。
テトリス*10のようなイメージです。
// バルクロードの完了 val conn = ConnectionFactory.createConnection(hbaseContext.config) val load = new LoadIncrementalHFiles(hbaseContext.config) load.doBulkLoad( new Path(stagingDirClassificationModel), conn.getAdmin, new HTable(hbaseContext.config, TableName.valueOf(tableName)), conn.getRegionLocator(TableName.valueOf(tableName)))
これで完成です。約100万分類のモデルができました。
分類フェーズ
学習フェーズで生成したモデルのテーブルを読み、実際に文書の分類を行うフェーズです。
HBaseの分類モデルのテーブルは1行が1分類となっており、分類対象の文書に含まれるキーワードから、どの分類から来たのかの確率をスコアとして計算し、スコアの高いものを10件ずつ出力しています。
下にコード(抜粋)とポイントを簡単に挙げます。
// (学習時と同じデータセット) val appln_group = appln_ipc.groupBy("appln_id").agg(concat_ws(",", sort_array(collect_set(substring($"ipc_class_symbol", 0, 4)))) as "ipc_sss", sha1(concat_ws(",", sort_array(collect_set(substring($"ipc_class_symbol", 0, 4))))) as "group") // 学習に使った残りの約2割からのデータをテストデータとする val testset = appln_group.filter($"appln_id" % lit(5) === lit(0)).limit(100000) testset.persist(StorageLevel.DISK_ONLY) val instances = testset.as("ag").join(appln_abstr, "appln_id").select($"ag.appln_id", Util.wordcount($"appln_abstract", lit(30)) as "keyword", $"ag.group" as "group", $"ag.ipc_sss" as "ipc", row_number().over(Window.partitionBy().orderBy($"appln_id")) as "row_num") // 行番号とバッチサイズを使ってバッチ分割(高速化のためバッチサイズ毎に1テーブルスキャンとするため) val input = instances.map{ r => (r.getInt(4) / batchSize, (r.getInt(0), r.getMap[String,Int](1), r.getString(2), r.getString(3))) }.groupByKey().repartition(2048) val hbconf = HBaseConfiguration.create() val tableNameObject = TableName.valueOf(tableName) val table = new HTable(hbconf, tableName) val conn = ConnectionFactory.createConnection(hbconf) val regionLocator: RegionLocator = conn.getRegionLocator(tableNameObject) val startEndKeys: org.apache.hadoop.hbase.util.Pair[Array[Array[Byte]], Array[Array[Byte]]] = regionLocator.getStartEndKeys() val range = startEndKeys.getFirst().zip(startEndKeys.getSecond()) // input: _1=appln_id, _2=keyword, _3=group, _4=label val result = input.map{ row => import scala.math.log // トップN件の分類結果を保持するための入れ物を用意しておく object EntryOrdering extends Ordering[(String, String, Double)] { def compare(a: (String, String, Double), b: (String, String, Double)) = -(a._3 compare b._3) } // バッチ内のエントリ(出願)毎にランキングのための優先キューを初期化 val topK = scala.collection.mutable.Map[Int, PriorityQueue[(String, String, Double)]]() for (d <- row._2) { topK.put(d._1, new PriorityQueue[(String, String, Double)]()(EntryOrdering)) } // バッチ毎に1テーブルスキャン(レンジスキャンをシャッフルしてクエリを分散させる) scala.util.Random.shuffle(range.toList).foreach { r => val hbconf = HBaseConfiguration.create() val table = new HTable(hbconf, tableName) val scan = new Scan(r._1, r._2) scan.setCaching(100) val scanner = table.getScanner(scan) for (r: Result <- scanner) { val classificationId = Bytes.toString(r.getRow()) val classificationStats = r.getFamilyMap(Bytes.toBytes("stats")) val classificationFeature = r.getFamilyMap(Bytes.toBytes("feature")) // 可視化のためのラベルを復元 val label = Bytes.toString(classificationStats.getOrElse(Bytes.toBytes("label"), Bytes.toBytes("N/A"))) // バッチサイズ件数分の処理 for (d <- row._2) { var score = log(Bytes.toDouble(classificationStats.getOrElse(Bytes.toBytes("prior_prob"), Bytes.toBytes(0.0)))) val features = d._2 for (f <- features) { val keyword = f._1 val n = f._2 val numerator = Bytes.toLong(classificationFeature.getOrElse(Bytes.toBytes(keyword), Bytes.toBytes(0L))).toDouble + 1.0 val denominator = Bytes.toLong(classificationStats.getOrElse(Bytes.toBytes("s_val"), Bytes.toBytes(1L))) + 70000000.0 val wordProb: Double = numerator / denominator score += n * log(wordProb) } val q = topK.get(d._1).get if (q.size < 10) { q.enqueue((classificationId, label, score)) } else { val smallest = q.dequeue() if (score > smallest._3) { q.enqueue((classificationId, label, score)) } else { q.enqueue(smallest) } } } } scanner.close() table.close() } // バッチの各行毎に計算結果を回収 val z = row._2.map{ d => // appln_id, keyword, group, label, result (d._1, d._2, d._3, d._4, topK.get(d._1).get.dequeueAll) } z }.flatMap(x => x) result.persist(StorageLevel.DISK_ONLY) result.map{ r => val keywords = r._2.toList.sortBy(-_._2).map{case (a, b) => s"${a}:${b}"}.mkString(",") val result = r._5.reverse.map(_.productIterator.mkString(":")).mkString("|") // appln_id, keyword, group, label, result s"${r._1}\t${keywords}\t${r._3}\t${r._4}\t${result}" }.saveAsTextFile("model_test.tsv")
ポイント
- 分類対象の件数に応じた計算リソースと時間が必要
- テーブルスキャンが分類対象の文書数分必要なので、バッチ化してHBase側の負荷とネットワークIOを抑制
- 上位10分類を出力するため優先度付きキューを使用
- レンジスキャンをシャッフルして実行することでHBaseのブロックキャッシュヒット率を上げることで速度を稼ぐ
実験結果
テストデータのセットから10万件を分類器にかけてみて、出力結果の評価をしました。 文書毎に返された上位10件の分類ラベルに対して、次の6通りを集計しました。
- 最上位で正解ラベルと完全一致する率
- 最上位で正解ラベルの構成要素の集合との共通部分が存在する率
- 5位以内で正解ラベルと完全一致する率
- 5位以内で正解ラベルの構成要素の集合との共通部分が存在する率
- 10位以内で正解ラベルと完全一致する率
- 10位以内で正解ラベルの構成要素の集合との共通部分が存在する率
それぞれ以下の数字でした。
- 24397/100000 0.2440
- 45486/100000 0.4549
- 40696/100000 0.4070
- 66264/100000 0.6626
- 46965/100000 0.4697
- 73903/100000 0.7390
考察と課題
- データの設計と実装上の工夫により比較的規模の大きなデータに対してジョブを完遂することができた
- 分類器の精度としては今ひとつな値だが、約100万の中から10個選んだものとしてみるとランダムよりは良さそう(ラベル付けが目的かマッチングが目的かで異なる解釈)というレベル
- 以下の要素についてもうすこし考える余地がありそう
- 分類の数と粒度
- テキストの分量と質
- 特徴(キーワード)抽出の方法
- 正解ラベルよりもよいラベルを引き当てた可能性についてはどう評価するか
まとめ
- SparkとHBaseを使ってスケーラブルな単純ベイズの文書分類器を実装しました。
- 数千万件の特許要約テキストと分類コードのデータから約百万通りの分類器を学習させてみました。
- 実装面、理論面の両方で更なる改良の余地がありそうです。
*1:http://www.starbucks.co.jp/coffee/reserve/
飲み方と豆の種類と淹れ方を選ぶとバリスタが目の前で作ってくれるスタイルです。匠の技やプロセスを間近で見れるのでエンジニアの方におすすめです。
*2:テキストは長いため切り詰めて表示しています。
*3:ipc_class_symbolは国際特許分類と呼ばれる階層化された分類コードで、出願案件毎に複数付与されます http://www.wipo.int/classifications/ipc/en/
*4:Cloudera社のCDH5を使わせていただいております。歴史的理由により、Spark1.6のコードとなっています。また、CDHでは2017年9月現在、hbase-sparkがSpark2系に対応していないようです。対応を心待ちにしています。
*5:ナイーブベイズの文書分類器については、Webや書籍で多くの情報があります。 以下の記事を参考にさせていただきました。
ナイーブベイズを用いたテキスト分類 - 人工知能に関する断創録
*6:これにより粒度と分類数を調節しています。今回実験のため、ヒューリスティックに約100万分類となるような変換を施しています
*7:id列の値の採番方法と分布に依存します。ここでは単純に連番を想定しています
*8:ここでは詳細を割愛していますが、wordcountはユーザ定義関数でテキストからキーワードを抽出し、スコアの高い順にキーワードと頻度のタプルのリスト(List[(String, Int)])として返してくれるモジュールを別途実装しています。
*9:hbase-spark
https://blog.cloudera.com/blog/2015/08/apache-spark-comes-to-apache-hbase-with-hbase-spark-module/ https://github.com/apache/hbase/tree/master/hbase-spark
*11:HBaseのBlock Cacheヒット率
実行開始後からヒット率が上昇し始め、高い値に収束しています。施策がない場合、並列実行する複数のタスクがほぼ同時に同じ領域を読むことになり、キャッシュ領域を有効に利用出来ずヒット率が低いまま推移していたため、ゆらぎを設けて出来るだけキャッシュに詰め込みました。