🐼

PlanetcaleをSQLAlchemyを使ってPandasで読み書きする

2022/03/07に公開

PlanetcaleをSQLAlchemyを使ってPandasで読み書きする

やること

  • Pythonで
  • SQL(planetscale)からpandasへ読み込む
  • 平文でSQLを書くのはアレなのでSQLAlchemyを使う

PlanetScaleを使う理由

なんかアツいらしいからです。(適当)
詳しくは以下の記事を参考にしてください。

PlanetScaleというサーバレスDBが凄く勢いのあるサービスらしいのでQuick Startやってみた - Qiita

なお、今回は上記の記事を読み、ブランチを作成した前提で行きます。

脆弱性の対策

SQLを普通に文字列で書いて渡して使う〜のが普通のやり方なんですが、
それだとSQLインジェクションという脆弱性があるとのことなので、
怖いね〜ってことで対策しつつ行きます。

SQLAlchemyというライブラリを使うとSQLインジェクションを回避できるということで
これを使って準備していきます。

また今回は昨今話題らしい、planetscaleというデータベースを使って接続していきます。

DBへの接続準備

planetscaleのクイックスタートを一部見ながら進めていきます。

クイックスタートはOverviewconnectから閲覧できます。
今回は Pythonで進めていくので言語を設定します。

ターミナルで諸々準備する

  1. pipで必要なライブラリをインストールします。
  2. .envを作り、ここにパスワードとか諸々を打ち込んでいきます。
$ pip install python-dotenv mysqlclient
$ touch .env

.envを編集する

  1. エディタで .envを編集します。
  2. クイックスタートに.envというタブがあるので、ここをクリックして中身をコピペします。
HOST=ほにゃらら
USERNAME=ほにゃらら
PASSWORD=ほにゃらら
DATABASE=ほにゃらら

SQLAlchemyからPlanetscaleへ接続

  1. 諸々インポートします。
  2. load_dotenvで先程書いた .envを読み込み、 os.getenv("HOST")とかで読み出します。
  3. create_engineでDBに接続します、フォーマットなどはコードを参照してください。
    1. SSL接続が必須なので、 ?ssl_mode=VERIFY_IDENTITY",connect_args={"ssl": {"ca": "/etc/ssl/cert.pem を記載しておきましょう。
  4. これを settings.pyで保存しておきます。
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from dotenv import load_dotenv

load_dotenv()
import os

HOST = os.getenv("HOST")
USER = os.getenv("USERNAME")
PASSWD = os.getenv("PASSWORD")
DB = os.getenv("DATABASE")

# データベース接続
ENGINE = create_engine(
    f"mysql://{USER}:{PASSWD}@{HOST}/{DB}?ssl_mode=VERIFY_IDENTITY",
    connect_args={"ssl": {"ca": "/etc/ssl/cert.pem"}},
)
session = scoped_session(sessionmaker(bind=ENGINE))

# modelで使用する
Base = declarative_base()
Base.query = session.query_property()

テーブルを定義する

  1. setting.pyと別のファイルを用意します、名前は user.pyあたりでいいでしょう。
  2. setting.pyから BaseENGINE をimportします、名前を変えてたら該当箇所を適宜変えてください。
  3. class User(Base)でテーブルを定義します。
    1. 今回は適当に全部 stringで設定しています。
    2. primary_key=True は設定しておきましょう。
  4. Base.metadata.create_all(ENGINE)を実行すると、定義したデータでテーブルがDB上に作成されます。
    1. これがSQLで言うところの CREATEにあたります。
from sqlalchemy import Column, String
from setting import Base, ENGINE

class User(Base):
    """
    ユーザモデル
    """

    __tablename__ = "test"
    user_id = Column("user_id", String(767),primary_key=True)
	  data = Column("data", String(767))
    

def main():
    Base.metadata.create_all(ENGINE)

if __name__ == "__main__":
    main()

データベースを読み込む

今回、あらかじめPlanetscale上に10万行ほどデータを挿入してあります。
これを読み込んで pandasDataFrameに変換しましょう。

pandasread_sql を使ってみる(失敗)

from setting import ENGINE
from sqlalchemy import MetaData, Table, func
from sqlalchemy.sql import select
import pandas as pd

# 既存のデータを取得
metadata = MetaData()
event_data = Table("liked", metadata, autoload=True, autoload_with=ENGINE)
df_tweet_id = pd.read_sql_query(
    sql=select([
			event_data,
			func.count("*").label("rows")
		]).group_by(event_data),
    con=ENGINE,
)

エラーが出た

(MySQLdb._exceptions.OperationalError) (1153, 'rpc error: code = ResourceExhausted desc = grpc: received message larger than max (~ vs. ~)')らしいです。

つまり、一度に大量のデータを読み込むんじゃねえぞアホがってことだと思います。

別の方法を考えましょう。

SQLAlchemyでちょっとずつ読み込む(失敗)

以下を参考に 、ちょっとずつ読み込んでいく作戦で行ってみます。

SQLAlchemyで、DBから大量にデータを取ってくる時に一度に全部取得せずちょっとずつ取る - 日々精進

from setting import ENGINE
from sqlalchemy.sql import select

sel = select(User.tweet_id).select_from(User)
con = ENGINE.connect()
res = con.execution_options(stream_results=True).execute(sel)

と書いて、本当はここから forを書いていく予定だったんですが、
そもそも con.execution_options(stream_results=True).execute(sel)の時点で同じエラーが出てしまうという結果に終わってしまいました。

引っ張ってくるデータはちゃんと選定する(成功…と思いきや)

…ここまでの原因として、すべての列を10万行分読み込もうとしているというのがあります。
横着はいけません、ちゃんと欲しい列をしっかりと指定してあげましょう。

from sqlalchemy.orm import sessionmaker
from setting import ENGINE

# セッション作成
SessionClass = sessionmaker(ENGINE)
session = SessionClass()

# SELECT
ids = session.query(User.tweet_id).all()

# DataFrameに変換
df = pd.DataFrame(ids)

追記ここから

10万件を超えるデータはリミットとオフセットをかけてあげる

実際10万件超えてくると流石に駄目と言われてくるので、
LIMITOFFSETを使い、小分けにしながらデータを収集します。

  1. データベースにどれくらいの行があるか見る
  2. データベースを小分けにしながら forで回す
  3. 出てくるのが何重にもネストされた配列なので、解いて一次元配列にする。
# データベース上に何件あるかカウント見る
count_from_db = session.query(User.tweet_id).count()

# データベースのすべての行の1列を取得
db_datas = [
    session.query(User.tweet_id).filter(User.tweet_id).limit(1).offset(i).all()
    for i in tqdm(range(0, count_from_db, 10000), desc="DB取得中", leave=False)
]

# リストの平坦化
db_datas = [x[0] for x in list(itertools.chain.from_iterable(numbers))]

心配なら、データの個数を比較するのもいいかもしれません。

追記ここまで

データベースを書き込む

pandasにto_sqlというDataFrameを簡単に書き込める機能があるのでこれを使っていきます。

読みと同じく、10万件レベルの大量のデータを書き込む前提で考えていきます。

forでちょっとずつ書き込んでいく

  1. taskでだいたい何件ずつ書き込んでいくかを決めていきます。ここは実際動かしながら調整すると良いと思います。多すぎるとエラー吐きます。
  2. forDataFrameの中身を回していきます。
  3. df.ilocDataFrameの位置を決めながら情報を取得していきます。
  4. to_sql で書き込みます。
    1. if_existsはもしテーブルが存在したときにどうするか?を決めます、デフォルトだとエラーを吐くようになっているのでちゃんと設定します。
    2. method, chunksizeを設定して高速化を図ります。
from setting import ENGINE
import pandas as pd

task = 500

for i in tqdm(range(0, len(df), task)):
		w = df.iloc[i : i + task, 0:task]
		w.to_sql(
		    "liked",
		    con=ENGINE,
		    if_exists="append",
		    method="multi",
		    index=False,
		    chunksize=task,
)

おわりに

手探りでSQLalchemyとPandasを触ってみましたが、いかがでしたでしょうか?
とりあえず自分が引っかかったところは網羅できたかなと思います。

参考資料

Discussion