Spark を使った実験データ解析
これまでC/C++でCERN ROOTを使って実験データの解析をしてきた筆者が、Sparkで解析をやってみてよく行う操作の例をいくつか挙げてみる
全体的にSQL関数を上手く使えるとROOTでは面倒だった操作がとても楽にできることが分かった。
Sparkの設定については過去の記事を参照
検出器のキャリブレーション
以前の記事でも使ったtest.parquetファイルには
>>> df = spark.read.parquet("hdfs:///test/test.parquet")
>>> df.show(10)
+---+--------------------+-------+
|fID| fTiming|fCharge|
+---+--------------------+-------+
| 1|1.114106081875542E14| 5742|
| 0|1.114106105298723...| 155|
| 0|1.114106112329469...| 1703|
| 1|1.114106112329506...| 958|
| 0|1.114106120337208...| 1818|
| 0|1.114106125184000...| 178|
| 1|1.114106132377145E14| 105|
| 1|1.114106144537507...| 794|
| 0|1.114106417419858E14| 4164|
| 0|1.114106424617431...| 200|
+---+--------------------+-------+
only showing top 10 rows
2台の検出器でイベントが観測された時刻(fTiming)と信号波高(fCharge)のリストが記録されている。
ROOTのTTreeから切り出してきたものだ。fChargeは生のADCの値で整数値なので、まずは0から1の乱数を足してfloatにする。
>>> df = df.withColumn("fCharge", F.col("fCharge") + F.rand())
>>> df.show(10)
+---+--------------------+------------------+
|fID| fTiming| fCharge|
+---+--------------------+------------------+
| 1|1.114106081875542E14| 5742.854963218908|
| 0|1.114106105298723...| 155.9766192554091|
| 0|1.114106112329469...| 1703.609628179982|
| 1|1.114106112329506...| 958.8928912853729|
| 0|1.114106120337208...|1818.0142345787904|
| 0|1.114106125184000...|178.70001054164862|
| 1|1.114106132377145E14|105.33545583166983|
| 1|1.114106144537507...| 794.3959406237444|
| 0|1.114106417419858E14|4164.2946565839875|
| 0|1.114106424617431...| 200.6593168109559|
+---+--------------------+------------------+
only showing top 10 rows
次に、検出器0と1 (fID)のfChargeにそれぞれのキャリブレーションパラメータを適用し、
fCalibrated = p0 + p1 * fCharge
となるようにする。
まずは、キャリブレーションパラメータのCSVを用意する。
id,p0,p1
0,-2.9456,0.6775
1,-2.3264,0.6785
一行目がヘッダーになっているCSVからデータフレームを作るには以下のようにする。
>>> param = spark.read.csv("hdfs:///test/param.csv", header=True, inferSchema=True)
+---+-------+------+
| id| p0| p1|
+---+-------+------+
| 0|-2.9456|0.6775|
| 1|-2.3264|0.6785|
+---+-------+------+
次にdfとparamをjoin()する。このときの条件に (df.fID == param.id) を指定することで、fID=0の行にはid=0の値、fID=1の行にはid=1の値が入る。
>>> df_with_param = df.join(param, df.fID == param.id)
>>> df_with_param.show(10)
+---+--------------------+------------------+---+-------+------+
|fID| fTiming| fCharge| id| p0| p1|
+---+--------------------+------------------+---+-------+------+
| 1|1.114106081875542E14| 5742.546022498016| 1|-2.3264|0.6785|
| 0|1.114106105298723...| 155.7764090911184| 0|-2.9456|0.6775|
| 0|1.114106112329469...|1703.0362556833722| 0|-2.9456|0.6775|
| 1|1.114106112329506...| 958.9431829283574| 1|-2.3264|0.6785|
| 0|1.114106120337208...| 1818.836486345097| 0|-2.9456|0.6775|
| 0|1.114106125184000...|178.91403731560973| 0|-2.9456|0.6775|
| 1|1.114106132377145E14|105.81584950245579| 1|-2.3264|0.6785|
| 1|1.114106144537507...| 794.21126862575| 1|-2.3264|0.6785|
| 0|1.114106417419858E14| 4164.679102345666| 0|-2.9456|0.6775|
| 0|1.114106424617431...|200.06057931639475| 0|-2.9456|0.6775|
+---+--------------------+------------------+---+-------+------+
only showing top 10 rows
あとは、p0 + p1 * fCharge のカラムをwithColumn()で追加すればよい。
>>> df_calibrated = df_with_param.withColumn("fCalibrated", df_with_param.p0 + df_with_param.p1 * df_with_param.fCharge)
>>> df_calibrated.show(10)
+---+--------------------+------------------+---+-------+------+------------------+
|fID| fTiming| fCharge| id| p0| p1| fCalibrated|
+---+--------------------+------------------+---+-------+------+------------------+
| 1|1.114106081875542E14| 5742.546022498016| 1|-2.3264|0.6785|3893.9910762649038|
| 0|1.114106105298723...| 155.7764090911184| 0|-2.9456|0.6775|102.59291715923273|
| 0|1.114106112329469...|1703.0362556833722| 0|-2.9456|0.6775|1150.8614632254846|
| 1|1.114106112329506...| 958.9431829283574| 1|-2.3264|0.6785| 648.3165496168905|
| 0|1.114106120337208...| 1818.836486345097| 0|-2.9456|0.6775|1229.3161194988031|
| 0|1.114106125184000...|178.91403731560973| 0|-2.9456|0.6775| 118.2686602813256|
| 1|1.114106132377145E14|105.81584950245579| 1|-2.3264|0.6785| 69.46965388741624|
| 1|1.114106144537507...| 794.21126862575| 1|-2.3264|0.6785| 536.5459457625714|
| 0|1.114106417419858E14| 4164.679102345666| 0|-2.9456|0.6775| 2818.624491839189|
| 0|1.114106424617431...|200.06057931639475| 0|-2.9456|0.6775|132.59544248685742|
+---+--------------------+------------------+---+-------+------+------------------+
only showing top 10 rows
ちなみに、データフレームの操作は遅延評価なので、今の段階ではshow(10)した最初の10行にしか計算は行われていない。
ファイルに書き込む用に不要な列を除いたデータフレームを定義する。
>>> df_calibrated = df_calibrated.select("fID", "fTiming", "fCharge", "fCalibrated")
>>> df_calibrated.show(10)
+---+--------------------+------------------+------------------+
|fID| fTiming| fCharge| fCalibrated|
+---+--------------------+------------------+------------------+
| 1|1.114106081875542E14| 5742.546022498016|3893.9910762649038|
| 0|1.114106105298723...| 155.7764090911184|102.59291715923273|
| 0|1.114106112329469...|1703.0362556833722|1150.8614632254846|
| 1|1.114106112329506...| 958.9431829283574| 648.3165496168905|
| 0|1.114106120337208...| 1818.836486345097|1229.3161194988031|
| 0|1.114106125184000...|178.91403731560973| 118.2686602813256|
| 1|1.114106132377145E14|105.81584950245579| 69.46965388741624|
| 1|1.114106144537507...| 794.21126862575| 536.5459457625714|
| 0|1.114106417419858E14| 4164.679102345666| 2818.624491839189|
| 0|1.114106424617431...|200.06057931639475|132.59544248685742|
+---+--------------------+------------------+------------------+
only showing top 10 rows
ファイルに保存
df_calibrated.write.parquet("hdfs:///test/calibrated.parquet")
この時点で全ての行に操作が行われる。
イベントビルド
今まで見てきたデータフレームは、一行に一つのヒット情報が入っていた。検出器0と1が両方同時にイベントを観測した場合の相関をとりたい時には、検出器0と1の時間差が一定幅以内だったら同じイベントとみなすという作業が必要となる。これをイベントビルド呼んでいるが、ROOTでこれを行うには面倒なコーディングが必要だった。データフレームではWindowを使って簡単にできるのでやってみる。
まず、Sliding Windowというデータフレームを行ごとにある列の値のレンジでなめていく関数を定義する。
以下のWindowは自分よりfTimingが1000大きい行までを取り込む。
>>> from pyspark.sql import Window
>>> window_spec = Window.orderBy("fTiming").rangeBetween(0,1000)
次に、Window毎にfID, fTiming, fCalibratedのStructureのリストが入った列"hits"を定義する。
>>> df_evb = df.withColumn("hits",F.collect_list(F.struct("fID","fTiming","fCalibrated")).over(window_spec)).select("fTiming","hits")
これで、1000 ns以内のイベントがリストに入った。ちなみに、show() に trancate=Falseを渡すと行が省略なく全て表示される。
>>> df_evb.show(10,truncate=False)
24/08/27 18:34:25 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
+---------------------+-----------------------------------------------------------------------------------------------+
|fTiming |hits |
+---------------------+-----------------------------------------------------------------------------------------------+
|1.114106081875542E14 |[{1, 1.114106081875542E14, 3893.9910762649038}] |
|1.1141061052987234E14|[{0, 1.1141061052987234E14, 102.59291715923273}] |
|1.1141061123294692E14|[{0, 1.1141061123294692E14, 1150.8614632254846}, {1, 1.1141061123295069E14, 648.3165496168905}]|
|1.1141061123295069E14|[{1, 1.1141061123295069E14, 648.3165496168905}] |
|1.1141061203372088E14|[{0, 1.1141061203372088E14, 1229.3161194988031}] |
|1.1141061251840005E14|[{0, 1.1141061251840005E14, 118.2686602813256}] |
|1.114106132377145E14 |[{1, 1.114106132377145E14, 69.46965388741624}] |
|1.1141061445375077E14|[{1, 1.1141061445375077E14, 536.5459457625714}, {0, 1.1141061445379472E14, 683.0947309364051}] |
|1.1141061445379472E14|[{0, 1.1141061445379472E14, 683.0947309364051}] |
|1.11410614970118E14 |[{1, 1.11410614970118E14, 1500.4333400659127}] |
+---------------------+-----------------------------------------------------------------------------------------------+
only showing top 10 rows
ただし、これだと全ての行が保持されるので、同じイベントが複数回アレイ内に現れる。そこで、一つ前のrowのfTimingを詰めた列を定義して、1000 ns以内に起こったものは排除してしまおう。
一つ前の行を詰めるにはlag()関数を使う。引数はカラム、オフセット、Nullの場合のデフォルト値。
>>> window = Window.orderBy("fTiming")
>>> df_prev = df_evb.withColumn("prev_timing", F.lag("fTiming",1,0).over(window))
+--------------------+--------------------+--------------------+
| fTiming| hits| prev_timing|
+--------------------+--------------------+--------------------+
|1.114106081875542E14|[{1, 1.1141060818...| 0.0|
|1.114106105298723...|[{0, 1.1141061052...|1.114106081875542E14|
|1.114106112329469...|[{0, 1.1141061123...|1.114106105298723...|
|1.114106112329506...|[{1, 1.1141061123...|1.114106112329469...|
|1.114106120337208...|[{0, 1.1141061203...|1.114106112329506...|
|1.114106125184000...|[{0, 1.1141061251...|1.114106120337208...|
|1.114106132377145E14|[{1, 1.1141061323...|1.114106125184000...|
|1.114106144537507...|[{1, 1.1141061445...|1.114106132377145E14|
|1.114106144537947...|[{0, 1.1141061445...|1.114106144537507...|
| 1.11410614970118E14|[{1, 1.1141061497...|1.114106144537947...|
+--------------------+--------------------+--------------------+
only showing top 10 rows
prev_timingカラムに一つ前の行のfTimingが入り、最初の行はデフォルト値として渡した0が入っている。
あとは時間差でフィルタリングすればOK
>>> df_filtered = df_prev.filter(F.expr("fTiming - prev_timing > 1000"))
>>> df_filtered.select("hits").show(10, truncate=False)
+-----------------------------------------------------------------------------------------------+
|hits |
+-----------------------------------------------------------------------------------------------+
|[{1, 1.114106081875542E14, 3893.9910762649038}] |
|[{0, 1.1141061052987234E14, 102.59291715923273}] |
|[{0, 1.1141061123294692E14, 1150.8614632254846}, {1, 1.1141061123295069E14, 648.3165496168905}]|
|[{0, 1.1141061203372088E14, 1229.3161194988031}] |
|[{0, 1.1141061251840005E14, 118.2686602813256}] |
|[{1, 1.114106132377145E14, 69.46965388741624}] |
|[{1, 1.1141061445375077E14, 536.5459457625714}, {0, 1.1141061445379472E14, 683.0947309364051}] |
|[{1, 1.11410614970118E14, 1500.4333400659127}] |
|[{1, 1.114106157838E14, 4728.992252793331}] |
|[{0, 1.1141061660129662E14, 1098.1495110454523}] |
+-----------------------------------------------------------------------------------------------+
only showing top 10 rows
これでイベントの重複はなくなった。
今回は検出器0と1のcharge情報の相関を見たいので、charge0とcharge1のカラムを作る。
hitsカラムはアレイが入っているので、そのアレイに対して条件を掛ける場合、Spark SQL の filter(array, lambda)関数を使う。expr()内で使うとアレイに対してフィルタ操作ができる。Spark DataFrame API の filter() 関数とは違うので注意
>>> df_charge = df_filtered.withColumn("charge0",F.expr("filter(hits, x -> x.fID=0)"))
>>> df_charge = df_charge.withColumn("charge1",F.expr("filter(hits, x -> x.fID=1)"))
charge0にはStructureではなくfCalibratedの値だけ入っていた方が使い易いのでtransform()で置き換える。
>>> df_charge = df_charge.withColumn("charge0",F.expr("transform(charge0, x -> x.fCalibrated)"))
>>> df_charge = df_charge.withColumn("charge1",F.expr("transform(charge1, x -> x.fCalibrated)"))
ヒット数が2以上のイベントを選んで必要なカラムだけ表示
>>> df_charge.filter(F.expr("size(hits)>1")).select("fTiming","charge0","charge1").show(10)
+--------------------+--------------------+--------------------+
| fTiming| charge0| charge1|
+--------------------+--------------------+--------------------+
|1.114106112329469...|[1150.8614632254846]| [648.3165496168905]|
|1.114106144537507...| [683.0947309364051]| [536.5459457625714]|
|1.114106170728324...|[113.68025553823864]| [85.078082202156]|
| 1.1141061750287E14|[1124.0347392818417]| [599.6530719662841]|
|1.114106194005185...| [858.0833248841577]|[1497.8838829750975]|
|1.114106229145187...| [3502.54560390515]| [4360.022655290631]|
|1.114106291164085...|[1211.1048336276579]|[1686.1653273619022]|
|1.114106433824964...|[3720.4027528439847]|[3881.0775926168853]|
|1.114106434624439...|[1034.1065929671859]| [416.1863023988994]|
| 1.11410645384022E14| [5231.815924090422]|[3982.0608745084955]|
+--------------------+--------------------+--------------------+
only showing top 10 rows
これで検出器0と1のchargeの相関がリストされた。
デコードされた生データのマッピング
以前の記事で説明したように、DAQの生データをデコードした段階ではデータは検出器ごとではなく、信号をデジタイズしたモジュールが所属するクレート番号、モジュール番号、モジュール上のチャンネル番号などでタグ付けされている。
RIDFフォーマットの場合、信号を特定するためにはdevice, focal, detector, geometry, channelの5つのid情報が必要である。これを検出器の種類ごとに通し番号を振りなおす操作をマッピングと呼んでいる。この操作も、join()を使えば簡単に実装できた。
例えばデコードされた状態のデータがこちら
(ちなみに大きなファイルを開くときはdriver memoryがデフォルトだと足りなくなることがあるので pyspark --driver-memory 8g などとオプションを付けて起動するとよい。)
>>> df = spark.read.parquet("hdfs:///test/calib1029.parquet")
>>> df.show(10, truncate=100)
+--------+---------+--------------+----------------------------------------------------------------------------------------------------+
|event_id|runnumber| ts| segdata|
+--------+---------+--------------+----------------------------------------------------------------------------------------------------+
| 0| 1029|20507102988732|[{0, 63, 24, 1, [{7, 0, 6190, 0}, {7, 0, 6736, 1}]}, {0, 63, 24, 1, [{8, 0, 6046, 0}, {8, 0, 6598...|
| 1| 1029|20507103048025|[{0, 63, 24, 1, [{7, 0, 6243, 0}, {7, 0, 6792, 1}]}, {0, 63, 24, 1, [{8, 0, 6278, 0}, {8, 0, 6831...|
| 2| 1029|20507103113109|[{0, 63, 24, 1, [{7, 0, 6256, 0}, {7, 0, 6807, 1}]}, {0, 63, 24, 1, [{8, 0, 6205, 0}, {8, 0, 6761...|
| 3| 1029|20507103197795|[{0, 63, 24, 1, [{7, 0, 6106, 0}, {7, 0, 6657, 1}]}, {0, 63, 24, 1, [{8, 0, 6200, 0}, {8, 0, 6757...|
| 4| 1029|20507103292457|[{0, 63, 24, 1, [{7, 0, 6164, 0}, {7, 0, 6715, 1}]}, {0, 63, 24, 1, [{8, 0, 6136, 0}, {8, 0, 6690...|
| 5| 1029|20507103358210|[{0, 63, 24, 1, [{7, 0, 6036, 0}, {7, 0, 6590, 1}]}, {0, 63, 24, 1, [{8, 0, 6176, 0}, {8, 0, 6737...|
| 6| 1029|20507103435349|[{0, 63, 24, 1, [{7, 0, 6030, 0}, {7, 0, 6588, 1}]}, {0, 63, 24, 1, [{8, 0, 6071, 0}, {8, 0, 6636...|
| 7| 1029|20507103483913|[{0, 63, 24, 1, [{7, 0, 6223, 0}, {7, 0, 6778, 1}]}, {0, 63, 24, 1, [{8, 0, 6200, 0}, {8, 0, 6760...|
| 8| 1029|20507103566635|[{0, 63, 24, 1, [{7, 0, 6027, 0}, {7, 0, 6580, 1}]}, {0, 63, 24, 1, [{8, 0, 6151, 0}, {8, 0, 6709...|
| 9| 1029|20507103626476|[{0, 63, 24, 1, [{7, 0, 6215, 0}, {7, 0, 6768, 1}]}, {0, 63, 24, 1, [{8, 0, 6263, 0}, {8, 0, 6820...|
+--------+---------+--------------+----------------------------------------------------------------------------------------------------+
only showing top 10 rows
segdataカラムはセグメント(1VMEモジュールに相当)データのアレイで、アレイの中身はdevice, focal, module, detector, hits のStructureである。hitsはさらにStructureのアレイで、中身はgeometry, channel, value, edge の4つである。これを、検出器ごとにid=Xのvalue=YYY、といったように抽出したい。
マッピング情報のCSVを用意
ここではまず、sr91 という名前の検出器について、X方向の読み出しチャンネルそれぞれがどのVMEモジュールのどのチャンネルに入っていたかの情報をCSVとして用意する。
id,dev,fp,det,geo,ch
0, 11,9,31,0,63
1, 11,9,31,0,62
2, 11,9,31,0,61
3, 11,9,31,0,60
4, 11,9,31,0,47
5, 11,9,31,0,46
6, 11,9,31,0,45
7, 11,9,31,0,44
8, 11,9,31,0,59
9, 11,9,31,0,58
...(以下略)
アレイをexplodeする
データがアレイに入ったままだと扱いづらいので、まずSpark SQLのexplode()関数を使って行に展開する。
>>> exp_df = df.withColumn("ex_segs", F.explode("segdata"))
>>> exp_df.show(10)
+--------+---------+--------------+--------------------+--------------------+
|event_id|runnumber| ts| segdata| ex_segs|
+--------+---------+--------------+--------------------+--------------------+
| 0| 1029|20507102988732|[{0, 63, 24, 1, [...|{0, 63, 24, 1, [{...|
| 0| 1029|20507102988732|[{0, 63, 24, 1, [...|{0, 63, 24, 1, [{...|
| 0| 1029|20507102988732|[{0, 63, 24, 1, [...|{0, 63, 8, 60, [{...|
| 0| 1029|20507102988732|[{0, 63, 24, 1, [...|{11, 63, 25, 31, []}|
| 0| 1029|20507102988732|[{0, 63, 24, 1, [...|{11, 63, 36, 63, ...|
| 0| 1029|20507102988732|[{0, 63, 24, 1, [...|{0, 3, 8, 60, [{0...|
| 0| 1029|20507102988732|[{0, 63, 24, 1, [...|{11, 3, 25, 31, [...|
| 0| 1029|20507102988732|[{0, 63, 24, 1, [...|{11, 7, 24, 31, [...|
| 0| 1029|20507102988732|[{0, 63, 24, 1, [...|{11, 8, 24, 31, [...|
| 0| 1029|20507102988732|[{0, 63, 24, 1, [...|{11, 9, 24, 31, [...|
+--------+---------+--------------+--------------------+--------------------+
少々分かりづらいが、segdataのアレイが外れて複数行に展開されているのが分かる。(ex_segs列)
続いてselect()でStructureを列に展開する。
>>> exp_df = exp_df.select("event_id", "ex_segs.dev", "ex_segs.fp", "ex_segs.det", "ex_segs.hits")
>>> exp_df.show(10)
+--------+---+---+---+--------------------+
|event_id|dev| fp|det| hits|
+--------+---+---+---+--------------------+
| 0| 0| 63| 1|[{7, 0, 6190, 0},...|
| 0| 0| 63| 1|[{8, 0, 6046, 0},...|
| 0| 0| 63| 60| [{0, 0, 188, -1}]|
| 0| 11| 63| 31| []|
| 0| 11| 63| 63|[{0, 0, 1, -1}, {...|
| 0| 0| 3| 60| [{0, 0, 0, -1}]|
| 0| 11| 3| 31|[{0, 14, 109354, ...|
| 0| 11| 7| 31|[{1, 12, 102193, ...|
| 0| 11| 8| 31|[{1, 0, 102262, 0...|
| 0| 11| 9| 31|[{0, 1, 31266, 0}...|
+--------+---+---+---+--------------------+
only showing top 10 rows
hitsアレイについても同様に展開
>>> exp_df = exp_df.withColumn("ex_hits", F.explode("hits"))
>>> exp_df = exp_df.select("event_id", "dev", "fp", "det", "ex_hits.geo", "ex_hits.ch", "ex_hits.value")
>>> exp_df.show(10)
+--------+---+---+---+---+---+-----+
|event_id|dev| fp|det|geo| ch|value|
+--------+---+---+---+---+---+-----+
| 0| 0| 63| 1| 7| 0| 6190|
| 0| 0| 63| 1| 7| 0| 6736|
| 0| 0| 63| 1| 8| 0| 6046|
| 0| 0| 63| 1| 8| 0| 6598|
| 0| 0| 63| 60| 0| 0| 188|
| 0| 11| 63| 63| 0| 0| 1|
| 0| 11| 63| 63| 0| 1| 1|
| 0| 11| 63| 63| 0| 2|18838|
| 0| 11| 63| 63| 0| 3| 1884|
| 0| 11| 63| 63| 0| 4| 0|
+--------+---+---+---+---+---+-----+
only showing top 10 rows
これで一行一ヒットに展開された。後でevent_idでアレイに戻せるようにevent_id列は残しておく。
join()
join()は二つのデータフレームを条件(on
引数)でくっつける操作である。CSVファイルからデータフレームを作って各idが全て一致するという条件でjoinする。
>>> map_df = spark.read.csv("hdfs:///test/sr91_x.csv", header=True, inferSchema=True)
>>> join_df = exp_df.join(map_df, (exp_df.dev == map_df.dev) & (exp_df.fp == map_df.fp) & (exp_df.det == map_df.det) & (exp_df.geo == map_df.geo) & (exp_df.ch == map_df.ch), "i
nner")
>>> join_df.show(10)
+--------+---+---+---+---+---+-----+----+----+---+---+---+---+
|event_id|dev| fp|det|geo| ch|value| id| dev| fp|det|geo| ch|
+--------+---+---+---+---+---+-----+----+----+---+---+---+---+
| 1| 11| 9| 31| 0| 71|31029|52.0|11.0| 9| 31| 0| 71|
| 1| 11| 9| 31| 0| 84|31018|51.0|11.0| 9| 31| 0| 84|
| 1| 11| 9| 31| 0| 71|31311|52.0|11.0| 9| 31| 0| 71|
| 1| 11| 9| 31| 0| 85|31084|50.0|11.0| 9| 31| 0| 85|
| 1| 11| 9| 31| 0| 84|31332|51.0|11.0| 9| 31| 0| 84|
| 1| 11| 9| 31| 0| 85|31232|50.0|11.0| 9| 31| 0| 85|
| 2| 11| 9| 31| 0| 72|30983|47.0|11.0| 9| 31| 0| 72|
| 2| 11| 9| 31| 0| 87|31044|48.0|11.0| 9| 31| 0| 87|
| 2| 11| 9| 31| 0| 73|30982|46.0|11.0| 9| 31| 0| 73|
| 2| 11| 9| 31| 0| 87|31210|48.0|11.0| 9| 31| 0| 87|
+--------+---+---+---+---+---+-----+----+----+---+---+---+---+
only showing top 10 rows
valueカラムまでが元のデータフレームの列でそれより右がjoinされたCSVの列。
"inner"は両方のデータフレームに存在する行のみ出力されるオプション。今回はsr91検出器のXだけのデータフレームを作るので"inner"にする。ちなみに"left"だと左のデータフレームの行は全て残り、"outer"だと両方の行が全て残る。
必要な列をセレクト
>>> join_df = join_df.select("event_id","id","value")
>>> join_df.show(10)
+--------+----+-----+
|event_id| id|value|
+--------+----+-----+
| 1|52.0|31029|
| 1|51.0|31018|
| 1|52.0|31311|
| 1|50.0|31084|
| 1|51.0|31332|
| 1|50.0|31232|
| 2|47.0|30983|
| 2|48.0|31044|
| 2|46.0|30982|
| 2|48.0|31210|
+--------+----+-----+
only showing top 10 rows
CSVから読み込んだidカラムがfloatになっているのでintegerにキャスト
>>> join_df = join_df.withColumn("id", F.col("id").cast("integer"))
>>> join_df.show(10)
+--------+---+-----+
|event_id| id|value|
+--------+---+-----+
| 1| 52|31029|
| 1| 51|31018|
| 1| 52|31311|
| 1| 50|31084|
| 1| 51|31332|
| 1| 50|31232|
| 2| 47|30983|
| 2| 48|31044|
| 2| 46|30982|
| 2| 48|31210|
+--------+---+-----+
only showing top 10 rows
最後に、必要であればヒットごとにexplodeしたデータフレームをevent_idごとにまとめなおす。これにはgroupBy()関数(DataFrame API)とcollect_list()関数(Spark SQL)を使う。groupByは引数で渡したカラムの値が同じもの同士をグルーピングして、何らかの操作を行うために使われる。返り値はDataFrameGroupByオブジェクトなので、続けて何らかの関数を呼ぶ。今回はaggrigation(collect_list())でグループごとにアレイに詰める。
>>> event_df = join_df.groupBy("event_id").agg(F.collect_list("id").alias("sr91x_id"), F.collect_list("value").alias("sr91x_value"))
>>> event_df.show(10)
+--------+--------------------+--------------------+
|event_id| sr91x_id| sr91x_value|
+--------+--------------------+--------------------+
| 21|[47, 49, 46, 48, ...|[31121, 31193, 31...|
| 55|[55, 56, 54, 56, ...|[30965, 30986, 31...|
| 59|[44, 43, 44, 42, ...|[31097, 31059, 31...|
| 176|[53, 51, 52, 50, ...|[31025, 30907, 30...|
| 182|[55, 56, 54, 56, ...|[31125, 31161, 31...|
| 189|[31, 31, 34, 33, ...|[30963, 31111, 30...|
| 215|[60, 47, 49, 60, ...|[127593, 31242, 3...|
| 218|[55, 54, 53, 52, ...|[30996, 30959, 30...|
| 229|[55, 56, 54, 57, ...|[31004, 31009, 31...|
| 269|[55, 47, 57, 54, ...|[31021, 96226, 31...|
+--------+--------------------+--------------------+
only showing top 10 rows
これで、イベント毎にsr91xのヒットのアレイが作られた。
ちなみに分散処理で行の順番が崩れるので、並び変えたければorderBy()を使う。
>>> event_df.orderBy("event_id").show(10)
+--------+--------------------+--------------------+
|event_id| sr91x_id| sr91x_value|
+--------+--------------------+--------------------+
| 1|[52, 51, 52, 50, ...|[31029, 31018, 31...|
| 2|[47, 48, 46, 48, ...|[30983, 31044, 30...|
| 3|[47, 49, 46, 48, ...|[31330, 31285, 31...|
| 7|[55, 54, 53, 52, ...|[31013, 30949, 30...|
| 9|[54, 53, 52, 54, ...|[30932, 30901, 30...|
| 10|[55, 56, 54, 56, ...|[31152, 31211, 31...|
| 11|[59, 58, 57, 59, ...|[30914, 30884, 30...|
| 15|[47, 49, 47, 48, ...|[30957, 30987, 31...|
| 17|[47, 50, 47, 49, ...|[31271, 31233, 31...|
| 21|[47, 49, 46, 48, ...|[31121, 31193, 31...|
+--------+--------------------+--------------------+
また、もとのデータフレームから"event_id"をセレクトして"left" join()すれば空の行も復活する。
>>> join_df = df.select("event_id").join(event_df, "event_id", "left").orderBy("event_id")
>>> join_df.show(10)
+--------+--------------------+--------------------+
|event_id| sr91x_id| sr91x_value|
+--------+--------------------+--------------------+
| 0| NULL| NULL|
| 1|[52, 51, 52, 50, ...|[31029, 31018, 31...|
| 2|[47, 48, 46, 48, ...|[30983, 31044, 30...|
| 3|[47, 49, 46, 48, ...|[31330, 31285, 31...|
| 4| NULL| NULL|
| 5| NULL| NULL|
| 6| NULL| NULL|
| 7|[55, 54, 53, 52, ...|[31013, 30949, 30...|
| 8| NULL| NULL|
| 9|[54, 53, 52, 54, ...|[30932, 30901, 30...|
+--------+--------------------+--------------------+
only showing top 10 rows
最終的なマッピングのコードがこちら
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
# Initialize Spark session
spark = SparkSession.builder \
.appName("testApp") \
.config("spark.driver.memory", "8g") \
.config("spark.executor.memory", "10g") \
.getOrCreate()
# Read decoded rawdata file
df = spark.read.parquet("hdfs:///test/calib1029.parquet")
# Explode arrays
exp_df = df.withColumn("ex_segs", F.explode("segdata"))
exp_df = exp_df.select("event_id", "ex_segs.dev", "ex_segs.fp", "ex_segs.det", "ex_segs.hits")
exp_df = exp_df.withColumn("ex_hits", F.explode("hits"))
exp_df = exp_df.select("event_id", "dev", "fp", "det", "ex_hits.geo", "ex_hits.ch", "ex_hits.value")
# Join with mapping DataFrame
map_df = spark.read.csv("hdfs:///test/sr91_x.csv", header=True, inferSchema=True)
join_df = exp_df.join(map_df,
(exp_df.dev == map_df.dev) & (exp_df.fp == map_df.fp) & \
(exp_df.det == map_df.det) & (exp_df.geo == map_df.geo) & \
(exp_df.ch == map_df.ch),
"inner")
# Select necessary columns
join_df = join_df.select("event_id","id","value")
join_df = join_df.withColumn("id", F.col("id").cast("integer"))
# Group by evnet_id an make arrays
event_df = join_df.groupBy("event_id").agg(
F.collect_list("id").alias("sr91x_id"),
F.collect_list("value").alias("sr91x_value"))
# Reorder and save result to file
join_df = df.select("event_id").join(event_df, "event_id", "left").orderBy("event_id")
join_df.write.mode("overwrite").parquet("hdfs:///test/calib1029_sr91x.parquet")
ジョブをクラスタに投げてみる。
spark-submit --master spark://hostname:7077 test2.py
executor memory を10GBにしたので、対応しているworkerの計48Coreが使われている。
おわりに
これまでCERN ROOTではC++の面倒なコーティングが必要だった操作がかなり簡潔に記述できるようになったと思う。SQL関数の知識がないと始めはとっつきにくいが慣れれば基本的にはSparkでほとんどの解析が完結しそうに思えてきた。
Discussion