📊

Patchworklibの紹介: MatplotlibのSubplotを簡単に。

2022/04/23に公開

MatplotlibのSubplotをより簡単に。

Matplotlib は Pythonで最も利用されている可視化パッケージですが、論文やプレゼンにつかえるような綺麗なプロットを作成するためにはな複雑な構文をマスターする必要があります。特に、サブプロットの機能は十分に洗練されているとは言い難く、プロットが互いに重ならないように配置したり、プロットの枠を揃える、テキストの位置を揃える、いくつかのサブプロットに対して共通のタイトルを設定する、といった自身の理想とするレイアウトを達成するためには、時に何百行ものコードを書く必要が出てきます。

さらに、matplotlibに実装されているサブプロットの関数は、事前にFigure全体のレイアウト決定することを強制するため、Jupyter-lab のような対話型プログラミング環境でもサブプロットのレイアウトを逐次、確認、修正したりするようなことはできません(matplotlib v3.4から実装されているsubplot_mosiac関数を使えばsubplotのレイアウトを簡単に定義できますが、それでも一度決めたレイアウトを変更することは容易ではありません。)。

また、Seabornplotnineのように、数行のPythonコードで美しいプロットを作成してくれる素晴らしいデータ可視化パッケージが存在しますが、両パッケージが生成するプロットの中には、matplotlib のサブプロットとして扱えないものがあります(両パッケージは matplotlib をベースに実装されいるにも関わらずです)。

そのため、複数のプロットの配置する場合、結局subplotの機能は使わずに、keynote や Powerpoint、イラストレータなどの別のGUI ソフトを使用して手動で配置しているの方も多いのではないでしょうか?しかし、このような作業は非常に骨がおれる作業です。また、手動で編集する部分が多くなるほど、visualizaitonの再現性は低下していきますし、レイアウトを修正するのも大変になります。

こうした問題を解決するために、私は最近、複数の matplotlib プロットを|/演算子だけを使って素早く配置できる新しいサブプロットマネージャ、patchworklib の開発に取り組んできました。

Patchworklibの使い方

patchworklib モジュールは matplotlib.axes.Axes クラスのサブクラスとして実装されたBrick クラスを提供します。各Brickクラスオブジェクトは、他のBrickクラスオブジェクトと|や|演算子で結合することができます。従って、patchworklib を使用すると、次のように簡単な Python コードで 2 つのサブプロットをすばやく配置することができます。

import patchworklib as pw
ax1 = pw.Brick(figsize=(3,3))
ax2 = pw.Brick(figsize=(1,3)) 
ax1.set_title("ax1")
ax2.set_title("ax2") 
ax12 = ax1|ax2
ax12.savefig()

めちゃくちゃ簡単です。ax1とax2の位置を入れ替えたレイアウトもすぐに試すことができます。

ax12 = ax2|ax1
ax12.savefig()

そして、matplotlibでやろうとすると地味にむずかしい、subplotに共通xlabel, ylabelなんかも簡単につけられます。

ax12.set_supspine("bottom")
ax12.set_supxlabel("hoge")
ax12.savefig()

もちろん、こうして作ったFigureは別のFigureと簡単に結合することも可能です。

ax3 = pw.Brick(figsize=(2,1))
ax4 = pw.Brick(figsize=(2,2))
ax3.set_title("ax3")
ax4.set_title("ax4") 
ax34 = ax3/ax4
(ax12|ax34).savefig()

このとき二つのFigの大きさは互いの枠線が揃うように自動的に調節されます。ただ人によっては、枠線ではなくて、2つのFigureの文字の端を揃えたいと思う人もいるかもしれません。それも簡単です。

(ax12.outline|ax34.outline).savefig()

図の枠線ではなくて、外側の文字の端が揃うように図の大きさが調節されているのがわかるでしょうか?こんな感じでpatchworklibはちょーーー簡単にsubplotのレイアウトを調整することを可能にします。

Seaborn plotを並べる。

もう少し複雑な図にしたら、図が重なったりするのじゃないのかって?そんなことはありません。
では、simpleなseabornのplotを幾つか並べてみましょう。

import seaborn as sns
import patchworklib as pw

#ax1
ax1 = pw.Brick("ax1", figsize=(3,2))
fmri = sns.load_dataset("fmri")
sns.lineplot(x="timepoint", y="signal", hue="region", style="event", data=fmri, ax=ax1)
ax1.move_legend(new_loc='upper left', bbox_to_anchor=(1.05, 1.0))
ax1.set_title("ax1")

#ax2
ax2 = pw.Brick("ax2", figsize=(1,2))
titanic = sns.load_dataset("titanic")
sns.barplot(x="sex", y="survived", hue="class", data=titanic, ax=ax2)
ax2.move_legend(new_loc='upper left', bbox_to_anchor=(1.05, 1.0))
ax2.set_title("ax2")

#ax3
ax3 = pw.Brick("ax3", (5,2))
diamonds = sns.load_dataset("diamonds")
sns.histplot(diamonds, x="price", hue="cut", multiple="stack", palette="light:m_r", edgecolor=".3", linewidth=.5, log_scale=True, ax = ax3)
ax3.move_legend(new_loc='upper left', bbox_to_anchor=(1.0, 1.0))
ax3.set_title("ax3")

#ax4
ax4 = pw.Brick("ax4", (6,2))
tips = sns.load_dataset("tips")
sns.violinplot(data=tips, x="day", y="total_bill", hue="smoker",split=True, inner="quart", linewidth=1, palette={"Yes": "b", "No": ".85"}, ax=ax4)
ax4.move_legend("upper left", bbox_to_anchor=(1.02, 1.0))
ax4.set_title("ax4")

#ax5
ax5    = pw.Brick("ax5", (5,2))
rs     = np.random.RandomState(365)
values = rs.randn(365, 4).cumsum(axis=0)
dates  = pd.date_range("1 1 2016", periods=365, freq="D")
data   = pd.DataFrame(values, dates, columns=["A", "B", "C", "D"])
data = data.rolling(7).mean()
sns.lineplot(data=data, palette="tab10", linewidth=2.5, ax=ax5)
ax5.set_xlabel("date")
ax5.set_ylabel("value")
ax5.move_legend("upper left", bbox_to_anchor=(1.02, 1.0))
ax5.set_title("ax5")

#patchwork
ax35421 = (ax3/ax4)|(ax5/(ax2|ax1))
ax35421.savefig()

結果はご覧の通り。ちゃんとlegendの位置も検知して、図が重ならないように、そして枠線は揃うように自動的に配置してくれます。

Figure-levelのSeaborn plotを並べる。

ぶっちゃけた話、上記のようなレイアウトはpathckworlibを使わなくてもMatplotlibとSeabornを駆使すれば作れないことはありません。なぜなら、上記のseaborn plotは全てAxes-levelのplotというものでMatplotlibのSubplotとして扱えるように実装されていれるからです。しかし、SeabornのplotにはFigure全体を利用することが前提でつくられたFigure-levelのplotというものが存在します。SeabornにおけるFigure-level、Axes-levelの違いを知りたい人は、英語になりますが、以下のページを参考にしてください。
https://seaborn.pydata.org/tutorial/function_overview.html

こうしたFigure-levelのplotを並べる方法は、これまで基本的に手動でやる以外にはありませんでした。しかし、PatchworklibはFigure-levelのplotをAxes-levelのplotへと変換する機能を有しているので、以下のようなFigureの作成もお茶の子さいさいです。

import seaborn as sns
import patchworklib as pw
pw.overwrite_axisgrid() #When you use pw.load_seagorngrid, the 'overwrite_axisgrid' method should be executed.
#g1
df = sns.load_dataset("penguins")
g1 = sns.pairplot(df, hue="species")
g1 = pw.load_seaborngrid(g1)
g1.move_legend("upper left", bbox_to_anchor=(0.17,1.01))
#g2
planets = sns.load_dataset("planets")
cmap    = sns.cubehelix_palette(rot=-.2, as_cmap=True)
g2      = sns.relplot(data=planets, x="distance", y="orbital_period", hue="year", size="mass", palette=cmap, sizes=(10, 200))
g2.set(xscale="log", yscale="log")
g2.ax.xaxis.grid(True, "minor", linewidth=.25)
g2.ax.yaxis.grid(True, "minor", linewidth=.25)
g2.despine(left=True, bottom=True)
g2 = pw.load_seaborngrid(g2)
#g3
penguins = sns.load_dataset("penguins")
g3 = sns.jointplot(data=penguins,x="bill_length_mm", y="bill_depth_mm", hue="species", kind="kde")
g3 = pw.load_seaborngrid(g3, labels=["joint","marg_x","marg_y"])
#patchwork
((g2.outline/g3.outline)|g1).savefig()

ここまでくると、patchworlibの作成者(筆者)は天才の可能性があります。しかし残念ながら無職です。(嘘です。ちょっと盛りました。無限回廊にハマった博士課程x年生の学生です。)

plotnine plotを並べる

ここまで読んでくれた方は気づいているかもしれませんが、そうです。patchworklibはggplotのsubplotライブラリpatchworkのパクリです。ぶっちゃけた話、python,matplotlibでなくてR,ggplotを使って解析してるひとは本家patchworkを使った方が100倍幸せになれます。
しかし、Jupyterlabやら、Googlecolabやら世の中はpython以外をあまり受け入れてくれない(Rも一応使えるけどね)環境がたくさんあります。ggplotが大好きな皆さんもmatplotlibを使わざる得なかったことがあるのではないのでしょうか。そんなあなたの救世主。それがplotnine。matplotlibをベースにggplotのsyntaxを完全実装してしまった恐ろしいライブラリです。しかし、一つ大きな欠点がありました。それが、本家のpatchworkにあたるライブラリがなく、subplotの機能が全く実装されていないこと、でした。
でも、もう問題ありません。patchworkのパクリpatchworklibを使えば、plotnineのplotも以下のように簡単に並べられます。なんなら、matplotlibやseabornのplotとだって並べられます。

import patchworklib as pw 
from plotnine import * 
from plotnine.data import *  
g1 = (ggplot(mtcars) + geom_point(aes("mpg", "disp"))) 
g1 = pw.load_ggplot(g1, figsize=(2,3))
g2 = (ggplot(mtcars) + geom_boxplot(aes("gear", "disp", group="gear"))) 
g2 = pw.load_ggplot(g2, figsize=(2,3))
g3 = (ggplot(mtcars, aes('wt', 'mpg', color='factor(gear)')) + geom_point() + stat_smooth(method='lm') + facet_wrap('~gear')) 
g3 = pw.load_ggplot(g3, figsize=(3,3))
g4 = (ggplot(data=diamonds) + geom_bar(mapping=aes(x="cut", fill="clarity"), position="dodge"))  
g4 = pw.load_ggplot(g4, figsize=(5,2))
#patchwork
g1234 = (g1|g2|g3)/g4 
g1234.savefig()

ここまで、読んでくれた方ありがとうごいます。Githubのレポジトリにstarくれると喜びます。
そして、誰かドキュメントの作成手伝ってくれたりする方がいたら喜びます。
正直、実装だけしてあきた

ここで紹介してコードは、以下のGooglecolab上で実際に動かすことが可能です。気になった人は触ってみてください。

Discussion