新しいVocoderであるVocosをJSUTで学習させてみる
概要
Vocoderの新たな仲間としてVocos
(paper)が加わりました。
Vocoder内でUpsamplingを一切行わず、ISTFTによりフレームレベル特徴量からサンプルレベルの音声波形に戻すことで高速な動作が期待できます。
(自分の知識の中で)似たモデルとして、ISTFTNetが挙げられますが、異なる点としてはISTFTNetでは最初にConvTransposeによる8倍x2のupsampleをした後にISTFTを適用しますが、Vocosではupsamplingを一切しないことです。
正直upsampleせずとも品質が出るのか疑問だったので、今回は実験していこうと思います。
著者実装(link)も公開されていましたが、共通の枠組みで他のモデルと比較したかったので自分で実装しました。
モデル構造
モデル構造は非常にシンプルで、Conv1d版のConvNeXtのレイヤーをそのままstackしたような構造をしています。
なので、実装も非常に簡単で以下のように実装ができます。
今回は読みやすさを重視してLayerNormを独自に定義していますが、上のGiHubの方のコードではLayerNormとしてnn.LayerNormを使用しているため一部実装が異なります。
コード
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.transforms import InverseSpectrogram
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.eps = eps
self.gamma = torch.nn.Parameter(torch.ones(1, channels, 1))
self.beta = torch.nn.Parameter(torch.zeros(1, channels, 1))
def forward(self, x: torch.Tensor):
mean = torch.mean(x, dim=1, keepdim=True)
variance = torch.mean((x - mean)**2, dim=1, keepdim=True)
x = (x - mean) * torch.rsqrt(variance + self.eps)
x = x * self.gamma + self.beta
return x
class ConvNeXtLayer(nn.Module):
def __init__(self, channel, h_channel, scale):
super().__init__()
self.dw_conv = nn.Conv1d(channel, channel, kernel_size=7, padding=3, groups=channel)
self.norm = LayerNorm(channel)
self.pw_conv1 = nn.Conv1d(channel, h_channel, 1)
self.pw_conv2 = nn.Conv1d(h_channel, channel, 1)
self.scale = nn.Parameter(torch.full(size=(1, channel, 1), fill_value=scale), requires_grad=True)
def forward(self, x):
res = x
x = self.dw_conv(x)
x = self.norm(x)
x = self.pw_conv1(x)
x = F.gelu(x)
x = self.pw_conv2(x)
x = self.scale * x
x = res + x
return x
class Vocos(nn.Module):
def __init__(self, in_channel, channel, h_channel, out_channel, num_layers, istft_config):
super().__init__()
self.pad = nn.ReflectionPad1d([1, 0])
self.in_conv = nn.Conv1d(in_channel, channel, kernel_size=7, padding=3)
self.norm = LayerNorm(channel)
scale = 1 / num_layers
self.layers = nn.ModuleList(
[
ConvNeXtLayer(channel, h_channel, scale)
for _ in range(num_layers)
]
)
self.norm_last = LayerNorm(channel)
self.out_conv = nn.Conv1d(channel, out_channel, 1)
self.istft = InverseSpectrogram(**istft_config)
def forward(self, x):
x = self.pad(x)
x = self.in_conv(x)
x = self.norm(x)
for layer in self.layers:
x = layer(x)
x = self.norm_last(x)
x = self.out_conv(x)
mag, phase = x.chunk(2, dim=1)
mag = torch.exp(mag)
s = mag * (phase.cos() + 1j * phase.sin())
o = self.istft(s).unsqueeze(1)
return o # [B, 1, T]
実験概要
本来であれば他のモデルとしっかり比較すべきで、論文のようにHiFi-GAN, iSTFTNet、BigVGANと比較すべきですが、今回はVocosの結果のみ載せます。
-
データセット:JSUT/basic5000
-
1000epoch, 612000 steps
-
batch size = 16
-
24kHz
-
n_fft = 1024
-
hop_length = 256
Loss
-
Mel
:生成音声をメルスペクトログラムに変換し、GTのメルスペクトログラムとのL1 loss
loss | plot |
---|---|
Mel |
- Melが下がっているから良しと思っている。
合成音声
音声は貼れないようなので、いくつかGitHubにサンプルファイルを上げときます。
正直アップサンプルせずにistftだけでどういった音声が出るか気になっていたが、想像していたより良い品質の音声が出てきた。
概ね良好なのだが、一部音声において音声の揺れが確認できたため、シンプルに学習が足りないか、別の工夫が必要だと思われる。
Discriminatorが異なる以外は同条件で学習させたHiFi-GANと比較してもHiFi-GANにはこの揺れがなかったため,改善方法を考えたい.
感想
学習についてモデルはフレームレベルのみしか扱わないため非常に高速に感じました。
同じ条件でのHiFi-GAN V1に比べ2倍以上の速度感(同じ条件とはGeneratorやDiscriminatorのアーキテクチャ以外の部分)
また、想像していたよりは良い品質だったため、今後パフォーマンスを意識したモデリングの際には候補の1つとしたいと思います。
とりあえずTTSモデルとしてVITSのDecoderをVocosにして学習させてみようと思います。
Discussion