こんにちは。最近GINZA SIXで本当のスタバ*1を知ってしまった福田です。
私たちの身の周りは、様々なデータで溢れています。
ある2つの異なるデータ集合を互いに紐付けたいこともよくあります。
どのように紐付けられるでしょうか。
一方のデータ集合から分類器をつくることができれば、分類結果を媒介として他のデータ集合とのマッチングができるかもしれません。
では、どうやって分類できるでしょう。
ここではSparkとHBaseを使って実装がシンプルで、文書分類でよく使われるナイーブベイズの分類器を実装してみます。
材料と調理器具
材料
特許の要約と分類のデータ
簡単のため以下のように正規化されたテーブル構造のデータがあるとします。
特許出願(appln)を親として、要約テキスト(appln_abstr)と、分類コード(appln_ipc)がぶら下がっています。今回使うのは右側の2つのデータのみです。
特許要約のデータ(appln_abstr)
約4千万行
サンプル
+--------+-------------------------------------------------------------------------------------------------------+
|appln_id|appln_abstract |
+--------+-------------------------------------------------------------------------------------------------------+
|153620 |The present invention relates to uses, methods and compositions for modulating replication of viruse...|
|197020 |A base station includes a group determination unit grouping mobile stations residing within a cell a...|
|286620 |The programming method comprises supplying a turnoff voltage to the source terminal of the selected... |
+--------+-------------------------------------------------------------------------------------------------------+
*2
分類コードのデータ(appln_ipc)
約2億行
サンプル
+--------+----------------+
|appln_id|ipc_class_symbol|
+--------+----------------+
|153620 |A61K 31/57 |
|153620 |A61P 31/14 |
|153620 |C12N 15/113 |
|197020 |H04L 5/22 |
|197020 |H04W 4/06 |
|197020 |H04W 72/00 |
|286620 |G11C 16/12 |
|455820 |H04L 1/18 |
|455820 |H04W 28/04 |
+--------+----------------+
*3
調理器具
*4
設計編
ナイーブベイズの文書分類器*5では、分類対象の文書Dについて、分類毎に事後確率P(C|D)を計算し、その確率が最大のものを選択します。
分類アルゴリズム
事後確率
- P(C|D): 文書Dが与えられたときに分類Cである確率(事後確率)
- P(C): 分類Cが現れる確率(事前確率)
- P(D|C): 分類Cが与えられたときに文書Dが生成される確率(尤度)
P(D|C)は次の式のように、分類Cにワードが出現する確率の積で表されます。
ここで、分類時に必要なデータを事前に集計、計算し、永続化したものをモデルとします。
HBaseのデータ構造はは分散ソート済みマップとも呼ばれ、疎なデータを効率よく扱うことができます。ここではその構造を活かしてテーブル設計をしました。
テーブルレイアウト(イメージ)
row-key |
feature: |
|
stats: |
|
cafebabe |
スギ花粉 |
4 |
n_occurence |
16 |
|
薬 |
2 |
s_occurence |
300 |
|
|
|
n_feature |
2 |
|
|
|
s_val |
6 |
|
|
|
label |
花粉症対策 |
|
|
|
prior_prob |
0.053 |
deadbeaf |
ウィルス |
20 |
n_occurence |
10 |
|
手洗い |
3 |
s_occurence |
300 |
|
うがい |
16 |
n_feature |
3 |
|
|
|
s_val |
39 |
|
|
|
label |
インフルエンザ対策 |
|
|
|
prior_prob |
0.033 |
- row-keyはラベルのハッシュ値を分類IDとして使用
- feature:カラムファミリ
- ワードとその出現回数をKey-Valueとして格納
- stats:カラムファミリ
- prior_prob: 事前確率
- label: 分類ラベル
- s_val: 分類におけるの語彙毎の出現回数の合計
- n_occurence: 分類の出現回数(デバッグ用)
- s_occurence: 分類の総出現回数(デバッグ用)
- n_feature: 分類の語彙数(デバッグ用)
分類に必要な値以外に調査や検証に役立つ値も格納しています
実装編
テーブルの作成
CREATE 'test_model', { NAME => 'feature', VERSIONS => 1, COMPRESSION => 'LZ4', DATA_BLOCK_ENCODING => 'FAST_DIFF', 'IN_MEMORY' =>
'true' },
{ NAME => 'stats', VERSIONS => 1, COMPRESSION => 'LZ4', DATA_BLOCK_ENCODING => 'FAST_DIFF', 'IN_MEMORY' => 'true' },
{ SPLITS => [
'1000000000000000000000000000000000000000',
'2000000000000000000000000000000000000000',
'3000000000000000000000000000000000000000',
'4000000000000000000000000000000000000000',
'5000000000000000000000000000000000000000',
'6000000000000000000000000000000000000000',
'7000000000000000000000000000000000000000',
'8000000000000000000000000000000000000000',
'9000000000000000000000000000000000000000',
'a000000000000000000000000000000000000000',
'b000000000000000000000000000000000000000',
'c000000000000000000000000000000000000000',
'd000000000000000000000000000000000000000',
'e000000000000000000000000000000000000000',
'f000000000000000000000000000000000000000'
] }
- スループットを稼ぎたいのでインメモリ指定しています
- データが分散するように分割ポイントを指定しています
モデルデータの生成
Sparkを使ってデータを変換していきます。
ステップ1 グループIDの生成
文書:分類が1:Nの関係となっているのを、ここでは複数コードの集合を1つのグループとして扱うように変換し、このグループのIDを生成します。*6
val appln_group = appln_ipc.groupBy("appln_id")
.agg(concat_ws(",", sort_array(collect_set(substring($"ipc_class_symbol", 0, 4)))) as "ipc_sss",
sha1(concat_ws(",", sort_array(collect_set(substring($"ipc_class_symbol", 0, 4))))) as "group")
DataFrame APIを使用しています。
+--------+--------------+----------------------------------------+
|appln_id|ipc_sss |group |
+--------+--------------+----------------------------------------+
|153620 |A61K,A61P,C12N|947ed80b48cae17d652fdd8fb6f6de2eff130710|
|197020 |H04L,H04W |fc38214c2f0b7f1636cc0ee9206d2023537b0dcd|
|286620 |G11C |e3a1862a8f7ce681b57c4c41f711922f3b0bb490|
|455820 |H04L,H04W |fc38214c2f0b7f1636cc0ee9206d2023537b0dcd|
+--------+--------------+----------------------------------------+
この段階で、1文書1行のデータに変換されます。
ステップ2 学習データの生成
val trainingset = appln_group.filter($"appln_id" % lit(5) !== lit(0)
val instances = trainingset.as("ag")
.join(appln_abstr, "appln_id")
.select($"ag.appln_id",
explode(Util.wordcount($"appln_abstract", lit(30))) as "wc",
$"ag.group" as "group",
$"ag.ipc_sss" as "ipc")
.select($"*", $"wc"("_1").as("keyword"), $"wc"("_2").as("n"))
.drop("wc")
}
ここはデータ全体の8割を学習データとし、残りをテストデータにし、実際の学習データを作ります。*7
要約のデータをjoinしつつキーワードを抽出し行を展開しています。
*8
データイメージ
+--------+----------------------------------------+--------------+----------------------------------------------------+---+
|appln_id|group |ipc |keyword |n |
+--------+----------------------------------------+--------------+----------------------------------------------------+---+
|50020 |947ed80b48cae17d652fdd8fb6f6de2eff130710|A61K,A61P,C12N|virus |5 |
|50020 |947ed80b48cae17d652fdd8fb6f6de2eff130710|A61K,A61P,C12N|cell proliferative disorders |4 |
|50020 |947ed80b48cae17d652fdd8fb6f6de2eff130710|A61K,A61P,C12N|Ras-pathway |2 |
|454620 |e3a1862a8f7ce681b57c4c41f711922f3b0bb490|G11C |ground stage |1 |
|454620 |e3a1862a8f7ce681b57c4c41f711922f3b0bb490|G11C |process condition |1 |
|454620 |e3a1862a8f7ce681b57c4c41f711922f3b0bb490|G11C |TR1 |1 |
+--------+----------------------------------------+--------------+----------------------------------------------------+---+
ここまでで下ごしらえが完了です。次のステップではHBaseのテーブルの各カラムファミリに流しこむデータを生成していきます。
HDFS上のステージングディレクトリに一通りHFileを出力した後、最後にファイル群をHBaseのデータディレクトリに移動させてバルクロードを完了させます。
ステップ3 HFileの生成
hbase-sparkというライブラリ*9を使用してHFileを生成していきます。ここでは入力としてRDDへの変換をしています。
def createHFile(hbaseContext: HBaseContext, tableName: TableName,
rdd: RDD[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])], stagingDir: String) = {
hbaseContext.bulkLoad[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])](
rdd,
tableName,
(r) => {
r._2.map { v =>
(new KeyFamilyQualifier(r._1, v._1, v._2), v._3)
}.iterator
},
stagingDir)
}
}
val hbconf = HBaseConfiguration.create()
val hbaseContext = new HBaseContext(sc, hbconf)
val wc_by_group = instances.groupBy($"group", $"keyword").agg(sum("n")).orderBy($"group", $"keyword")
val feature_put_rdd = wc_by_group.rdd.map{ x =>
(Bytes.toBytes(x.getString(0)), Array((Bytes.toBytes("feature"), Bytes.toBytes(x.getString(1)), Bytes.toBytes(x.getLong(2))))) }
createHFile(hbaseContext, TableName.valueOf(tableName), feature_put_rdd, stagingDirClassificationModel)
出力イメージ
+----------------------------------------+----------------------------------------------------+------+
|group |keyword |sum(n)|
+----------------------------------------+----------------------------------------------------+------+
|947ed80b48cae17d652fdd8fb6f6de2eff130710|HSV |1 |
|947ed80b48cae17d652fdd8fb6f6de2eff130710|Methods |1 |
|e3a1862a8f7ce681b57c4c41f711922f3b0bb490|first output signal |3 |
|e3a1862a8f7ce681b57c4c41f711922f3b0bb490|first resistor |6 |
|e3a1862a8f7ce681b57c4c41f711922f3b0bb490|first transistor |6 |
|fc38214c2f0b7f1636cc0ee9206d2023537b0dcd|terminal |1 |
|fc38214c2f0b7f1636cc0ee9206d2023537b0dcd|traffic |3 |
|fc38214c2f0b7f1636cc0ee9206d2023537b0dcd|traffic differentiation |1 |
|fc38214c2f0b7f1636cc0ee9206d2023537b0dcd|transfer |1 |
|fc38214c2f0b7f1636cc0ee9206d2023537b0dcd|wireless LAN |1 |
+----------------------------------------+----------------------------------------------------+------+
stats:n_featureデータの生成
val n_feature_by_group = wc_by_group.groupBy($"group").agg(count("*")).orderBy($"group")
val n_feature_put_rdd = n_feature_by_group.rdd.map { x =>
(Bytes.toBytes(x.getString(0)), Array(
(Bytes.toBytes("stats"), Bytes.toBytes("n_feature"), Bytes.toBytes(x.getLong(1))))) }
createHFile(hbaseContext, TableName.valueOf(tableName), n_feature_put_rdd, stagingDirClassificationModel)
出力イメージ
+----------------------------------------+--------+
|group |count(1)|
+----------------------------------------+--------+
|947ed80b48cae17d652fdd8fb6f6de2eff130710|17 |
|e3a1862a8f7ce681b57c4c41f711922f3b0bb490|83 |
|fc38214c2f0b7f1636cc0ee9206d2023537b0dcd|70 |
+----------------------------------------+--------+
stats:s_valとlabelデータの生成
val s_val_by_group = instances.groupBy($"group", $"ipc").agg(sum("n")).orderBy($"group")
val s_val_put_rdd = s_val_by_group.rdd.map{ x =>
(Bytes.toBytes(x.getString(0)), Array(
(Bytes.toBytes("stats"), Bytes.toBytes("s_val"), Bytes.toBytes(x.getLong(2))),
(Bytes.toBytes("stats"), Bytes.toBytes("label"), Bytes.toBytes(x.getString(1))))) }
createHFile(hbaseContext, TableName.valueOf(tableName), s_val_put_rdd, stagingDirClassificationModel)
出力イメージ
+----------------------------------------+--------------+------+
|group |ipc |sum(n)|
+----------------------------------------+--------------+------+
|947ed80b48cae17d652fdd8fb6f6de2eff130710|A61K,A61P,C12N|29 |
|e3a1862a8f7ce681b57c4c41f711922f3b0bb490|G11C |263 |
|fc38214c2f0b7f1636cc0ee9206d2023537b0dcd|H04L,H04W |216 |
+----------------------------------------+--------------+------+
stats:prior_probのデータの生成
val n_occurence_by_group = trainingset.groupBy($"group").agg(count("*")).orderBy($"group")
n_occurence_by_group.persist(StorageLevel.DISK_ONLY)
val s_group_occurence = trainingset.count().toDouble
val n_occurence_put_rdd = n_occurence_by_group.rdd.map{ x =>
(Bytes.toBytes(x.getString(0)), Array(
(Bytes.toBytes("stats"), Bytes.toBytes("prior_prob"), Bytes.toBytes(x.getLong(1) / s_group_occurence)),
(Bytes.toBytes("stats"), Bytes.toBytes("s_occurence"), Bytes.toBytes(s_group_occurence)),
(Bytes.toBytes("stats"), Bytes.toBytes("n_occurence"), Bytes.toBytes(x.getLong(1))))) }
createHFile(hbaseContext, TableName.valueOf(tableName), n_occurence_put_rdd, stagingDirClassificationModel)
ステージングディレクトリに出力されたHFileをHBaseのデータディレクトリに移動させてバルクロードを完了させます。
テトリス*10のようなイメージです。
val conn = ConnectionFactory.createConnection(hbaseContext.config)
val load = new LoadIncrementalHFiles(hbaseContext.config)
load.doBulkLoad(
new Path(stagingDirClassificationModel),
conn.getAdmin,
new HTable(hbaseContext.config, TableName.valueOf(tableName)),
conn.getRegionLocator(TableName.valueOf(tableName)))
これで完成です。約100万分類のモデルができました。
分類フェーズ
学習フェーズで生成したモデルのテーブルを読み、実際に文書の分類を行うフェーズです。
HBaseの分類モデルのテーブルは1行が1分類となっており、分類対象の文書に含まれるキーワードから、どの分類から来たのかの確率をスコアとして計算し、スコアの高いものを10件ずつ出力しています。
下にコード(抜粋)とポイントを簡単に挙げます。
val appln_group = appln_ipc.groupBy("appln_id").agg(concat_ws(",", sort_array(collect_set(substring($"ipc_class_symbol", 0, 4)))) as "ipc_sss", sha1(concat_ws(",", sort_array(collect_set(substring($"ipc_class_symbol", 0, 4))))) as "group")
val testset = appln_group.filter($"appln_id" % lit(5) === lit(0)).limit(100000)
testset.persist(StorageLevel.DISK_ONLY)
val instances = testset.as("ag").join(appln_abstr, "appln_id").select($"ag.appln_id", Util.wordcount($"appln_abstract", lit(30)) as "keyword", $"ag.group" as "group", $"ag.ipc_sss" as "ipc", row_number().over(Window.partitionBy().orderBy($"appln_id")) as "row_num")
val input = instances.map{ r =>
(r.getInt(4) / batchSize, (r.getInt(0), r.getMap[String,Int](1), r.getString(2), r.getString(3)))
}.groupByKey().repartition(2048)
val hbconf = HBaseConfiguration.create()
val tableNameObject = TableName.valueOf(tableName)
val table = new HTable(hbconf, tableName)
val conn = ConnectionFactory.createConnection(hbconf)
val regionLocator: RegionLocator = conn.getRegionLocator(tableNameObject)
val startEndKeys: org.apache.hadoop.hbase.util.Pair[Array[Array[Byte]], Array[Array[Byte]]] = regionLocator.getStartEndKeys()
val range = startEndKeys.getFirst().zip(startEndKeys.getSecond())
val result = input.map{ row =>
import scala.math.log
object EntryOrdering extends Ordering[(String, String, Double)] {
def compare(a: (String, String, Double), b: (String, String, Double)) = -(a._3 compare b._3)
}
val topK = scala.collection.mutable.Map[Int, PriorityQueue[(String, String, Double)]]()
for (d <- row._2) {
topK.put(d._1, new PriorityQueue[(String, String, Double)]()(EntryOrdering))
}
scala.util.Random.shuffle(range.toList).foreach { r =>
val hbconf = HBaseConfiguration.create()
val table = new HTable(hbconf, tableName)
val scan = new Scan(r._1, r._2)
scan.setCaching(100)
val scanner = table.getScanner(scan)
for (r: Result <- scanner) {
val classificationId = Bytes.toString(r.getRow())
val classificationStats = r.getFamilyMap(Bytes.toBytes("stats"))
val classificationFeature = r.getFamilyMap(Bytes.toBytes("feature"))
val label = Bytes.toString(classificationStats.getOrElse(Bytes.toBytes("label"), Bytes.toBytes("N/A")))
for (d <- row._2) {
var score = log(Bytes.toDouble(classificationStats.getOrElse(Bytes.toBytes("prior_prob"), Bytes.toBytes(0.0))))
val features = d._2
for (f <- features) {
val keyword = f._1
val n = f._2
val numerator = Bytes.toLong(classificationFeature.getOrElse(Bytes.toBytes(keyword), Bytes.toBytes(0L))).toDouble + 1.0
val denominator = Bytes.toLong(classificationStats.getOrElse(Bytes.toBytes("s_val"), Bytes.toBytes(1L))) + 70000000.0
val wordProb: Double = numerator / denominator
score += n * log(wordProb)
}
val q = topK.get(d._1).get
if (q.size < 10) {
q.enqueue((classificationId, label, score))
} else {
val smallest = q.dequeue()
if (score > smallest._3) {
q.enqueue((classificationId, label, score))
} else {
q.enqueue(smallest)
}
}
}
}
scanner.close()
table.close()
}
val z = row._2.map{ d =>
(d._1, d._2, d._3, d._4, topK.get(d._1).get.dequeueAll)
}
z
}.flatMap(x => x)
result.persist(StorageLevel.DISK_ONLY)
result.map{ r =>
val keywords = r._2.toList.sortBy(-_._2).map{case (a, b) => s"${a}:${b}"}.mkString(",")
val result = r._5.reverse.map(_.productIterator.mkString(":")).mkString("|")
s"${r._1}\t${keywords}\t${r._3}\t${r._4}\t${result}"
}.saveAsTextFile("model_test.tsv")
ポイント
- 分類対象の件数に応じた計算リソースと時間が必要
- テーブルスキャンが分類対象の文書数分必要なので、バッチ化してHBase側の負荷とネットワークIOを抑制
- 上位10分類を出力するため優先度付きキューを使用
- レンジスキャンをシャッフルして実行することでHBaseのブロックキャッシュヒット率を上げることで速度を稼ぐ
*11
実験結果
テストデータのセットから10万件を分類器にかけてみて、出力結果の評価をしました。
文書毎に返された上位10件の分類ラベルに対して、次の6通りを集計しました。
- 最上位で正解ラベルと完全一致する率
- 最上位で正解ラベルの構成要素の集合との共通部分が存在する率
- 5位以内で正解ラベルと完全一致する率
- 5位以内で正解ラベルの構成要素の集合との共通部分が存在する率
- 10位以内で正解ラベルと完全一致する率
- 10位以内で正解ラベルの構成要素の集合との共通部分が存在する率
それぞれ以下の数字でした。
- 24397/100000 0.2440
- 45486/100000 0.4549
- 40696/100000 0.4070
- 66264/100000 0.6626
- 46965/100000 0.4697
- 73903/100000 0.7390
考察と課題
- データの設計と実装上の工夫により比較的規模の大きなデータに対してジョブを完遂することができた
- 分類器の精度としては今ひとつな値だが、約100万の中から10個選んだものとしてみるとランダムよりは良さそう(ラベル付けが目的かマッチングが目的かで異なる解釈)というレベル
- 以下の要素についてもうすこし考える余地がありそう
- 分類の数と粒度
- テキストの分量と質
- 特徴(キーワード)抽出の方法
- 正解ラベルよりもよいラベルを引き当てた可能性についてはどう評価するか
まとめ
- SparkとHBaseを使ってスケーラブルな単純ベイズの文書分類器を実装しました。
- 数千万件の特許要約テキストと分類コードのデータから約百万通りの分類器を学習させてみました。
- 実装面、理論面の両方で更なる改良の余地がありそうです。