👌

前処理大全をPysparkで試みる(8章)

2022/08/31に公開

はじめに

こちらの記事のつづき
https://zenn.dev/tjjj/articles/1fc22bf9fe3160

前提

from pyspark.sql import functions as F
from pyspark.sql.window import Window as W

8章:数値型

8-1:数値型への変換

  • Sparkではこちら[1]の型の種類がある。
  • castの引数にはdatatypeでの指定だけでなく、文字列の指定も可能である。datatypeと文字列のマッピングはこちら[2]である。
from pyspark.sql.types import *

df = spark.createDataFrame([(40000/3, 40000/3, 40000/3, 40000/3, 40000/3, 40000/3, 40000/3)], [
                           'ByteType', 'ShortType', 'IntegerType', 'LongType', 'FloatType', 'DoubleType', 'DecimalType'])

df.show()
df.printSchema()
+------------------+------------------+------------------+------------------+------------------+------------------+------------------+
|          ByteType|         ShortType|       IntegerType|          LongType|         FloatType|        DoubleType|       DecimalType|
+------------------+------------------+------------------+------------------+------------------+------------------+------------------+
|13333.333333333334|13333.333333333334|13333.333333333334|13333.333333333334|13333.333333333334|13333.333333333334|13333.333333333334|
+------------------+------------------+------------------+------------------+------------------+------------------+------------------+

root
 |-- ByteType: double (nullable = true)
 |-- ShortType: double (nullable = true)
 |-- IntegerType: double (nullable = true)
 |-- LongType: double (nullable = true)
 |-- FloatType: double (nullable = true)
 |-- DoubleType: double (nullable = true)
 |-- DecimalType: double (nullable = true)
 
 
df = df.withColumn('ByteType', F.col('ByteType').cast(ByteType())).withColumn('ShortType', F.col('ShortType').cast(ShortType())).withColumn('IntegerType', F.col('IntegerType').cast(IntegerType())).withColumn('LongType', F.col(
    'LongType').cast(LongType())).withColumn('FloatType', F.col('FloatType').cast(FloatType())).withColumn('DoubleType', F.col('DoubleType').cast(DoubleType())).withColumn('DecimalType', F.col('DecimalType').cast(DecimalType()))

df.show()
df.printSchema()

+--------+---------+-----------+--------+---------+------------------+-----------+
|ByteType|ShortType|IntegerType|LongType|FloatType|        DoubleType|DecimalType|
+--------+---------+-----------+--------+---------+------------------+-----------+
|      21|    13333|      13333|   13333|13333.333|13333.333333333334|      13333|
+--------+---------+-----------+--------+---------+------------------+-----------+

root
 |-- ByteType: byte (nullable = true)
 |-- ShortType: short (nullable = true)
 |-- IntegerType: integer (nullable = true)
 |-- LongType: long (nullable = true)
 |-- FloatType: float (nullable = true)
 |-- DoubleType: double (nullable = true)
 |-- DecimalType: decimal(10,0) (nullable = true)

8-2:対数化

  • Sparkのlog関数[3]の第一引数は浮動小数点(float)で指定する必要がある点[4]は注意である。整数で指定した場合にはエラーとなる。
df_8_2 = df_reserve.withColumn('total_price_log', F.log(10.0, F.col('total_price')/1000 + 1))

8-3:カテゴリ化

  • floorはこちらのRoundingで言う、Round downである。なお、Round upはceilである。
df_8_3 = df_customer.withColumn('age_rank', F.floor(F.col('age')/10) * 10)

8-4:正規化

  • PySparkにも正規化の関数standardscalerが用意されているため、これを用いる。
  • standardscalerでは正規化対象の特徴量をベクトル化したものを引数と持つので、その対応を行う。(これらの処理はついては、こちらを参照)
from pyspark.ml.feature import VectorAssembler, StandardScaler

vecAssembler = VectorAssembler(
    inputCols=['people_num', 'total_price'], outputCol='features')
df_reserve_pre = vecAssembler.transform(df_reserve)

scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures",
                        withStd=True, withMean=True)

scalerModel = scaler.fit(df_reserve_pre)

df_reserve_scaled = scalerModel.transform(df_reserve_pre)
df_reserve_scaled.show(3, truncate=False)

# 一番右側のカラムにscaledFeatures
+----------+--------+-----------+-------------------+------------+------------+-------------+----------+-----------+-------------+-----------------------------------------+
|reserve_id|hotel_id|customer_id|reserve_datetime   |checkin_date|checkin_time|checkout_date|people_num|total_price|features     |scaledFeatures                           |
+----------+--------+-----------+-------------------+------------+------------+-------------+----------+-----------+-------------+-----------------------------------------+
|r1        |h_75    |c_1        |2016-03-06 13:09:42|2016-03-26  |10:00:00    |2016-03-29   |4         |97200      |[4.0,97200.0]|[1.3005476974271204,-0.0531873782588139] |
|r2        |h_219   |c_1        |2016-07-16 23:39:55|2016-07-20  |11:30:00    |2016-07-21   |2         |20600      |[2.0,20600.0]|[-0.4836930585324679,-0.7477295187903918]|
|r3        |h_179   |c_1        |2016-09-24 10:03:17|2016-10-19  |09:00:00    |2016-10-22   |2         |33600      |[2.0,33600.0]|[-0.4836930585324679,-0.6298568317550066]|
+----------+--------+-----------+-------------------+------------+------------+-------------+----------+-----------+-------------+-----------------------------------------+
only showing top 3 rows

8-5:外れ値の除去

  • 各統計値をスカラー値として扱うために、先に算出する。
  • 標準偏差については、こちらを参照。
stddev_pop_total_price = df_reserve.select(F.stddev_pop('total_price')).first()[0]
mean_total_price = df_reserve.select(F.mean('total_price')).first()[0]

df_8_5 = df_reserve.where(F.abs(df_reserve.total_price - mean_total_price) <= (stddev_pop_total_price * 3))

8-6:主成分分析

  • PySparkにもPCAクラスがあるため、Pandasとほぼ同様。
from pyspark.ml.feature import VectorAssembler, PCA

vecAssembler = VectorAssembler(
    inputCols=['length', 'thickness'], outputCol='features')
df_production_pre = vecAssembler.transform(df_production)

pca = PCA(k=2, inputCol="features", outputCol="pcaFeatures")
model = pca.fit(df_production_pre)

print(f'寄与率:{model.explainedVariance}')
print(f'累積寄与率:{model.explainedVariance.values.sum()}')

result = model.transform(df_production_pre)
result.show(3, truncate=False)

寄与率:[0.9789779361436818,0.021022063856318253]
累積寄与率:1.0
+----+------------------+------------------+---------+---------------------------------------+-----------------------------------------+
|type|length            |thickness         |fault_flg|features                               |pcaFeatures                              |
+----+------------------+------------------+---------+---------------------------------------+-----------------------------------------+
|E   |274.0273827080609 |40.24113135955541 |false    |[274.0273827080609,40.24113135955541]  |[-276.6297044755196,13.651436873595388]  |
|D   |86.31926860506081 |16.906714630016268|false    |[86.31926860506081,16.906714630016268] |[-87.54662953719902,8.511215469595069]   |
|E   |123.94038830419984|1.0184619943950775|false    |[123.94038830419984,1.0184619943950775]|[-123.46188900724863,-10.927903760176704]|
+----+------------------+------------------+---------+---------------------------------------+-----------------------------------------+
only showing top 3 rows

8-7:数値の補完

8-7-1:欠損レコードの削除

  • 読み込み対象のファイルにおけるnull値がNoneの文字列となっているので読み込み時に当値をnullとして処理するように設定する。
  • where条件内はクエリ表記で'thickness is not NULL'とすることも可能である。
df_production_missn = spark.read.option('header', True).option(
    "inferSchema", True).option('nullValue', 'None').csv('../../../data/production_missing_num.csv')

df_8_7_1 = df_production_missn.where(F.col('thickness').isNotNull())

8-7-2:定数補完

  • DataFrame.fillna()とDataFrameNaFunctions.fill()はお互いのエイリアスである。
df_8_7_2_1 = df_production_missn.fillna(1)
df_8_7_2_2 = df_production_missn.na.fill(1)
df_8_7_2_1.sameSemantics(df_8_7_2_2)
True

8-7-3:平均値補完

  • 1点目は平均値を自ら集計して、補完をしているものであり、2点目はImputerクラスを用いたものである。2点目の方が汎用性含めて良さそうに思える。
パターン1
mean_thickness = df_production_missn.select(F.mean('thickness')).first()[0]
df_8_7_3_1 = df_production_missn.na.fill(mean_thickness)
パターン2
from pyspark.ml.feature import Imputer

imputer = Imputer(inputCols=["thickness"], outputCols=["out_thickness"])
model = imputer.fit(df_production_missn)

df_8_7_3_2 = model.transform(df_production_missn)

8-7-4:PMM

  • PythonのfancyimputeライブラリのMICEクラスのようなものが用意されていないので、SparkDataframeをtoPandasなりで変換して、MICEクラスを適用するのがよいと思われる。
脚注
  1. https://spark.apache.org/docs/latest/sql-ref-datatypes.html ↩︎

  2. https://github.com/apache/spark/blob/1b6cdf1040645486ae9b5cbb0247d8869f4f259f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala#L2665 ↩︎

  3. https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.log.html ↩︎

  4. https://spark.apache.org/docs/1.6.2/api/java/org/apache/spark/sql/functions.html#log(double, org.apache.spark.sql.Column) ↩︎

Discussion