TensorFlowの3つのモデル構築法を理解する:Sequential / Functional / Subclassing
はじめに
TensorFlow(Keras)でニューラルネットワークを構築する方法は、大きく3つに分かれます。
- Sequential API
- Functional API
- Subclassing API
どの方法も最終的には同じ「モデル」を作りますが、
柔軟性・可読性・拡張性の面で大きく異なります。
この記事では、それぞれの特徴と書き方を整理し、
さらに PyTorch での対応方法も比較して理解を深めます。
1. Sequential API:最もシンプルな構造
Sequential は「層を順番に積み上げる」ための最も基本的な方法です。
入力から出力までが一本道のシンプルなネットワークに向いています。
from tensorflow.keras import models, layers
model = models.Sequential([
layers.Dense(64, activation='relu', input_shape=(100,)),
layers.Dense(10, activation='softmax')
])
特徴
- コードが短く、直感的に書ける
- 入力・出力が1対1の単純な構造に限定される
- 複雑な接続(分岐・結合など)は扱えない
主な用途
シンプルな分類モデル(MLP, 小規模CNNなど)
2. Functional API:入出力を柔軟に接続できる
Functional API では、レイヤーを「関数のように呼び出して」接続します。
これにより、複数の入力や出力、スキップ接続を持つ複雑なモデルが書けます。
from tensorflow.keras import layers, models, Input
inputs = Input(shape=(100,))
x = layers.Dense(64, activation='relu')(inputs)
x1 = layers.Dense(32, activation='relu')(x)
x2 = layers.Dense(32, activation='relu')(x)
merged = layers.concatenate([x1, x2])
outputs = layers.Dense(10, activation='softmax')(merged)
model = models.Model(inputs, outputs)
特徴
-
Input()
で入力を明示的に定義 - レイヤーを「関数呼び出しのように」つなげる
- モデル構造を視覚的に把握しやすい
- 分岐・結合・マルチ入力/出力が可能
主な用途
AutoEncoder、ResNet、Transformerなど複雑な構造
3. Subclassing API:クラスで自由に定義する
Subclassing API は、tf.keras.Model
を継承して
自分で __init__
(初期化)と call()
(順伝播処理)を定義する方法です。
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.fc1 = tf.keras.layers.Dense(64, activation='relu')
self.fc2 = tf.keras.layers.Dense(10, activation='softmax')
def call(self, inputs):
x = self.fc1(inputs)
return self.fc2(x)
model = MyModel()
model.build(input_shape=(None, 100))
model.summary()
特徴
-
__init__
にレイヤーを定義 -
call()
に順伝播処理を記述 - 条件分岐や動的な処理も書ける
-
super().__init__()
で親クラスを初期化
主な用途
RNN, GAN, 強化学習など、柔軟な制御が必要なモデル
4. 3つのAPI比較表
項目 | Sequential | Functional | Subclassing |
---|---|---|---|
書き方 | 層を順番に積む | 入出力を関数的に接続 | クラス定義で柔軟に書く |
柔軟性 | 低 | 中 | 高 |
入出力 | 1入力1出力のみ | 多入力・多出力OK | 完全自由 |
分岐・結合 | × | ○ | ○ |
可視化 (plot_model ) |
◎ | ◎ | △ |
学習 (fit ) |
○ | ○ | ○ |
主な用途 | MLP, CNN | AutoEncoder, ResNet | RNN, GANなど |
5. PyTorchでの対応関係
TensorFlow での3パターンは、PyTorchでは次のように対応します。
概念 | TensorFlow (Keras) | PyTorchでの対応 |
---|---|---|
直列構造 | Sequential([...]) |
nn.Sequential([...]) |
関数的接続 | Functional API (Model(inputs, outputs) ) |
forward() 内で柔軟に記述 |
クラス定義 | Subclassing (class MyModel(tf.keras.Model) ) |
class MyModel(nn.Module) (PyTorchの標準構文) |
6. PyTorchでのコード例
Sequential(直列構造)
import torch.nn as nn
model = nn.Sequential(
nn.Linear(100, 64),
nn.ReLU(),
nn.Linear(64, 10),
nn.Softmax(dim=1)
)
Functional風(柔軟な接続)
import torch
import torch.nn as nn
import torch.nn.functional as F
class FunctionalNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(100, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 10)
def forward(self, x):
x1 = F.relu(self.fc1(x))
x2 = F.relu(self.fc2(x))
x = torch.cat([x1, x2], dim=1)
return F.softmax(self.fc3(x), dim=1)
Subclassing(完全自由)
class CustomModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(100, 64)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(64, 10)
def forward(self, x):
x = torch.relu(self.fc1(x))
if self.training:
x = self.dropout(x)
return torch.softmax(self.fc2(x), dim=1)
7. TensorFlowとPyTorchの哲学の違い
観点 | TensorFlow | PyTorch |
---|---|---|
構文レイヤー | 3段階(Sequential / Functional / Subclassing) | 基本はすべて nn.Module ベース |
コード実行 | 静的 → 動的(v2で統一) | 動的グラフが標準 |
学習ループ |
model.fit() で自動 |
for ループ+loss.backward() で明示 |
主な用途 | 教育・実務・デプロイ | 研究・実験・Kaggle |
設計思想 | 宣言的(定義→実行) | 命令的(実行しながら構築) |
8. 同じモデルを3つのAPIで書いたときのイメージ
3つとも構造は同じですが、
- Sequential:リスト的に並べる
- Functional:データの流れを明示
- Subclassing:処理を自分で記述
という違いがあります。
まとめ
学習順としては以下のようなイメージでしょうか。私が実務で扱うことはあまりないので、実務に照らし合わせた解像度は高くありませんがSequentialで実装する例は現場だとそうないのかなと思います。
学び方のステップ | TensorFlowでの理解 | PyTorchでの対応 |
---|---|---|
Step1 | Sequentialで構造を掴む |
nn.Sequential で同様に記述 |
Step2 | Functionalで分岐・結合を学ぶ |
forward() 内で再現 |
Step3 | Subclassingで柔軟な構造を作る |
nn.Module を継承して自由に記述 |
終わりに
TensorFlowは「学びやすく構築しやすい」
PyTorchは「柔軟で実験的な開発に強い」的な印象です。流派の問題もあります。
どちらも根本は同じ計算グラフベースのフレームワーク。
違うのは、「どのレベルまで自分で制御したいか」です。
3つの構築法を理解しておくことで、
どんなフレームワークでも自然に対応できるようになります。
Discussion