Chapter 29無料公開

🏍️3️⃣ 訓練済みモデルで予測 ➔ クラスタリング結果の確認

Ryota Chijimatsu
Ryota Chijimatsu
2024.09.24に更新

7.【 コミュニティクラスターの予測 】

訓練済みモデルを通して、クラスター割り当て行列を取得する。この時、ノード水増し分のデータは廃棄している。

pred_loader = DenseDataLoader(dataset,batch_size=1, shuffle=False)

model.cpu()
results = []

for data in tqdm(pred_loader):
    # 予測
    x,s = model(data.x, data.adj, data.mask)
    # ノードの水増し分を除く
    x = x[data.mask,:]
    s = s[data.mask,:]
    
    results.append([x,s])

181サンプル分の予測は6秒で終了した。


クラスタリング結果の取得

モデルの出力の1要素目はノードと隠れ層特徴量の行列で、2要素目にノードのクラスター割り当て行列が入っている。この割り当て行列は確率値に変換していないものであるが、変換しようがしまいが最も値が高いindexをクラスター番号にすればよい。torch.argmax()を使用する。

# デモ用に1サンプル分の予測クラスター割り当て結果を取り出し
result = results[0][1]

# クラスター割り当て行列で最も高い値を持つindexをクラスター番号とする
clusters = torch.argmax(result,dim=-1)


クラスタリング結果の可視化

元の座標系にクラスター番号や細胞ラベルを反映させたplotで視覚的に確認する。座標情報や細胞ラベル情報はDataインスタンスに入れていなかったので、またローカルのファイルを読み込んで対応する。

import pandas as pd

# 座標見込み
GraphCoord_filename = "MERFISH_Brain_KNNgraph_Input/10_-0.04_Coordinates.txt"
x_y_coordinates = pd.read_csv(GraphCoord_filename, sep="\t",
                                  header=None, names=["x", "y"])

# 細胞ラベル読み込み
CellType_filename = "MERFISH_Brain_KNNgraph_Input/10_-0.04_CellTypeLabel.txt"
cell_type_label = pd.read_csv(CellType_filename, sep="\t",
                                  header=None, names=["cell_type"] )

# 細胞名とラベル番号の対応表
cell_labels = ['Astrocyte', 'Endothelial', 'Ependymal', 'Excitatory', 'Inhibitory', 'Microglia', 'OD Immature', 'OD Mature', 'Pericytes']
cell_label_dict = dict(zip(cell_labels,range(len(cell_labels))))

cell_label = [ cell_label_dict[label] for label in cell_type_label["cell_type"]]
plot機能

colormapはmatplotlibで指定可能なcolormap名かラベルと色の対応表(dict型)も受け付ける。

def plot_cluster(
    df,
    cluster=None, num_cluster=None,
    celltype=None, num_celltype=None,
    cell_labels = ['Astrocyte', 'Endothelial', 'Ependymal', 'Excitatory', 'Inhibitory', 'Microglia', 'OD Immature', 'OD Mature', 'Pericytes'],
    pt_size=5,
    figsize=(14, 6),
    celltype_cmap="tab20", cluster_cmap="jet"
):

    plt.figure(figsize=figsize)
    
    # 左側のプロット
    if celltype is not None:
        if num_celltype is None:
            num_celltype = len(np.unique(celltype))
            
        plt.subplot(121)
        if isinstance(celltype_cmap, dict):
            # クラスタ番号に応じて色を取得
            colors = [celltype_cmap[c] for c in celltype]
            scatter = plt.scatter(x=df["x"], y=df["y"],
                                  c=colors, s=pt_size)
            
            # 凡例のためのハンドルを手動で作成
            handles = [plt.Line2D([0], [0], marker='o', color='w', 
                                  markerfacecolor=celltype_cmap[i], markersize=6) 
                       for i in celltype_cmap.keys()]
        else:
            scatter = plt.scatter(x=df["x"],
                                  y=df["y"],
                                  c=celltype,
                                  s=pt_size, cmap=celltype_cmap,
                                  vmin=0, vmax=num_celltype-1)

            # 固定された全ラベルを表示する
            cmap_obj = plt.get_cmap(celltype_cmap)
            handles = [plt.Line2D([0], [0], marker='o', color='w', 
                                  markerfacecolor=cmap_obj(i/(num_celltype-1)), markersize=6) 
                       for i in range(num_celltype)]
            
        plt.title("Cell Type")
        plt.legend(handles=handles,
                   labels=cell_labels,
                   title='Max label', bbox_to_anchor=(-0.75, 1), loc='upper left')
    
    # 右側のプロット
    if cluster is not None:
        if num_cluster is None:
            num_cluster = len(np.unique(cluster))
            
        plt.subplot(122)
        if isinstance( cluster_cmap, dict ):
            # クラスタ番号に応じて色を取得
            colors = [cluster_cmap[c] for c in cluster]
            scatter = plt.scatter(x=df["x"], y=df["y"],
                                  c=colors, s=pt_size)
            
            # 凡例のためのハンドルを手動で作成
            handles = [plt.Line2D([0], [0], marker='o', color='w', 
                                  markerfacecolor=cluster_cmap[i], markersize=6) 
                       for i in cluster_cmap.keys()]
        else:
            scatter = plt.scatter(x=df["x"], y=df["y"], c=cluster,
                                  s=pt_size, cmap=cluster_cmap,
                                  vmin=0, vmax=num_cluster-1)
            
            # 固定された全クラスター値のラベルを表示する
            cmap_obj = plt.get_cmap(cluster_cmap)
            handles = [plt.Line2D([0], [0], marker='o', color='w', 
                                  markerfacecolor=cmap_obj(i/(num_cluster-1)), markersize=6) 
                       for i in range(num_cluster)]
        
        plt.title("Cluster")
        labels = range(num_cluster)
        plt.legend(handles=handles, labels=labels, title="Cluster",
                   bbox_to_anchor=(1.25, 1), loc='upper right')

    
    # レイアウト調整
    plt.subplots_adjust(right=0.8)
    plt.show()
cell_labels = ['Astrocyte', 'Endothelial', 'Ependymal', 'Excitatory', 'Inhibitory', 'Microglia', 'OD Immature', 'OD Mature', 'Pericytes']

plot_cluster(
    df=x_y_coordinates, 
    cluster= clusters, num_cluster=9,
    celltype= cell_label, num_celltype=9,
    cell_labels = cell_labels,
    pt_size=5,
    figsize=(12, 5),
    celltype_cmap="tab10", cluster_cmap="jet"
)

plot機能。もっと簡略版

file_id=引数でファイル番号を指定するだけでplotが描ける。その他にもplot_cluster()機能の引数が指定可能。

# デモデータフォルダへのパス
InputFolderName = "MERFISH_Brain_KNNgraph_Input/"

# ファイル名一覧が記載されたファイルを読み込んで、ファイル名リストを作成
Region_filename = InputFolderName + "ImageNameList.txt"
region_name_list = pd.read_csv(Region_filename,sep="\t",header=None)
region_name_list = list(region_name_list[0])

def plot_cluster2(file_id=None, **kwargs):
    plot_kwargs = {
        "df": None, 
        "cluster": None,
        "num_cluster": None,
        "celltype": None,
        "num_celltype": None,
        "cell_labels": None,
        "pt_size": 5,
        "figsize": (12, 5),
        "celltype_cmap": "tab10",
        "cluster_cmap": "jet",
    }
    plot_kwargs.update(kwargs)

    if plot_kwargs["cluster"] is None:
        plot_kwargs["cluster"] = torch.argmax(results[file_id][1],dim=-1)
    
    if plot_kwargs["num_cluster"] is None:
        plot_kwargs["num_cluster"] =  np.unique(plot_kwargs["cluster"]).max()
    
    if plot_kwargs["df"] is None: 
        name = region_name_list[file_id]
        InputFolderName = "MERFISH_Brain_KNNgraph_Input/"
        GraphCoord_filename = InputFolderName + name + "_Coordinates.txt"
        plot_kwargs["df"] = pd.read_csv(GraphCoord_filename, sep="\t",
                                      header=None, names=["x", "y"])
        
    if plot_kwargs["cell_labels"] is None:
        # 細胞名とラベル番号の対応表
        plot_kwargs["cell_labels"] = ['Astrocyte', 'Endothelial', 'Ependymal', 'Excitatory', 'Inhibitory', 'Microglia', 'OD Immature', 'OD Mature', 'Pericytes']
        
    if plot_kwargs["celltype"] is None:
        CellType_filename = InputFolderName + name + "_CellTypeLabel.txt"
        plot_kwargs["celltype"] = pd.read_csv(CellType_filename, sep="\t", header=None).iloc[:,0].to_numpy()
        
        cell_label_dict = dict(zip( plot_kwargs["cell_labels"],range(len(plot_kwargs["cell_labels"]))))

        # 細胞ラベルと色の対応表が無ければ細胞ラベルをラベル番号に変換
        if not isinstance( plot_kwargs["celltype_cmap"], dict ):
            plot_kwargs["celltype"] = [ cell_label_dict[label] for label in plot_kwargs["celltype"] ]

    if plot_kwargs["num_celltype"] is None:
        plot_kwargs["num_celltype"] = len( np.unique(plot_kwargs["celltype"]) )
    
    
    plot_cluster(
        **plot_kwargs
    )
plot_cluster2(file_id=46)