Open30

pysparkチートシート

antyuntyunantyuntyun

ピリオドが入っている列を扱うのを諦めない

列名をバッククォートで囲う.

# spdf: spark dataframe
columns = spdf[99:999]
columns_fixed = ['`' + column + '`' for column in columns ]
spdf_tmp = spdf.select(columns_fixed)
antyuntyunantyuntyun

諦めて素直に変換するときもある.

new_cols=(column.replace('.', '_') for column in df.columns)
df2 = df.toDF(*new_cols)
antyuntyunantyuntyun

列数多いデータのshow()でげんなりしない

pandasで表示してあげる

def printDf(sprkDF): 
    # format Spark Dataframe like pandas dataframe
    # sparkのdataframeをpandasのdataframeのように整形して出力する
    newdf = sprkDF.toPandas()
    from IPython.core.display import display, HTML
    return HTML(newdf.to_html())

参考: https://www.creationline.com/lab/21695

antyuntyunantyuntyun

複数列をまとめてcast

df2 = df.select(*(F.col(c).cast("long").alias(c) for c in df.columns))
antyuntyunantyuntyun

sparkでrow_number()

import pyspark.sql.functions as F
from pyspark.sql import Window

# spdf: spark dataframe
spdf = spdf.withColumn('new_column',F.row_number().over(Window.partitionBy('partition_column').orderBy('order_by_column')))
antyuntyunantyuntyun

desc

.withColumn('new_column',F.row_number().over(Window.partitionBy('partition_column').orderBy(F.col('order_by_column').desc())))
antyuntyunantyuntyun

sparkでtimestampを扱う

objectで読み込まれているものをtimestampにして上書き

import pyspark.sql.functions as F

spdf = spdf.withColumn('time_column', F.to_timestamp(F.col('time_column')))

指定期間で絞る

import pyspark.sql.functions as F

start = '2020-11-01 00:00:00'
end = '2021-02-01 00:00:00'

# 複数条件ある時に時々変な不具合防ぐために条件ごとに括弧で括るの推奨
spdf = spdf.filter( (F.col('Createdat') >= start) & (F.col('Createdat') <= end) )
antyuntyunantyuntyun

時間間隔を取りたいとき

longにしてあげて計算する

df = '201912'
df.withColumn('timestamp_interval'+dt,\
                    F.to_timestamp(F.lit(dt), 'yyyyMMdd').cast(T.LongType()) \
                    - F.col('create_time').cast(T.LongType()))\
        .withColumn('year_interval_from_join_to'+dt\
                    ,F.ceil(F.col('timestamp_interval_from_join_to_'+dt)/(365*24*3600)))\
        .groupBy('year_interval_from_join_to'+dt)\
``
antyuntyunantyuntyun

UTCになってしまうのをJSTになるように調整

spdf = spdf\
    .withColumn('time_column', F.to_timestamp(F.col('time_column')))\
    .withColumn('time_column', F.col('time_column') + F.expr("INTERVAL 9 HOURS"))\
antyuntyunantyuntyun

文字列化

mだけ大文字になるの注意

spdf= spdf\
    .withColumn(time_column', F.date_format(F.col('time_column'), 'yyyyMM'))
antyuntyunantyuntyun
        .withColumn('t_cleaned', F.from_unixtime(F.col('t'), 'yyyy-MM-dd HH:mm:ss') )
        .withColumn('t_cleaned', F.to_timestamp(F.col('t_cleaned')) )
        .withColumn('t_cleaned', F.from_utc_timestamp(F.col('t_cleaned'),"Asia/Tokyo") )
antyuntyunantyuntyun

集計値をサクッと出す

column_name = 'Createdat'
spdf.select(column_name).agg(\
    F.min(F.col(column_name))\
    , F.max(F.col(column_name))\
    , F.avg(F.col(column_name))\
    , F.sum(F.col(column_name))\
    ).show()
antyuntyunantyuntyun

timestampのとき

    .agg(\
          F.min(F.col(agg_column_name)).alias(f'max_{agg_column_name}')\
        , F.max(F.col(agg_column_name)).alias(f'min_{agg_column_name}')\
        , F.to_timestamp(F.from_unixtime(F.avg(agg_column_name))).alias(f'avg_{agg_column_name}') \
    )
antyuntyunantyuntyun

前後間隔をWindow関数で計算したい

行動ログ等でインターバルを見たいときに使う。
window関数とlag()で前行を引っ張て来て、差分を計算し色々な単位に変換。

import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql import Window

column_prefix = "prev_Timestamp"
time_column_name = "Timestamp"
cid = 'id'


spdf= \
    spdf.withColumn(time_column_name, F.to_timestamp(F.col(time_column_name)))
w = Window.partitionBy(cid).orderBy(time_column_name)
spdf= \
    spdf\
    .withColumn(column_prefix, F.lag(spdf[time_column_name]).over(w))\
    .withColumn(column_prefix + "_DiffInSeconds", F.col(time_column_name).cast(T.LongType()) - F.col(column_prefix).cast(T.LongType()))\
    .withColumn(column_prefix + "_DiffInMinutes",F.round(F.col(column_prefix + "_DiffInSeconds")/60))\
    .withColumn(column_prefix + "_DiffInHours",F.round(F.col(column_prefix + "_DiffInSeconds")/3600))\
    .withColumn(column_prefix + "_DiffInDays",F.round(F.col(column_prefix + "_DiffInSeconds")/(24*3600)))
antyuntyunantyuntyun

保存/分割

保存

分割csv保存、上書き可、ヘッダあり

%%time
folder_path = '/mnt/share/data/save_folder'
spdf.write.mode('overwrite').csv(folder_path, header=True)

単一ファイルとして保存

%%time
folder_path = '/mnt/share/data/save_folder'
spdf.coalesce(1).write.mode('overwrite').csv(folder_path, header=True)

パーティション基準列を指定して保存

%%time
folder_path = '/mnt/share/data/save_folder'
spdf.write.partitionBy('yyyymm').mode('overwrite').format("csv").save(folder_path, header=True)

パーティション基準列を指定して単一ファイルとして保存
sql書くときに見たことあってもスルーしてたが、coalesceの読み方ちゃんと調べてみたら”コアレス”と読むらしい。コアレス!

%%time
folder_path = '/mnt/share/data/save_folder'
spdf.coalesce(1).write.partitionBy('yyyymm').mode('overwrite').format("csv").save(folder_path, header=True)

読み込み

%%time
folder_path = '/mnt/share/data/save_folder'
spdf= spark.read.csv(folder_path, inferSchema=True, header=True)
# spdf.printSchema()
antyuntyunantyuntyun

case when がしたいとき

df= df.withColumn('label',\
    F.when(F.col('a') <= 1, 'label_1')\
     .when(F.col('a') <= 2, 'label_2')\
     .otherwise('label_3')\
)
antyuntyunantyuntyun

なぜかすぐ忘れる集約関数

df.groupBy('id').agg(F.sum(F.col('price')).alias('sum_of_price'))

df.groupBy('category')\
.agg({'id': 'count', 'price': 'sum'})\
.withColumnRenamed('sum(price)', 'sum_of_price')\
.withColumnRenamed('count(yid)', 'count')\
.select('category', 'sum_of_price')

antyuntyunantyuntyun

定数列追加

spdf = spdf.\
    .withColumn('string_column', F.lit(None).cast(T.StringType()))\
    .withColumn('int_columns', F.lit(None).cast(T.IntegerType()))\
antyuntyunantyuntyun

文字列操作

substring

開始位置(1始まり)と文字数を指定

spdf = spdf\
    .withColumn('substr', F.substring(F.col('string_column'),1,1))
antyuntyunantyuntyun

S3保存

parquet

spdf.write.mode("overwrite").parquet("s3a://bucket/spdf.parquet")

test = spark.read.parquet("s3a://bucket/spdf.parquet",inferSchema=True, header=True)

csv

spdf.write.option("quoteAll", "true").mode("overwrite").csv("s3a://bucket/spdf.split.csv", header=True)

test = spark.read.csv("s3a://bucket/spdf.split.csv", header=True, inferSchema=True)

tsv (引用符付き)

spdf.coalesce(1).write.option("quoteAll", "true").option('sep', '\t').mode("overwrite").csv("s3a://bucket/spdf.tsv", header=True)

test = spark.read.csv("s3a://bucket/spdf.tsv", sep=r'\t', header=True, inferSchema=True)
antyuntyunantyuntyun

カラムまとめて削除

columns_to_drop=['xxx','yyy','zzz']
spdf.drop(*columns_to_drop)
antyuntyunantyuntyun

欠損値の扱い

Mllibが便利そう

medianで埋める

from pyspark.ml.feature import Imputer

imputer = Imputer(
    inputCols=fill_by_average_columns, 
    outputCols=fill_by_average_columns
    ).setStrategy("median")
spdf_f1f2list_feature = imputer.fit(spdf_f1f2list_feature).transform(spdf_f1f2list_feature)
antyuntyunantyuntyun

横方向に足し算

spdf.withColumn('total', sum(spdf[col] for col in spdf.columns)).select('total').show()
antyuntyunantyuntyun

PySparkで特定のカラムが全体の最大値であるレコードを取得する

やまっぷさん記事参照
https://yamap55.hatenablog.com/entry/2019/07/19/090000

from pyspark.sql import functions as F
from pyspark.sql.window import Window as W

df = spark.createDataFrame(
  [['a', '201906'], ['a', '201907'], ['b', '201906'], ['b', '201907'], ['c', '201907']],
  ['name', 'date']
)
df.show()

result_df = (
  df
    .withColumn('max_date', F.max('date').over(W.partitionBy()))
    .filter(f.col('date') == F.col('max_date'))
    .drop('max_date')
)
result_df.show()
antyuntyunantyuntyun

array column同士の比較

df = spark.createDataFrame(
    pd.DataFrame(
        data=[
            [["hello"], ["world"]],
            [["hello", "world"], ["world"]],
            [["sample", "overflow", "text"], ["sample", "text"]],
        ],
        columns=["A", "B"],
    )
)
df.show(truncate=False)

# array_exceptでAにあってBにない要素を抽出
df2 = dff.withColumn('difference', F.array_except('A', 'B'))
df2.show()
# array_intersectで共通する要素抽出
df3 = dff.withColumn('difference', F.array_intersect('A', 'B')).withColumn(
    'size_difference', F.size('difference')
)
dff3.show()
antyuntyunantyuntyun

json文字列の解析

# json_columnが[{key1:value01, ket2:value02}, {key1:value11, ket2:value12},,, ]である場合
    .withColumn(
        # array of dictのstructの列を追加
        'from_json_column',
        F.from_json(
            'json_column',
            T.ArrayType(
                T.StructType(
                    [
                        T.StructField("key1", T.IntegerType()),
                        T.StructField("key2", T.StringType()),
                    ]
                )
            ),
        ),
    )