792 lines
31 KiB
Python
792 lines
31 KiB
Python
# Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice/tree/main
|
|
"""
|
|
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
from abc import ABC
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from typing import Dict, Optional
|
|
|
|
import torch.nn as nn
|
|
from einops import pack, rearrange, repeat
|
|
from .matcha_components import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
|
from .matcha_transformer import BasicTransformerBlock
|
|
from omegaconf import DictConfig
|
|
|
|
|
|
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
|
assert mask.dtype == torch.bool
|
|
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
|
mask = mask.to(dtype)
|
|
# attention mask bias
|
|
# NOTE(Mddct): torch.finfo jit issues
|
|
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
|
mask = (1.0 - mask) * torch.finfo(dtype).min
|
|
return mask
|
|
|
|
|
|
def subsequent_chunk_mask(
|
|
size: int,
|
|
chunk_size: int,
|
|
num_left_chunks: int = -1,
|
|
device: torch.device = torch.device("cpu"),
|
|
) -> torch.Tensor:
|
|
"""Create mask for subsequent steps (size, size) with chunk size,
|
|
this is for streaming encoder
|
|
|
|
Args:
|
|
size (int): size of mask
|
|
chunk_size (int): size of chunk
|
|
num_left_chunks (int): number of left chunks
|
|
<0: use full chunk
|
|
>=0: use num_left_chunks
|
|
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
|
|
|
Returns:
|
|
torch.Tensor: mask
|
|
|
|
Examples:
|
|
>>> subsequent_chunk_mask(4, 2)
|
|
[[1, 1, 0, 0],
|
|
[1, 1, 0, 0],
|
|
[1, 1, 1, 1],
|
|
[1, 1, 1, 1]]
|
|
"""
|
|
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
|
# actually this is not needed after we have inference cache implemented, will remove it later
|
|
pos_idx = torch.arange(size, device=device)
|
|
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
|
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
|
return ret
|
|
|
|
def subsequent_mask(
|
|
size: int,
|
|
device: torch.device = torch.device("cpu"),
|
|
) -> torch.Tensor:
|
|
"""Create mask for subsequent steps (size, size).
|
|
|
|
This mask is used only in decoder which works in an auto-regressive mode.
|
|
This means the current step could only do attention with its left steps.
|
|
|
|
In encoder, fully attention is used when streaming is not necessary and
|
|
the sequence is not long. In this case, no attention mask is needed.
|
|
|
|
When streaming is need, chunk-based attention is used in encoder. See
|
|
subsequent_chunk_mask for the chunk-based attention mask.
|
|
|
|
Args:
|
|
size (int): size of mask
|
|
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
|
dtype (torch.device): result dtype
|
|
|
|
Returns:
|
|
torch.Tensor: mask
|
|
|
|
Examples:
|
|
>>> subsequent_mask(3)
|
|
[[1, 0, 0],
|
|
[1, 1, 0],
|
|
[1, 1, 1]]
|
|
"""
|
|
arange = torch.arange(size, device=device)
|
|
mask = arange.expand(size, size)
|
|
arange = arange.unsqueeze(-1)
|
|
mask = mask <= arange
|
|
return mask
|
|
|
|
|
|
def add_optional_chunk_mask(xs: torch.Tensor,
|
|
masks: torch.Tensor,
|
|
use_dynamic_chunk: bool,
|
|
use_dynamic_left_chunk: bool,
|
|
decoding_chunk_size: int,
|
|
static_chunk_size: int,
|
|
num_decoding_left_chunks: int,
|
|
enable_full_context: bool = True):
|
|
""" Apply optional mask for encoder.
|
|
|
|
Args:
|
|
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
|
mask (torch.Tensor): mask for xs, (B, 1, L)
|
|
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
|
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
|
training.
|
|
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
|
0: default for training, use random dynamic chunk.
|
|
<0: for decoding, use full chunk.
|
|
>0: for decoding, use fixed chunk size as set.
|
|
static_chunk_size (int): chunk size for static chunk training/decoding
|
|
if it's greater than 0, if use_dynamic_chunk is true,
|
|
this parameter will be ignored
|
|
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
|
the chunk size is decoding_chunk_size.
|
|
>=0: use num_decoding_left_chunks
|
|
<0: use all left chunks
|
|
enable_full_context (bool):
|
|
True: chunk size is either [1, 25] or full context(max_len)
|
|
False: chunk size ~ U[1, 25]
|
|
|
|
Returns:
|
|
torch.Tensor: chunk mask of the input xs.
|
|
"""
|
|
# Whether to use chunk mask or not
|
|
if use_dynamic_chunk:
|
|
max_len = xs.size(1)
|
|
if decoding_chunk_size < 0:
|
|
chunk_size = max_len
|
|
num_left_chunks = -1
|
|
elif decoding_chunk_size > 0:
|
|
chunk_size = decoding_chunk_size
|
|
num_left_chunks = num_decoding_left_chunks
|
|
else:
|
|
# chunk size is either [1, 25] or full context(max_len).
|
|
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
|
# delay, the maximum frame is 100 / 4 = 25.
|
|
chunk_size = torch.randint(1, max_len, (1, )).item()
|
|
num_left_chunks = -1
|
|
if chunk_size > max_len // 2 and enable_full_context:
|
|
chunk_size = max_len
|
|
else:
|
|
chunk_size = chunk_size % 25 + 1
|
|
if use_dynamic_left_chunk:
|
|
max_left_chunks = (max_len - 1) // chunk_size
|
|
num_left_chunks = torch.randint(0, max_left_chunks,
|
|
(1, )).item()
|
|
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
|
num_left_chunks,
|
|
xs.device) # (L, L)
|
|
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
chunk_masks = masks & chunk_masks # (B, L, L)
|
|
elif static_chunk_size > 0:
|
|
num_left_chunks = num_decoding_left_chunks
|
|
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
|
num_left_chunks,
|
|
xs.device) # (L, L)
|
|
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
chunk_masks = masks & chunk_masks # (B, L, L)
|
|
else:
|
|
chunk_masks = masks
|
|
return chunk_masks
|
|
|
|
|
|
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
|
"""Make mask tensor containing indices of padded part.
|
|
|
|
See description of make_non_pad_mask.
|
|
|
|
Args:
|
|
lengths (torch.Tensor): Batch of lengths (B,).
|
|
Returns:
|
|
torch.Tensor: Mask tensor containing indices of padded part.
|
|
|
|
Examples:
|
|
>>> lengths = [5, 3, 2]
|
|
>>> make_pad_mask(lengths)
|
|
masks = [[0, 0, 0, 0 ,0],
|
|
[0, 0, 0, 1, 1],
|
|
[0, 0, 1, 1, 1]]
|
|
"""
|
|
batch_size = lengths.size(0)
|
|
max_len = max_len if max_len > 0 else lengths.max().item()
|
|
seq_range = torch.arange(0,
|
|
max_len,
|
|
dtype=torch.int64,
|
|
device=lengths.device)
|
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
|
seq_length_expand = lengths.unsqueeze(-1)
|
|
mask = seq_range_expand >= seq_length_expand
|
|
return mask
|
|
|
|
# Causal
|
|
class Transpose(torch.nn.Module):
|
|
def __init__(self, dim0: int, dim1: int):
|
|
super().__init__()
|
|
self.dim0 = dim0
|
|
self.dim1 = dim1
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
x = torch.transpose(x, self.dim0, self.dim1)
|
|
return x
|
|
|
|
class CausalBlock1D(Block1D):
|
|
def __init__(self, dim: int, dim_out: int):
|
|
super(CausalBlock1D, self).__init__(dim, dim_out)
|
|
self.block = torch.nn.Sequential(
|
|
CausalConv1d(dim, dim_out, 3),
|
|
Transpose(1, 2),
|
|
nn.LayerNorm(dim_out),
|
|
Transpose(1, 2),
|
|
nn.Mish(),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
|
output = self.block(x * mask)
|
|
return output * mask
|
|
|
|
class CausalResnetBlock1D(ResnetBlock1D):
|
|
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
|
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
|
self.block1 = CausalBlock1D(dim, dim_out)
|
|
self.block2 = CausalBlock1D(dim_out, dim_out)
|
|
|
|
class CausalConv1d(torch.nn.Conv1d):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: int,
|
|
stride: int = 1,
|
|
dilation: int = 1,
|
|
groups: int = 1,
|
|
bias: bool = True,
|
|
padding_mode: str = 'zeros',
|
|
device=None,
|
|
dtype=None
|
|
) -> None:
|
|
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
|
kernel_size, stride,
|
|
padding=0, dilation=dilation,
|
|
groups=groups, bias=bias,
|
|
padding_mode=padding_mode,
|
|
device=device, dtype=dtype)
|
|
assert stride == 1
|
|
self.causal_padding = (kernel_size - 1, 0)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
x = F.pad(x, self.causal_padding)
|
|
x = super(CausalConv1d, self).forward(x)
|
|
return x
|
|
|
|
|
|
class BASECFM(torch.nn.Module, ABC):
|
|
def __init__(
|
|
self,
|
|
n_feats,
|
|
cfm_params,
|
|
n_spks=1,
|
|
spk_emb_dim=128,
|
|
):
|
|
super().__init__()
|
|
self.n_feats = n_feats
|
|
self.n_spks = n_spks
|
|
self.spk_emb_dim = spk_emb_dim
|
|
self.solver = cfm_params.solver
|
|
if hasattr(cfm_params, "sigma_min"):
|
|
self.sigma_min = cfm_params.sigma_min
|
|
else:
|
|
self.sigma_min = 1e-4
|
|
|
|
self.estimator = None
|
|
|
|
@torch.inference_mode()
|
|
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
|
"""Forward diffusion
|
|
|
|
Args:
|
|
mu (torch.Tensor): output of encoder
|
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
mask (torch.Tensor): output_mask
|
|
shape: (batch_size, 1, mel_timesteps)
|
|
n_timesteps (int): number of diffusion steps
|
|
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
|
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
shape: (batch_size, spk_emb_dim)
|
|
cond: Not used but kept for future purposes
|
|
|
|
Returns:
|
|
sample: generated mel-spectrogram
|
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
"""
|
|
z = torch.randn_like(mu) * temperature
|
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
|
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
|
|
|
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
|
"""
|
|
Fixed euler solver for ODEs.
|
|
Args:
|
|
x (torch.Tensor): random noise
|
|
t_span (torch.Tensor): n_timesteps interpolated
|
|
shape: (n_timesteps + 1,)
|
|
mu (torch.Tensor): output of encoder
|
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
mask (torch.Tensor): output_mask
|
|
shape: (batch_size, 1, mel_timesteps)
|
|
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
shape: (batch_size, spk_emb_dim)
|
|
cond: Not used but kept for future purposes
|
|
"""
|
|
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
|
|
|
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
|
# Or in future might add like a return_all_steps flag
|
|
sol = []
|
|
|
|
for step in range(1, len(t_span)):
|
|
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
|
|
|
x = x + dt * dphi_dt
|
|
t = t + dt
|
|
sol.append(x)
|
|
if step < len(t_span) - 1:
|
|
dt = t_span[step + 1] - t
|
|
|
|
return sol[-1]
|
|
|
|
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
|
"""Computes diffusion loss
|
|
|
|
Args:
|
|
x1 (torch.Tensor): Target
|
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
mask (torch.Tensor): target mask
|
|
shape: (batch_size, 1, mel_timesteps)
|
|
mu (torch.Tensor): output of encoder
|
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
|
shape: (batch_size, spk_emb_dim)
|
|
|
|
Returns:
|
|
loss: conditional flow matching loss
|
|
y: conditional flow
|
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
"""
|
|
b, _, t = mu.shape
|
|
|
|
# random timestep
|
|
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
|
# sample noise p(x_0)
|
|
z = torch.randn_like(x1)
|
|
|
|
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
|
u = x1 - (1 - self.sigma_min) * z
|
|
|
|
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
|
|
torch.sum(mask) * u.shape[1]
|
|
)
|
|
return loss, y
|
|
|
|
|
|
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
|
"""Make mask tensor containing indices of padded part.
|
|
|
|
See description of make_non_pad_mask.
|
|
|
|
Args:
|
|
lengths (torch.Tensor): Batch of lengths (B,).
|
|
Returns:
|
|
torch.Tensor: Mask tensor containing indices of padded part.
|
|
|
|
Examples:
|
|
>>> lengths = [5, 3, 2]
|
|
>>> make_pad_mask(lengths)
|
|
masks = [[0, 0, 0, 0 ,0],
|
|
[0, 0, 0, 1, 1],
|
|
[0, 0, 1, 1, 1]]
|
|
"""
|
|
batch_size = lengths.size(0)
|
|
max_len = max_len if max_len > 0 else lengths.max().item()
|
|
seq_range = torch.arange(0,
|
|
max_len,
|
|
dtype=torch.int64,
|
|
device=lengths.device)
|
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
|
seq_length_expand = lengths.unsqueeze(-1)
|
|
mask = seq_range_expand >= seq_length_expand
|
|
return mask
|
|
|
|
|
|
class ConditionalDecoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
causal=False,
|
|
channels=(256, 256),
|
|
dropout=0.05,
|
|
attention_head_dim=64,
|
|
n_blocks=1,
|
|
num_mid_blocks=2,
|
|
num_heads=4,
|
|
act_fn="snake",
|
|
gradient_checkpointing=True,
|
|
):
|
|
"""
|
|
This decoder requires an input with the same shape of the target. So, if your text content
|
|
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
|
"""
|
|
super().__init__()
|
|
channels = tuple(channels)
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.causal = causal
|
|
self.static_chunk_size = 2 * 25 * 2 # 2*input_frame_rate*token_mel_ratio
|
|
self.gradient_checkpointing = gradient_checkpointing
|
|
|
|
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
|
time_embed_dim = channels[0] * 4
|
|
self.time_mlp = TimestepEmbedding(
|
|
in_channels=in_channels,
|
|
time_embed_dim=time_embed_dim,
|
|
act_fn="silu",
|
|
)
|
|
self.down_blocks = nn.ModuleList([])
|
|
self.mid_blocks = nn.ModuleList([])
|
|
self.up_blocks = nn.ModuleList([])
|
|
|
|
output_channel = in_channels
|
|
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
|
input_channel = output_channel
|
|
output_channel = channels[i]
|
|
is_last = i == len(channels) - 1
|
|
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
|
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
transformer_blocks = nn.ModuleList(
|
|
[
|
|
BasicTransformerBlock(
|
|
dim=output_channel,
|
|
num_attention_heads=num_heads,
|
|
attention_head_dim=attention_head_dim,
|
|
dropout=dropout,
|
|
activation_fn=act_fn,
|
|
)
|
|
for _ in range(n_blocks)
|
|
]
|
|
)
|
|
downsample = (
|
|
Downsample1D(output_channel) if not is_last else
|
|
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
)
|
|
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
|
|
|
for _ in range(num_mid_blocks):
|
|
input_channel = channels[-1]
|
|
out_channels = channels[-1]
|
|
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
|
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
|
|
transformer_blocks = nn.ModuleList(
|
|
[
|
|
BasicTransformerBlock(
|
|
dim=output_channel,
|
|
num_attention_heads=num_heads,
|
|
attention_head_dim=attention_head_dim,
|
|
dropout=dropout,
|
|
activation_fn=act_fn,
|
|
)
|
|
for _ in range(n_blocks)
|
|
]
|
|
)
|
|
|
|
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
|
|
|
channels = channels[::-1] + (channels[0],)
|
|
for i in range(len(channels) - 1):
|
|
input_channel = channels[i] * 2
|
|
output_channel = channels[i + 1]
|
|
is_last = i == len(channels) - 2
|
|
resnet = CausalResnetBlock1D(
|
|
dim=input_channel,
|
|
dim_out=output_channel,
|
|
time_emb_dim=time_embed_dim,
|
|
) if self.causal else ResnetBlock1D(
|
|
dim=input_channel,
|
|
dim_out=output_channel,
|
|
time_emb_dim=time_embed_dim,
|
|
)
|
|
transformer_blocks = nn.ModuleList(
|
|
[
|
|
BasicTransformerBlock(
|
|
dim=output_channel,
|
|
num_attention_heads=num_heads,
|
|
attention_head_dim=attention_head_dim,
|
|
dropout=dropout,
|
|
activation_fn=act_fn,
|
|
)
|
|
for _ in range(n_blocks)
|
|
]
|
|
)
|
|
upsample = (
|
|
Upsample1D(output_channel, use_conv_transpose=True)
|
|
if not is_last
|
|
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
)
|
|
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
|
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
|
|
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
|
self.initialize_weights()
|
|
|
|
def initialize_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv1d):
|
|
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.GroupNorm):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
|
"""Forward pass of the UNet1DConditional model.
|
|
|
|
Args:
|
|
x (torch.Tensor): shape (batch_size, in_channels, time)
|
|
mask (_type_): shape (batch_size, 1, time)
|
|
t (_type_): shape (batch_size)
|
|
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
|
cond (_type_, optional): placeholder for future use. Defaults to None.
|
|
|
|
Raises:
|
|
ValueError: _description_
|
|
ValueError: _description_
|
|
|
|
Returns:
|
|
_type_: _description_
|
|
"""
|
|
t = self.time_embeddings(t)
|
|
t = t.to(x.dtype)
|
|
t = self.time_mlp(t)
|
|
x = pack([x, mu], "b * t")[0]
|
|
mask = mask.to(x.dtype)
|
|
if spks is not None:
|
|
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
|
x = pack([x, spks], "b * t")[0]
|
|
if cond is not None:
|
|
x = pack([x, cond], "b * t")[0]
|
|
|
|
hiddens = []
|
|
masks = [mask]
|
|
for resnet, transformer_blocks, downsample in self.down_blocks:
|
|
mask_down = masks[-1]
|
|
x = resnet(x, mask_down, t)
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
|
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
|
for transformer_block in transformer_blocks:
|
|
if self.gradient_checkpointing and self.training:
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs)
|
|
return custom_forward
|
|
x = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(transformer_block),
|
|
x,
|
|
attn_mask,
|
|
t,
|
|
)
|
|
else:
|
|
x = transformer_block(
|
|
hidden_states=x,
|
|
attention_mask=attn_mask,
|
|
timestep=t,
|
|
)
|
|
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
hiddens.append(x) # Save hidden states for skip connections
|
|
x = downsample(x * mask_down)
|
|
masks.append(mask_down[:, :, ::2])
|
|
masks = masks[:-1]
|
|
mask_mid = masks[-1]
|
|
|
|
for resnet, transformer_blocks in self.mid_blocks:
|
|
x = resnet(x, mask_mid, t)
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
|
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
|
for transformer_block in transformer_blocks:
|
|
if self.gradient_checkpointing and self.training:
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs)
|
|
return custom_forward
|
|
x = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(transformer_block),
|
|
x,
|
|
attn_mask,
|
|
t,
|
|
)
|
|
else:
|
|
x = transformer_block(
|
|
hidden_states=x,
|
|
attention_mask=attn_mask,
|
|
timestep=t,
|
|
)
|
|
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
|
|
for resnet, transformer_blocks, upsample in self.up_blocks:
|
|
mask_up = masks.pop()
|
|
skip = hiddens.pop()
|
|
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
|
x = resnet(x, mask_up, t)
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
|
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
|
for transformer_block in transformer_blocks:
|
|
if self.gradient_checkpointing and self.training:
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs)
|
|
return custom_forward
|
|
x = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(transformer_block),
|
|
x,
|
|
attn_mask,
|
|
t,
|
|
)
|
|
else:
|
|
x = transformer_block(
|
|
hidden_states=x,
|
|
attention_mask=attn_mask,
|
|
timestep=t,
|
|
)
|
|
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
x = upsample(x * mask_up)
|
|
x = self.final_block(x, mask_up)
|
|
output = self.final_proj(x * mask_up)
|
|
return output * mask
|
|
|
|
|
|
class ConditionalCFM(BASECFM):
|
|
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64):
|
|
super().__init__(
|
|
n_feats=in_channels,
|
|
cfm_params=cfm_params,
|
|
n_spks=n_spks,
|
|
spk_emb_dim=spk_emb_dim,
|
|
)
|
|
self.t_scheduler = cfm_params.t_scheduler
|
|
self.training_cfg_rate = cfm_params.training_cfg_rate
|
|
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
|
|
|
@torch.inference_mode()
|
|
def forward(self, estimator, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
|
"""Forward diffusion
|
|
|
|
Args:
|
|
mu (torch.Tensor): output of encoder
|
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
mask (torch.Tensor): output_mask
|
|
shape: (batch_size, 1, mel_timesteps)
|
|
n_timesteps (int): number of diffusion steps
|
|
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
|
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
shape: (batch_size, spk_emb_dim)
|
|
cond: Not used but kept for future purposes
|
|
|
|
Returns:
|
|
sample: generated mel-spectrogram
|
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
"""
|
|
z = torch.randn_like(mu) * temperature
|
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
|
if self.t_scheduler == 'cosine':
|
|
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
|
return self.solve_euler(estimator, z, t_span=t_span.to(mu.dtype), mu=mu, mask=mask, spks=spks, cond=cond)
|
|
|
|
def solve_euler(self, estimator, x, t_span, mu, mask, spks, cond):
|
|
"""
|
|
Fixed euler solver for ODEs.
|
|
Args:
|
|
x (torch.Tensor): random noise
|
|
t_span (torch.Tensor): n_timesteps interpolated
|
|
shape: (n_timesteps + 1,)
|
|
mu (torch.Tensor): output of encoder
|
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
mask (torch.Tensor): output_mask
|
|
shape: (batch_size, 1, mel_timesteps)
|
|
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
shape: (batch_size, spk_emb_dim)
|
|
cond: Not used but kept for future purposes
|
|
"""
|
|
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
|
|
|
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
|
# Or in future might add like a return_all_steps flag
|
|
sol = []
|
|
|
|
for step in range(1, len(t_span)):
|
|
dphi_dt = estimator(x, mask, mu, t, spks, cond)
|
|
# Classifier-Free Guidance inference introduced in VoiceBox
|
|
if self.inference_cfg_rate > 0:
|
|
cfg_dphi_dt = estimator(
|
|
x, mask,
|
|
torch.zeros_like(mu), t,
|
|
torch.zeros_like(spks) if spks is not None else None,
|
|
cond=cond
|
|
)
|
|
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
|
|
self.inference_cfg_rate * cfg_dphi_dt)
|
|
x = x + dt * dphi_dt
|
|
t = t + dt
|
|
sol.append(x)
|
|
if step < len(t_span) - 1:
|
|
dt = t_span[step + 1] - t
|
|
|
|
return sol[-1]
|
|
|
|
def compute_loss(self, estimator, x1, mask, mu, spks=None, cond=None):
|
|
"""Computes diffusion loss
|
|
|
|
Args:
|
|
x1 (torch.Tensor): Target
|
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
mask (torch.Tensor): target mask
|
|
shape: (batch_size, 1, mel_timesteps)
|
|
mu (torch.Tensor): output of encoder
|
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
|
shape: (batch_size, spk_emb_dim)
|
|
|
|
Returns:
|
|
loss: conditional flow matching loss
|
|
y: conditional flow
|
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
"""
|
|
org_dtype = x1.dtype
|
|
|
|
b, _, t = mu.shape
|
|
# random timestep
|
|
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
|
if self.t_scheduler == 'cosine':
|
|
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
|
# sample noise p(x_0)
|
|
z = torch.randn_like(x1)
|
|
|
|
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
|
u = x1 - (1 - self.sigma_min) * z
|
|
|
|
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
|
if self.training_cfg_rate > 0:
|
|
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
|
mu = mu * cfg_mask.view(-1, 1, 1)
|
|
if spks is not None:
|
|
spks = spks * cfg_mask.view(-1, 1)
|
|
if cond is not None:
|
|
cond = cond * cfg_mask.view(-1, 1, 1)
|
|
|
|
pred = estimator(y, mask, mu, t.squeeze(), spks, cond)
|
|
pred = pred.float()
|
|
u = u.float()
|
|
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
|
loss = loss.to(org_dtype)
|
|
return loss, y
|