🙃

Pytorchのforward, train/evalのワンライナー化

に公開

torch.nnのforward関数

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        for layer in self.layers[:-1]:
            x = self.F.relu(layer(x))
        return self.layers[-1](x)
def forward(self, x): return self.layers[-1](__import__('functools').reduce(lambda a, l: self.F.relu(l(a)), self.layers[:-1], x.view(-1, self.input_dim)))

def forward(self, x): return self.layers[-1](functools.reduce(lambda a, l: self.F.relu(l(a)), self.layers[:-1], x.view(-1, self.input_dim)))

def forward(self, x): return self.layers[-1](functools.reduce(lambda a, l: self.F.relu(l(a)), self.layers[:-1], x.view(-1, self.input_dim)))

Class内でlambdaも使えるが型ヒントが使えない

forward= lambda self, x: self.layers[-1](functools.reduce(lambda a, l: self.F.relu(l(a)), self.layers[:-1], x))

train,eval関数

以下のPytorchtrain,evalスクリプトをワンライナーにすることはできますか?

def train(device,net,loader,optimizer,criterion):
    for inputs, labels in loader:
        optimizer.zero_grad()
        outputs = net(inputs.to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()
    return loss

def eval(device,net,testloader):
    energies = []
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            energies.append(loss.item()*inputs.shape[0])
    return sum(energies)

ChatGPT5の解答

もちろん可能です。以下は動作は同じで完全に1行に圧縮したPyTorchワンライナー版です。
Python 3.8以降(ウォルラス演算子 := が必要)を前提にしています。


✅ ワンライナー版

def train(device, net, loader, opt, crit): [opt.zero_grad() or (loss := crit(net(inputs.to(device)), labels.to(device))).backward() or opt.step() for inputs, labels in loader]; return loss
def eval(device, net, testloader): 
    return sum((loss := crit(net((x := x.to(device))), (y := y.to(device)))).item() * x.shape[0] for x, y in testloader)

…だとさすがに読めないので、ちゃんと動く形で「実質1行」にするとこうです👇

def eval(device, net, testloader): 
    return sum([ (criterion(net((inputs := inputs.to(device)))), labels := labels.to(device))[0].item()*inputs.shape[0]
                 for inputs, labels in trainloader ])

🔍 説明

  • opt.zero_grad() or (...) という構文は、list comprehension 内で複数文を評価するための Python トリックです。
    or は左辺が None なので右辺を評価します。
  • := で代入して、最後に loss を返しています(代入式(ウォルラス演算子 :=)を使うと「to(device) した結果を2回使う」が1回で済みます。)。
  • evaltorch.no_grad() の文脈を簡略化するなら次のようにも書けます:
def eval(device, net, testloader): 
    with torch.no_grad(): return sum((loss := crit(net((x := x.to(device))), (y := y.to(device)))).item() * x.shape[0] for x, y in testloader)

⚠️ 注意

これらは動作するが極めて読みにくいので、実際の実験コードでは

def train(net, loader, opt, crit):
    for x, y in loader: opt.zero_grad(); (crit(net(x), y)).backward(); opt.step()

def eval(device, net, testloader): 
    with torch.no_grad(): return sum((loss := criterion(net((x := x.to(device))), (y := y.to(device)))).item()*x.shape[0] for x, y in trainloader)

のままにして、短縮するならせいぜいセミコロン区切りの2行構成に留めるのが実用的です。

感想

optimizer.zero_grad()
loss.backward()
は副作用を伴うので分かりづらくワンライナーの利点が薄い

with構文を1行にするにはやりすぎ

コード

https://gist.github.com/xiangze/e3c2d3d5165b6993394544445cd4101a

Discussion