Closed7

Stable Diffusionについて自分なりのメモ

ijiwarunahelloijiwarunahello

環境

  • macOS Monterey
  • MacBook Pro (16-inch, 2019)
  • 2.6GHz 6コア Intel Core i7
  • 16GB 2667 MHz DDR4
  • Intel UHD Graphics 630 1536 MB
❯ sw_vers
ProductName:	macOS
ProductVersion:	12.6
BuildVersion:	21G115
ijiwarunahelloijiwarunahello

環境セットアップ

まずアナコンダ入れる

https://www.anaconda.com/products/distribution

Anaconda Distributionを入れたあと、シェルで初期化

❯ ~/opt/anaconda3/bin/conda init zsh

シェルを再起動するか再読み込みするとパスが通る

❯ conda --version
conda 22.9.0

リポジトリをクローン

https://github.com/CompVis/stable-diffusion

git clone https://github.com/CompVis/stable-diffusion.git
cd stable-diffusion

environment.yamlを編集

diff --git a/environment.yaml b/environment.yaml
index 025ced8..f2a9004 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -1,13 +1,12 @@
 name: ldm
 channels:
-  - pytorch
+  - pytorch-nightly
   - defaults
 dependencies:
   - python=3.8.5
   - pip=20.3
-  - cudatoolkit=11.3
-  - pytorch=1.11.0
-  - torchvision=0.12.0
+  - pytorch
+  - torchvision
   - numpy=1.19.2
   - pip:
     - albumentations==0.4.3

conda環境を作成して有効化

conda env create -f environment.yaml
conda activate ldm
ijiwarunahelloijiwarunahello

トレーニング

クローンしたリポジトリ直下に以下スクリプトを作成

pytorch_m1_macbook.py
# -*- coding: utf-8 -*-
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms as tt
from torchvision.models import resnet18

import os
from argparse import ArgumentParser
import time

def main(device):
    # ResNetのハイパーパラメータ
    n_epoch = 5            # エポック数
    batch_size = 512       # ミニバッチサイズ
    momentum = 0.9         # SGDのmomentum
    lr = 0.01              # 学習率
    weight_decay = 0.00005 # weight decay

    # 訓練データとテストデータを用意
    mean = (0.491, 0.482, 0.446)
    std = (0.247, 0.243, 0.261)
    train_transform = tt.Compose([
        tt.RandomHorizontalFlip(p=0.5),
        tt.RandomCrop(size=32, padding=4, padding_mode='reflect'),
        tt.ToTensor(),
        tt.Normalize(mean=mean, std=std)
    ])
    test_transform = tt.Compose([tt.ToTensor(), tt.Normalize(mean, std)])
    root = os.path.dirname(os.path.abspath(__file__))
    train_set = CIFAR10(root=root, train=True,
                        download=True, transform=train_transform)
    train_loader = DataLoader(train_set, batch_size=batch_size,
                              shuffle=True, num_workers=8)

    # ResNetの準備
    resnet = resnet18()
    resnet.fc = torch.nn.Linear(resnet.fc.in_features, 10)

    # 訓練
    criterion = CrossEntropyLoss()
    optimizer = SGD(resnet.parameters(), lr=lr,
                    momentum=momentum, weight_decay=weight_decay)
    train_start_time = time.time()
    resnet.to(device)
    resnet.train()
    for epoch in range(1, n_epoch+1):
        train_loss = 0.0
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            optimizer.zero_grad()
            outputs = resnet(inputs)
            labels = labels.to(device)
            loss = criterion(outputs, labels)
            loss.backward()
            train_loss += loss.item()
            del loss  # メモリ節約のため
            optimizer.step()
        print('Epoch {} / {}: time = {}[s], loss = {:.2f}'.format(
            epoch, n_epoch, time.time() - train_start_time, train_loss))
    print('Train time on {}: {:.2f}[s] (Train loss = {:.2f})'.format(
        device, time.time() - train_start_time, train_loss))

    # 評価
    test_set = CIFAR10(root=root, train=False, download=True,
                       transform=test_transform)
    test_loader = DataLoader(test_set, batch_size=batch_size,
                             shuffle=False, num_workers=8)
    test_loss = 0.0
    test_start_time = time.time()
    resnet.eval()
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        outputs = resnet(inputs)
        labels = labels.to(device)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
    print('Test time on {}: {:.2f}[s](Test loss = {:.2f})'.format(
        device, time.time() - test_start_time, test_loss))


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--device', type=str, default='mps',
                        choices=['cpu', 'mps'])
    args = parser.parse_args()
    device = torch.device(args.device)
    main(device)

トレーニング実行

# CPUを使う場合
python pytorch_m1_macbook.py --device cpu
# MPS(Metal Performance Shaders)←こちらの方が速いらしい
python pytorch_m1_macbook.py --device mps
ijiwarunahelloijiwarunahello

txt2img実行

実行の前に、cudaで実行するようになっている箇所をcpuに変更する

diff --git a/configs/stable-diffusion/v1-inference.yaml b/configs/stable-diffusion/v1-inference.yaml
index d4effe5..8b7ceb9 100644
--- a/configs/stable-diffusion/v1-inference.yaml
+++ b/configs/stable-diffusion/v1-inference.yaml
@@ -68,3 +68,5 @@ model:
 
     cond_stage_config:
       target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
+      params:
+        device: cpu
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
index 78eeb10..b1c9fd2 100644
--- a/ldm/models/diffusion/plms.py
+++ b/ldm/models/diffusion/plms.py
@@ -17,8 +17,8 @@ class PLMSSampler(object):
 
     def register_buffer(self, name, attr):
         if type(attr) == torch.Tensor:
-            if attr.device != torch.device("cuda"):
-                attr = attr.to(torch.device("cuda"))
+            if attr.device != torch.device("cpu"):
+                attr = attr.to(torch.device("cpu")).contiguous()
         setattr(self, name, attr)
 
     def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py
index f4eff39..27d5744 100644
--- a/ldm/modules/attention.py
+++ b/ldm/modules/attention.py
@@ -209,6 +209,7 @@ class BasicTransformerBlock(nn.Module):
         return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
 
     def _forward(self, x, context=None):
+        x = x.contiguous()
         x = self.attn1(self.norm1(x)) + x
         x = self.attn2(self.norm2(x), context=context) + x
         x = self.ff(self.norm3(x)) + x
@@ -258,4 +259,4 @@ class SpatialTransformer(nn.Module):
             x = block(x, context=context)
         x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
         x = self.proj_out(x)
-        return x + x_in
\ No newline at end of file
+        return x + x_in
diff --git a/scripts/txt2img.py b/scripts/txt2img.py
index 59c16a1..b7bbb1b 100644
--- a/scripts/txt2img.py
+++ b/scripts/txt2img.py
@@ -60,7 +60,7 @@ def load_model_from_config(config, ckpt, verbose=False):
         print("unexpected keys:")
         print(u)
 
-    model.cuda()
+    model.to('cpu')
     model.eval()
     return model
 
@@ -239,7 +239,7 @@ def main():
     config = OmegaConf.load(f"{opt.config}")
     model = load_model_from_config(config, f"{opt.ckpt}")
 
-    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+    device = torch.device("cpu") if torch.cuda.is_available() else torch.device("cpu")
     model = model.to(device)
 
     if opt.plms:
@@ -279,7 +279,7 @@ def main():
 
     precision_scope = autocast if opt.precision=="autocast" else nullcontext
     with torch.no_grad():
-        with precision_scope("cuda"):
+        with precision_scope("cpu"):
             with model.ema_scope():
                 tic = time.time()
                 all_samples = list()

いよいよ実行 2時間弱待つと出力される

python scripts/txt2img.py --prompt "a photographic of the cyberpunk judo" --plms --precision full

このスクラップは2023/02/23にクローズされました