🍣

【次世代動画生成】Open-Sora徹底解説【OSS版Sora?】

2024/04/17に公開

今回は論文というわけではないですが、以前解説記事を書いたSoraのOSS版(技術レポートに基づいた再現実装)であるOpen-Soraの詳細な技術紹介ができればと思います。

https://github.com/hpcaitech/Open-Sora/tree/main

Open-Soraとは

Open-Soraは、高品質なビデオを効率的に制作し、そのモデル、ツール、コンテンツを誰もが利用できるようにするための取り組みとの事。
名前通り、一般向けに公開されていないSoraの再現を目指しているオープンソース活動です。

現状いくつかSoraの再現を目指すOSS活動がありますが、その中で一番スター数を集めているリポジトリになります。

すでに訓練済みのモデルも公開されており、短い時間の動画であれば(VRAMの要求が少し厳しいですが)個人でも動画が生成できるようになっています。

公開されているモデル

現時点(2024/03末)で作られているモデルは以下になります。

Resolution Data #iterations Batch Size GPU days (H800) URL
16×256×256 366K 80k 8×64 117 https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-16x256x256.pth
16×256×256 20K HQ 24k 8×64 45 https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-HQ-16x256x256.pth
16×512×512 20K HQ 20k 2×64 35 https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-HQ-16x512x512.pth

いずれも16フレームで学習が行われているため、長時間の動画生成はできないようです(大体2秒程度)

実際に生成してみた結果

  • プロンプト
    • Waves are crashing on the sandy beach. The ocean glows orange in the setting sun, and the sun is about to set on the horizon.
  • モデル
Resolution Data #iterations Batch Size GPU days (H800) URL
16×256×256 366K 80k 8×64 117 https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-16x256x256.pth

アーキテクチャ

さて、このOpen-Soraですがどのような構造になっているのでしょうか?
SoraはDiTを使っているなどある程度の情報が公開されていますが、全てのアーキテクチャが公開されているわけではありません。
そのため、Open-Soraも完全にSoraと同じというわけではなくある程度オリジナルの構造を持っています。

では、どのような構造を持っているのでしょうか?
Open-Soraも技術レポートを出しているため、その内容を確認していきます。

レポートの全文はこちら
https://github.com/hpcaitech/Open-Sora/blob/main/docs/report_v1.md

Open-Soraのアーキテクチャ

モデル構造

  • Open-Soraは LattePixArt-αがベースとなっています。

    • LatteはDiTを使った動画生成モデルのOSSで、現在出ているDiTベースの動画生成は大体Latteをベースとしています
      • Open-SoraはLatteの考え方をベースにしつつも、モデル自体はPixArt-αを改造したような構造になっている
    • PixArt-αは、DiTを使った画像生成モデルのOSSで、訓練済みの高品質な画像生成モデルが提供されています
      • また、Open-Soraでは高い品質で動画を生成するために、事前学習としてPixArt-αの訓練済みの重さを学習時に読みんでいます
  • VAEはStabilityAIの「sd-vae-ft-mse-original」を使用しています

  • モデル構造はこちら

といってモデル図を見せられても何をやっているのか分かりづらいと思うので、少しずつ解説していきます。
まずVAEについて。
VAEは、Stability AIが公開している2D(画像用の縦×横)の物を使っています。
まずここで、動画なのに2DのVAE?という疑問が浮かんできます。
Soraでは3Dの動画用の3D VAE(時間×縦×横)を一から訓練して作っていました(そのための工夫なども技術レポートで公開されています)
しかし、Open-SoraではVAEの訓練は行わず、すでにある品質が高い2D VAEを流用することで動画を生成しているようです。

続いて、モデル構造を見ていきます。
先ほどのモデル構造の前に、まずこちらの画像を見てください。

これは、PixArt-αのモデル構造になります。
Open-Soraのモデル構造とよく似ているのが分かります。
実は、OpenSoraのアーキテクチャはPixArt-αのモデル構造に「Temporal Self-Attention」のブロックを追加しただけで、ほぼPixArt-αの構造をそのまま使っています。
ちなみにここで追加しているTemporal Self-Attentionは、時系列(動画で言うとフレーム数)の変化を捉えて学習する機構になります。

さてここで二つ目の疑問が出てきます。
Open-Soraは訓練時にPixArt-αの重みを利用すると書きましたが「Temporal Self-Attention」なんていう余計なブロックをくっつけられた構造で、PixArt-αの重みの取り込みなどできるのでしょうか?

Open-Sora徹底解説

この二つの疑問を解決するためには、Open-Sora内部でデータがどのように加工されているか?というのをコードレベルで理解する必要があります。
(なんせ技術レポートではすごくざっくりしか話されていないので)

ここからが今日の記事の本番になります。
それではOpen-Soraコードレベル解説、いってみましょう。

入力データ

この後は、入力されたデータがどのように加工されていくか?に注目してコードを交えて解説していきますので、まずは入力されるデータ形式について理解しておく必要があります。

入力データの形式は以下のようになっています。

  • Video :16Frame 512×512のビデオファイル
  • batch_size = 8
  • 入力時のテンソル形式は以下
    • バッチ×チャンネル×フレーム数×縦×横(B×C×T×H×W)

ここまでは、よくある動画データの持ち方ですね。
では、このデータを加工していきましょう。

VAE

まずはVAEです。
さきほど紹介した通り、Open-SoraではVAEに2D用の「sd-vae-ft-ema」を使っています。
ここで、一つ目の疑問である2D VAEでどうやって動画データを潜在空間に落とし込むのか?を見ていきます。

sd-vae-ft-ema自体は画像生成用のモデルのため、このVAEを使うためにはテンソルの形式を(B×C×H×W)にする必要があります。

動画データは(B×C×T×H×W)のため、sd-vae-ft-emaでLatentに変換するため以下のようにテンソルを(B×C×H×W)の形に変換することになります。

(ちなみに、sd-vae-ft-emaの設定値は以下になります)

{
  "_class_name": "AutoencoderKL",
  "_diffusers_version": "0.4.2",
  "act_fn": "silu",
  "block_out_channels": [
    128,
    256,
    512,
    512
  ],
  "down_block_types": [
    "DownEncoderBlock2D",
    "DownEncoderBlock2D",
    "DownEncoderBlock2D",
    "DownEncoderBlock2D"
  ],
  "in_channels": 3,
  "latent_channels": 4,
  "layers_per_block": 2,
  "norm_num_groups": 32,
  "out_channels": 3,
  "sample_size": 256,
  "up_block_types": [
    "UpDecoderBlock2D",
    "UpDecoderBlock2D",
    "UpDecoderBlock2D",
    "UpDecoderBlock2D"
  ]
}

(B×C×T×H×W)から(B×C×H×W)に変換したことで、2DのVAEを利用する事ができるようになりました。
これが、一つ目の疑問である動画データに対してどのようにSD VAEを使っているか?の回答となります。

このようにVAEに通すことで、入力データはこのような形になります。

VAEを通した後は、2Dの形式にしておく必要がないので、再び元の形に戻します。

テキストデータの入力

ここまで画像の話をしてきましたが、Open-Soraはtext2Videoモデルのため、当然生成時のヒントとなるテキスト(プロンプト)を入力する必要があります。
Open-SoraはテキストエンコーダーにCLIPではなくT5を使っています。
これによって、動画内で文字の生成なども可能になっています。

T5については、他に解説してくださっている記事があるので、こちらをご覧ください。
https://www.ogis-ri.co.jp/otc/hiroba/technical/similar-document-search/part7.html

ちなみに、Open-Soraのデフォルト設定ではテキストエンコーダーは以下のようになっています。

text_encoder = dict(
    type="t5",
    from_pretrained="DeepFloyd/t5-v1_1-xxl",
    model_max_length=120,
)

STDiT Block

ここからがOpen-Sora本体の話になります。
改めて、Open-Soraのモデル構造を見てみましょう。

このモデルは、コード上ではこのように定義されています。

@MODELS.register_module()
class STDiT(nn.Module):
    def __init__(
        self,
        input_size=(1, 32, 32),
        in_channels=4,
        patch_size=(1, 2, 2),
        hidden_size=1152,
        depth=28,
        num_heads=16,
        mlp_ratio=4.0,
        class_dropout_prob=0.1,
        pred_sigma=True,
        drop_path=0.0,
        no_temporal_pos_emb=False,
        caption_channels=4096,
        model_max_length=120,
        dtype=torch.float32,
        space_scale=1.0,
        time_scale=1.0,
        freeze=None,
        enable_flashattn=False,
        enable_layernorm_kernel=False,
        enable_sequence_parallelism=False,
    ):
        super().__init__()
        self.pred_sigma = pred_sigma
        self.in_channels = in_channels
        self.out_channels = in_channels * 2 if pred_sigma else in_channels
        self.hidden_size = hidden_size
        self.patch_size = patch_size
        self.input_size = input_size
        num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])
        self.num_patches = num_patches
        self.num_temporal = input_size[0] // patch_size[0]
        self.num_spatial = num_patches // self.num_temporal
        self.num_heads = num_heads
        self.dtype = dtype
        self.no_temporal_pos_emb = no_temporal_pos_emb
        self.depth = depth
        self.mlp_ratio = mlp_ratio
        self.enable_flashattn = enable_flashattn
        self.enable_layernorm_kernel = enable_layernorm_kernel
        self.space_scale = space_scale
        self.time_scale = time_scale

        self.register_buffer("pos_embed", self.get_spatial_pos_embed())
        self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())

                self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
        self.t_embedder = TimestepEmbedder(hidden_size)
        self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
        self.y_embedder = CaptionEmbedder(
            in_channels=caption_channels,
            hidden_size=hidden_size,
            uncond_prob=class_dropout_prob,
            act_layer=approx_gelu,
            token_num=model_max_length,
        )

        drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)]
        self.blocks = nn.ModuleList(
            [
                STDiTBlock(
                    self.hidden_size,
                    self.num_heads,
                    mlp_ratio=self.mlp_ratio,
                    drop_path=drop_path[i],
                    enable_flashattn=self.enable_flashattn,
                    enable_layernorm_kernel=self.enable_layernorm_kernel,
                    enable_sequence_parallelism=enable_sequence_parallelism,
                    d_t=self.num_temporal,
                    d_s=self.num_spatial,
                )
                for i in range(self.depth)
            ]
        )
        self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)

        # init model
        self.initialize_weights()
        self.initialize_temporal()
        if freeze is not None:
            assert freeze in ["not_temporal", "text"]
            if freeze == "not_temporal":
                self.freeze_not_temporal()
            elif freeze == "text":
                self.freeze_text()

        # sequence parallel related configs
        self.enable_sequence_parallelism = enable_sequence_parallelism
        if enable_sequence_parallelism:
            self.sp_rank = dist.get_rank(get_sequence_parallel_group())
        else:
            self.sp_rank = None

    def forward(self, x, timestep, y, mask=None):
        """
        Forward pass of STDiT.
        Args:
            x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W]
            timestep (torch.Tensor): diffusion time steps; of shape [B]
            y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C]
            mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token]

        Returns:
            x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
        """

        x = x.to(self.dtype)
        timestep = timestep.to(self.dtype)
        y = y.to(self.dtype)

        # embedding
        x = self.x_embedder(x)  # [B, N, C]
        x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
        x = x + self.pos_embed
        x = rearrange(x, "B T S C -> B (T S) C")

        # shard over the sequence dim if sp is enabled
        if self.enable_sequence_parallelism:
            x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")

        t = self.t_embedder(timestep, dtype=x.dtype)  # [B, C]
        t0 = self.t_block(t)  # [B, C]
        y = self.y_embedder(y, self.training)  # [B, 1, N_token, C]

        if mask is not None:
            if mask.shape[0] != y.shape[0]:
                mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
            mask = mask.squeeze(1).squeeze(1)
            y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
            y_lens = mask.sum(dim=1).tolist()
        else:
            y_lens = [y.shape[2]] * y.shape[0]
            y = y.squeeze(1).view(1, -1, x.shape[-1])

        # blocks
        for i, block in enumerate(self.blocks):
            if i == 0:
                if self.enable_sequence_parallelism:
                    tpe = torch.chunk(
                        self.pos_embed_temporal, dist.get_world_size(get_sequence_parallel_group()), dim=1
                    )[self.sp_rank].contiguous()
                else:
                    tpe = self.pos_embed_temporal
            else:
                tpe = None
            x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)

        if self.enable_sequence_parallelism:
            x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
        # x.shape: [B, N, C]

        # final process
        x = self.final_layer(x, t)  # [B, N, C=T_p * H_p * W_p * C_out]
        x = self.unpatchify(x)  # [B, C_out, T, H, W]

        # cast to float32 for better accuracy
        x = x.to(torch.float32)
        return x

実際に入力されたデータが加工されていくのはForward関数になりますので、そちらを細かく見ていきます。

埋め込み処理(トークン化)

まず、VAEを通して得られた「Video Latent Reprecentation」をTransformerが扱いやすいテンソル形式に再度変換します。

コードとしてはこのあたりですね。

# embedding
x = self.x_embedder(x)  # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")

VisionTransformerを、動画も取り扱えるように拡張した最初期のモデルであるVivit(A video vision transformer)では、動画をTransformerで扱うための方法として以下の二つの手法が提案されています。

Open-Soraでは後者の埋め込み方法を利用しており、具体的には以下のような変換を行っています。

VAE変換直後に得られる「Video Latent Reprecentation」の形

Transformerが読み込めるように、Token化された「Video Latent Reprecentation」の形

また、続く

x = x + self.pos_embed

の部分は、ViTのPatch +Position Embedding や VivitのPositional Token Embeddingと同じような操作を行っています。

次は、DiffusionタイムステップとプロンプトをTransformerが取り扱える形に変更していきます。
コードはこんな感じ

        t = self.t_embedder(timestep, dtype=x.dtype)  # [B, C]
        t0 = self.t_block(t)  # [B, C]
        y = self.y_embedder(y, self.training)  # [B, 1, N_token, C]

ではコードを見ていきましょう。
まず、t_embeddingでは拡散過程/逆拡散過程において、いまどのステップを進めているのかを表す埋め込み表現を作成しています。
ここでいうtimestepとは、デノイジングステップなどとも呼ばれるもので、拡散モデルにおいて、ノイズをかける/ノイズを取り除くという工程を1ステップとし、それが何回行われたか?を表す数となります。
ステップ回数が少ないうちはまだノイズだらけの画像で、このステップが進むと徐々にノイズが取り除かれていくイメージを持っていただければと思います。

一つ注意する点としては、このタイムステップはFrame数などの動画における時間推移を表しているわけではないという点で、Open-Soraでは時系列情報が拡散/逆拡散過程の軸と動画における時間推移の二つの軸があるので混同しないようにする必要があります。

最後にy_embedderについてですが、これは動画につけられたキャプションの埋め込み表現を作成しています。

実体はVision Transformerで定義されているMLPです。

pytorch-image-models/timm/layers/mlp.py at main · huggingface/pytorch-image-models

このときy_embedderに渡されるキャプションは、文字情報そのままではなくTextEncoder(T5)によってEncodeされたものになります。

Diffusion Transformer部分

このモデルの実体部分についてです。
コード上ではこの部分ですね

 # blocks
        for i, block in enumerate(self.blocks):
            if i == 0:
                if self.enable_sequence_parallelism:
                    tpe = torch.chunk(
                        self.pos_embed_temporal, dist.get_world_size(get_sequence_parallel_group()), dim=1
                    )[self.sp_rank].contiguous()
                else:
                    tpe = self.pos_embed_temporal
            else:
                tpe = None
            x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)

        if self.enable_sequence_parallelism:
            x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
        # x.shape: [B, N, C]

ただし、ここは定義されたブロックに対しデータを入力しているだけなので、処理内容を追うにはもう一段深くコードを読む必要があります。
self.blocksは以下のように定義されていますので、深堀すべきはSTDiT Blockとなります。

self.blocks = nn.ModuleList(
            [
                STDiTBlock(
                    self.hidden_size,
                    self.num_heads,
                    mlp_ratio=self.mlp_ratio,
                    drop_path=drop_path[i],
                    enable_flashattn=self.enable_flashattn,
                    enable_layernorm_kernel=self.enable_layernorm_kernel,
                    enable_sequence_parallelism=enable_sequence_parallelism,
                    d_t=self.num_temporal,
                    d_s=self.num_spatial,
                )
                for i in range(self.depth)
            ]
        )

STDiT Blockは以下のように定義されています。

class STDiTBlock(nn.Module):
    def __init__(
        self,
        hidden_size,
        num_heads,
        d_s=None,
        d_t=None,
        mlp_ratio=4.0,
        drop_path=0.0,
        enable_flashattn=False,
        enable_layernorm_kernel=False,
        enable_sequence_parallelism=False,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.enable_flashattn = enable_flashattn
        self._enable_sequence_parallelism = enable_sequence_parallelism

        if enable_sequence_parallelism:
            self.attn_cls = SeqParallelAttention
            self.mha_cls = SeqParallelMultiHeadCrossAttention
        else:
            self.attn_cls = Attention
            self.mha_cls = MultiHeadCrossAttention

        self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
        self.attn = self.attn_cls(
            hidden_size,
            num_heads=num_heads,
            qkv_bias=True,
            enable_flashattn=enable_flashattn,
        )
        self.cross_attn = self.mha_cls(hidden_size, num_heads)
        self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
        self.mlp = Mlp(
            in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)

        # temporal attention
        self.d_s = d_s
        self.d_t = d_t

        if self._enable_sequence_parallelism:
            sp_size = dist.get_world_size(get_sequence_parallel_group())
            # make sure d_t is divisible by sp_size
            assert d_t % sp_size == 0
            self.d_t = d_t // sp_size

        self.attn_temp = self.attn_cls(
            hidden_size,
            num_heads=num_heads,
            qkv_bias=True,
            enable_flashattn=self.enable_flashattn,
        )

    def forward(self, x, y, t, mask=None, tpe=None):
        B, N, C = x.shape

        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
            self.scale_shift_table[None] + t.reshape(B, 6, -1)
        ).chunk(6, dim=1)
        x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)

        # spatial branch
        x_s = rearrange(x_m, "B (T S) C -> (B T) S C", T=self.d_t, S=self.d_s)
        x_s = self.attn(x_s)
        x_s = rearrange(x_s, "(B T) S C -> B (T S) C", T=self.d_t, S=self.d_s)
        x = x + self.drop_path(gate_msa * x_s)

        # temporal branch
        x_t = rearrange(x, "B (T S) C -> (B S) T C", T=self.d_t, S=self.d_s)
        if tpe is not None:
            x_t = x_t + tpe
        x_t = self.attn_temp(x_t)
        x_t = rearrange(x_t, "(B S) T C -> B (T S) C", T=self.d_t, S=self.d_s)
        x = x + self.drop_path(gate_msa * x_t)

        # cross attn
        x = x + self.cross_attn(x, y, mask)

        # mlp
        x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))

        return x

こちらもForward関数を中心に見ていきましょう。

Open-Soraのモデル画像を見てもわかる通り、STDiT Blockは、大きく分けて「Spatial Self Attention, Temporal Self Attention, Prompt Cross Attention, Point Wise Feed Forward」 の4つのブロックで構成されています。

前処理

各ブロックで処理を行う前に、もうひと段階「Video Latent Representation」を加工しています。

具体的には、拡散過程/逆拡散過程において現在のデノイジングステップ数に応じて、スケーリング(データの尺度を変更する操作)とシフト(データに一定の値を加算し、データの中心を変更する操作)を用いて、入力テンソルの各要素を調整しています。

shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
    self.scale_shift_table[None] + t.reshape(B, 6, -1)
).chunk(6, dim=1)
x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)

調整は線形モジュレーションを通して行われており、具体的な処理は以下のように定義されています。

def t2i_modulate(x, shift, scale):
    return x * (1 + scale) + shift

これにより、拡散過程/逆拡散過程において異なるステップ間でも関連性を見失わずに学習/推論することができるわけですね。
このあたり、何でこんなことをしているのか詳しく知りたい方は、以下の記事がわかりやすかったのでお勧めです。
https://qiita.com/nishiha/items/ea46bfbe9ae47c823182

Spatial Self Attention

では一つ目のブロックである「Spatial Self Attention」を見ていきましょう。

Spatial Self Attentionは名前通り空間的な注意(アテンション)機構で、調整済みの「Video Latent Representation」はまずここを通ります。

# spatial branch
x_s = rearrange(x_m, "B (T S) C -> (B T) S C", T=self.d_t, S=self.d_s)
x_s = self.attn(x_s)
x_s = rearrange(x_s, "(B T) S C -> B (T S) C", T=self.d_t, S=self.d_s)
x = x + self.drop_path(gate_msa * x_s)

B×(T×S)×Cで表されている「Video Latent Representation」を(B×T)×S×Cに組み換え、空間方向にアテンションをかけています。

これにより、空間的な特徴を学習することが可能となっています。
さて、この時のデータ形式に注目してもらうと面白い事がわかります。
この時、テンソル形式はH×W×CがT×B分存在している、という形になります。
これは、PixArt-αのデータ入力形式(H×W×C)×Bとバッチの数が違うだけで、取り扱うデータの形そのものは全く同じになるわけですね。
そのため、このSpatial Self AttentionではPixArt-αで事前訓練した重みをそのまま利用できる、というのが二つ目の疑問の回答になります。

Temporal Self Attention

続いて、時間的な特徴を学習する「Temporal Self Attention」です。
なお、ここでいう時間とはFrame数の推移を表しています。

注意すべき点としてはSpatial Self Attentionとは、アテンションをかける方向が異なります。
具体的には以下の通り。

    # temporal branch
    x_t = rearrange(x, "B (T S) C -> (B S) T C", T=self.d_t, S=self.d_s)
    if tpe is not None:
        x_t = x_t + tpe
    x_t = self.attn_temp(x_t)
    x_t = rearrange(x_t, "(B S) T C -> B (T S) C", T=self.d_t, S=self.d_s)
    x = x + self.drop_path(gate_msa * x_t)


これにより、時間的な特徴を学習することが可能となります。

また、Temporal Self Attentionについては、データ構造的にPixArt-αの重みを利用する事はできないので、このブロックに関しては重み0で初期化されています。

Prompt Cross Attention

3つ目のブロックで、キャプションの埋め込み表現はここで取り込んでいます。

Spatial Self Attention, Temporal Self Attentionをかけ終えた「Video Latent Representation」に対し、y_embedderから得られたキャプションの埋め込み表現をCross Attentionしています

# cross attn
x = x + self.cross_attn(x, y, mask)

この辺りは普通の拡散モデルもよく見る構造ですね。

Point Wise Feed Forward

最後のブロックです。

x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))

Transformerで使われているFFN(Feed-Forward Network)と大体同じなので省略します。
このあたり詳しく知りたい人は以下の記事などを参考にしていただければと!
https://zenn.dev/attentionplease/articles/1a01887b783494

後処理

いよいよ最後の処理です。
後処理は

  • 現在のデノイジングステップ数に応じて出力を調整するFinalLayer
  • 出力されたテンソルをVAEで画像に変換可能なデータ形式に戻すunpatchify
    の二つにわかれています。

まずは、FinalLayerから見ていきましょう。

  • FinalLayer
class T2IFinalLayer(nn.Module):
    """
    The final layer of PixArt.
    """

    def __init__(self, hidden_size, num_patch, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
        self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
        self.out_channels = out_channels

    def forward(self, x, t):
        shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
        x = t2i_modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x

処理内容は前処理部分で行っていた、デノイジングステップ数毎に出力分布を調整する機構とほぼ同じで、現在のデノイジングステップ数に応じてシフト値とスケール値を決め、正規化された出力結果を調整し、線形結合層を最後に通して出力結果を得ています。

  • unpatchify
    続いて、unpatchifyについて。
def unpatchify(self, x):
	"""
	Args:
	x (torch.Tensor): of shape [B, N, C]   
	Return:
	    x (torch.Tensor): of shape [B, C_out, T, H, W]
	"""
	
	N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
	T_p, H_p, W_p = self.patch_size
	x = rearrange(
	    x,
	    "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
	    N_t=N_t,
	    N_h=N_h,
	    N_w=N_w,
	    T_p=T_p,
	    H_p=H_p,
	    W_p=W_p,
	    C_out=self.out_channels,
	)
	return x

FinalLayerを通して得られた出力結果を、VAEで画像に変換できるデータ形式に変える処理をしています。
Finallayerを通した時点ではデータ形式が以下のようになっておりVAEでEncodeした時とは形式が変わっているのがわかります。


テンソル形式をVAEで変換可能な元の形に戻すためには、二種類の数字を出す必要があります。

まず一種類目

N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]

この値はパッチ化されている潜在空間を何個あつめれば元のVAEで変換可能な形に戻せるか?を計算しています。
N_Tは元データを時間方向に何個分割したか(パッチを作ったか)、H_Tは縦方向に何個分割したか(パッチを作ったか)、W_Tは横方向に何個分割したか(パッチを作ったか)を表しているわけですね。

ではこのパッチはどのように表されるのでしょうか?
それが二種類目の数字になります。

T_p, H_p, W_p = self.patch_size

これはもう見てそのままですが、ひとつのパッチがT_p×H_p×W_pで表されることを示しています。

つまり、VAEで変換可能な元データは(N_t×N_h×N_w)×(T_p×H_p×W_p)で表すことができます。

画像であらわすとこのような形。

さて。この情報がわかればrearrangeを使ったテンソル操作でVAEで変換可能な元のデータ形式に戻すことが可能です。

	x = rearrange(
	    x,
	    "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
	    N_t=N_t,
	    N_h=N_h,
	    N_w=N_w,
	    T_p=T_p,
	    H_p=H_p,
	    W_p=W_p,
	    C_out=self.out_channels,
	)
	return x

変換後は以下のような形になります。

これで元のVAEで変換可能な形に戻ってきましたね。
あとはEncode時と同様、Latent毎に2D VAEを通して画像を生成し、T(Frame数)方向に揃えて連続表示させることで動画を作り出す、というのがOpen-Soraの一連の処理になります。

終わりに

今回はOpen-Soraの処理を、かなり細かい部分まで見ながら解説していきました。
今年は動画生成が熱くなる一年だと思っていますので、今後も動画生成に関する情報発信を積極的にしていければと思います!

ここまで読んでくださりありがとうございました!

Discussion