2907 lines
108 KiB
Python
2907 lines
108 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT license.
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
# activation_checkpointing.py
|
|
"""helper function for activation checkpointing"""
|
|
|
|
from typing import Union, Dict, Callable
|
|
from functools import partial
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
checkpoint_wrapper,
|
|
offload_wrapper,
|
|
CheckpointImpl,
|
|
)
|
|
|
|
|
|
# utils.py
|
|
"""cascade basic blocks"""
|
|
|
|
import math
|
|
import backoff
|
|
import random
|
|
import numpy as np
|
|
from typing import Optional, Tuple, Union
|
|
import torch
|
|
from torch import nn
|
|
from torch import Tensor
|
|
import torch.nn.functional as F
|
|
|
|
|
|
# conformer_encoder.py
|
|
"""ConformerEncoder Module"""
|
|
|
|
from typing import Optional, Tuple, List, Literal
|
|
import abc
|
|
import math
|
|
import numpy as np
|
|
|
|
import torch
|
|
from torch import nn, Tensor
|
|
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
|
|
|
|
|
|
# activation_checkpointing.py
|
|
def validate_checkpointing_config(activation_checkpointing):
|
|
"""validate activation checkpointing configuration"""
|
|
if isinstance(activation_checkpointing, str):
|
|
assert activation_checkpointing in (
|
|
"",
|
|
"checkpoint",
|
|
"offload",
|
|
), "activation_checkpointing has to be a dict or a str in ('', 'checkpoint', 'offload')."
|
|
elif isinstance(activation_checkpointing, dict):
|
|
assert activation_checkpointing.get("module", "transformer") in (
|
|
"transformer",
|
|
"attention",
|
|
), "module in activation_checkpointing has to be in ('transformer', 'attention')."
|
|
else:
|
|
raise ValueError("activation_checkpointing has to be a str or dict.")
|
|
|
|
|
|
def embedding_checkpoint_wrapper(
|
|
activation_checkpointing: Union[str, Dict],
|
|
) -> Callable:
|
|
"""return encoder embedding activation checkpoint wrapper"""
|
|
validate_checkpointing_config(activation_checkpointing)
|
|
|
|
if isinstance(activation_checkpointing, str):
|
|
if activation_checkpointing:
|
|
if activation_checkpointing == "offload":
|
|
return offload_wrapper
|
|
return partial(checkpoint_wrapper)
|
|
return lambda x: x
|
|
|
|
if isinstance(activation_checkpointing, dict):
|
|
enabled = activation_checkpointing.get("embed", False)
|
|
if enabled:
|
|
offloading = activation_checkpointing.get("offload", False)
|
|
if offloading:
|
|
return offload_wrapper
|
|
impl = (
|
|
CheckpointImpl.REENTRANT
|
|
if activation_checkpointing.get("reentrant", False)
|
|
else CheckpointImpl.NO_REENTRANT
|
|
)
|
|
return partial(checkpoint_wrapper, checkpoint_impl=impl)
|
|
return lambda x: x
|
|
raise ValueError("Invalid activation_checkpointing config")
|
|
|
|
|
|
def encoder_checkpoint_wrapper(
|
|
activation_checkpointing: Union[str, Dict],
|
|
layer_cls: type,
|
|
idx: int = 0,
|
|
) -> Callable:
|
|
"""return encoder activation checkpoint wrapper"""
|
|
validate_checkpointing_config(activation_checkpointing)
|
|
|
|
if isinstance(activation_checkpointing, str):
|
|
if activation_checkpointing:
|
|
if activation_checkpointing == "offload":
|
|
return offload_wrapper
|
|
return partial(checkpoint_wrapper)
|
|
return lambda x: x
|
|
|
|
if isinstance(activation_checkpointing, dict):
|
|
target_layer_cls = activation_checkpointing.get("module", "transformer")
|
|
if target_layer_cls.lower() == "transformer":
|
|
target_layer_cls = (
|
|
"EncoderLayer",
|
|
"ConformerEncoderLayer",
|
|
)
|
|
elif target_layer_cls.lower() == "attention":
|
|
target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention")
|
|
checkpointing_interval = activation_checkpointing.get("interval", 1)
|
|
offloading = activation_checkpointing.get("offload", False)
|
|
impl = (
|
|
CheckpointImpl.REENTRANT
|
|
if activation_checkpointing.get("reentrant", True)
|
|
else CheckpointImpl.NO_REENTRANT
|
|
)
|
|
|
|
if idx % checkpointing_interval == 0 and layer_cls.__name__ in target_layer_cls:
|
|
if offloading:
|
|
return offload_wrapper
|
|
return partial(checkpoint_wrapper, checkpoint_impl=impl)
|
|
return lambda x: x
|
|
|
|
raise ValueError("Invalid activation_checkpointing config")
|
|
|
|
|
|
def attn_checkpointing(activation_checkpointing: Union[str, Dict], i) -> Union[str, Dict]:
|
|
"""return activation checkpointing config for attention layer"""
|
|
if isinstance(activation_checkpointing, str):
|
|
return ""
|
|
|
|
if isinstance(activation_checkpointing, dict):
|
|
target_layer_cls = activation_checkpointing.get("module", "transformer")
|
|
checkpointing_interval = activation_checkpointing.get("interval", 1)
|
|
if target_layer_cls == "attention" and i % checkpointing_interval == 0:
|
|
return activation_checkpointing
|
|
return ""
|
|
|
|
raise ValueError("Invalid activation_checkpointing config")
|
|
|
|
|
|
# utils.py
|
|
class Block(nn.Module):
|
|
"""Block abstract module"""
|
|
|
|
def __init__(self, input_size, output_size):
|
|
super().__init__()
|
|
self.input_size = input_size
|
|
self.output_size = output_size
|
|
|
|
def get_activation(name="relu"):
|
|
"""Select an activation function by name
|
|
|
|
Args:
|
|
name: str
|
|
activation function name,
|
|
one of ["relu", "gelu", "swish", "sigmoid"],
|
|
default "relu".
|
|
"""
|
|
name = name.lower()
|
|
if name == "relu":
|
|
return nn.ReLU(inplace=True)
|
|
if name == "gelu":
|
|
return nn.GELU()
|
|
if name == "swish":
|
|
return Swish()
|
|
if name == "sigmoid":
|
|
return torch.nn.Sigmoid()
|
|
return nn.Identity()
|
|
|
|
def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0):
|
|
"""
|
|
The function is very important for Transformer Transducer Streaming mode
|
|
Args:
|
|
xs_len (int): sequence length
|
|
chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45]
|
|
left_window (int): how many left chunks can be seen
|
|
right_window (int): how many right chunks can be seen. It is used for chunk overlap model.
|
|
Returns:
|
|
mask (torch.Tensor): a mask tensor for streaming model
|
|
Torch 1.0.1
|
|
tensor([[1., 1., 0., 0.],
|
|
[0., 1., 1., 0.],
|
|
[0., 0., 1., 1.]])
|
|
Torch 1.4.1
|
|
tensor([[True., True., False., False.],
|
|
[False., True., True., False.],
|
|
[False., False., True., True.]])
|
|
"""
|
|
chunk_start_idx = torch.Tensor(
|
|
chunk_start_idx
|
|
).long() # first idx of each chunk, such as [0,18,36,48].
|
|
start_pad = torch.nn.functional.pad(
|
|
chunk_start_idx, (1, 0)
|
|
) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48]
|
|
end_pad = torch.nn.functional.pad(
|
|
chunk_start_idx, (0, 1), value=x_len
|
|
) # append x_len to the end, so it becomes [0,18,36,48, x_len]
|
|
seq_range = torch.arange(0, x_len).unsqueeze(-1) # seq_range size: [x_len, 1]
|
|
idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] # idx size: [x_len]
|
|
boundary = end_pad[idx] # boundary size: [x_len]
|
|
seq_range_expand = (
|
|
torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
|
|
) # seq_range_expand size [x_len, x_len]
|
|
idx_left = idx - left_window
|
|
idx_left[idx_left < 0] = 0
|
|
boundary_left = start_pad[idx_left]
|
|
mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
|
|
idx_right = idx + right_window
|
|
idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
|
|
boundary_right = end_pad[idx_right]
|
|
mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
|
|
return mask_left & mask_right
|
|
|
|
class Swish(nn.Module):
|
|
"""Implement Swish activation module.
|
|
From https://arxiv.org/pdf/2005.03191.pdf
|
|
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.act_fn = nn.Sigmoid()
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
"""Apply Swish function
|
|
|
|
Args:
|
|
x: torch.Tensor
|
|
Input.
|
|
"""
|
|
return x * self.act_fn(x)
|
|
|
|
class GLU(nn.Module):
|
|
"""Implement Gated Linear Unit (GLU) module"""
|
|
|
|
def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None:
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.act_name = act_name.lower()
|
|
|
|
if self.act_name == "relu":
|
|
self.act_fn = nn.ReLU(inplace=True)
|
|
elif self.act_name == "gelu":
|
|
self.act_fn = nn.GELU()
|
|
elif self.act_name == "swish":
|
|
self.act_fn = Swish()
|
|
elif self.act_name == "sigmoid":
|
|
self.act_fn = nn.Sigmoid()
|
|
else:
|
|
self.act_fn = nn.Identity()
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
"""GLU forward
|
|
Apply Swish function on the first half of input matrices
|
|
with sigmoid of the second half.
|
|
|
|
Args:
|
|
x: torch.Tensor
|
|
Input.
|
|
|
|
"""
|
|
half_x, gate = x.chunk(2, dim=self.dim)
|
|
return half_x * self.act_fn(gate)
|
|
|
|
# TODO: Abdel, this can be improved using GLU module
|
|
class GLUPointWiseConv(nn.Module):
|
|
"""GLUPointWiseConv module
|
|
used for conformer architecture,
|
|
for more details see:
|
|
https://arxiv.org/pdf/2005.08100v1.pdf
|
|
|
|
Args:
|
|
input_dim: int
|
|
input channel size.
|
|
output_dim: int
|
|
output channel size.
|
|
kernel_size: int
|
|
kernel size
|
|
glu_type: str, optional
|
|
activation function one of
|
|
["sigmoid", "relu", "gelu"]
|
|
default "sigmoid".
|
|
bias_in_glu: bool, optional
|
|
use addtive bias in glu
|
|
causal: bool, optional
|
|
if set to True, padding is set to the half of
|
|
kernel size, ie, convolution can't see future frames.
|
|
default False.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self, input_dim, output_dim, kernel_size, glu_type="sigmoid", bias_in_glu=True, causal=False
|
|
):
|
|
super().__init__()
|
|
|
|
self.glu_type = glu_type
|
|
self.output_dim = output_dim
|
|
self.bias_in_glu = bias_in_glu
|
|
if causal:
|
|
self.ext_pw_conv_1d = nn.Conv1d(
|
|
input_dim, output_dim * 2, kernel_size, 1, padding=(kernel_size - 1)
|
|
)
|
|
else:
|
|
self.ext_pw_conv_1d = nn.Conv1d(
|
|
input_dim, output_dim * 2, kernel_size, 1, padding=(kernel_size - 1) // 2
|
|
)
|
|
|
|
if glu_type == "sigmoid":
|
|
self.glu_act = nn.Sigmoid()
|
|
elif glu_type == "relu":
|
|
self.glu_act = nn.ReLU()
|
|
elif glu_type == "gelu":
|
|
self.glu_act = nn.GELU()
|
|
elif glu_type == "swish":
|
|
self.glu_act = Swish()
|
|
else:
|
|
raise ValueError(f"Unsupported activation type {self.glu_act}")
|
|
|
|
if bias_in_glu:
|
|
self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1))
|
|
self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1))
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Args:
|
|
x: torch.Tensor
|
|
input tensor
|
|
"""
|
|
# to be consistent with GLULinear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
|
x = x.permute([0, 2, 1])
|
|
x = self.ext_pw_conv_1d(x)
|
|
if self.glu_type == "bilinear":
|
|
if self.bias_in_glu:
|
|
x = (x[:, 0 : self.output_dim, :] + self.b1) * (
|
|
x[:, self.output_dim : self.output_dim * 2, :] + self.b2
|
|
)
|
|
else:
|
|
x = (x[:, 0 : self.output_dim, :]) * (
|
|
x[:, self.output_dim : self.output_dim * 2, :]
|
|
)
|
|
else:
|
|
if self.bias_in_glu:
|
|
x = (x[:, 0 : self.output_dim, :] + self.b1) * self.glu_act(
|
|
x[:, self.output_dim : self.output_dim * 2, :] + self.b2
|
|
)
|
|
else:
|
|
x = (x[:, 0 : self.output_dim, :]) * self.glu_act(
|
|
x[:, self.output_dim : self.output_dim * 2, :]
|
|
)
|
|
|
|
x = x.permute([0, 2, 1])
|
|
return x
|
|
|
|
|
|
class DepthWiseSeperableConv1d(nn.Module):
|
|
"""DepthWiseSeperableConv1d module used in Convnet module
|
|
for the conformer, for more details see:
|
|
https://arxiv.org/pdf/2005.08100v1.pdf
|
|
|
|
Args:
|
|
input_dim: int
|
|
input channel size.
|
|
depthwise_seperable_out_channel: int
|
|
if set different to 0, the number of depthwise_seperable_out_channel
|
|
will be used as a channel_out of the second conv1d layer.
|
|
otherwise, it equal to 0, the second conv1d layer is skipped.
|
|
kernel_size: int
|
|
kernel_size
|
|
depthwise_multiplier: int
|
|
number of input_dim channels duplication. this value
|
|
will be used to compute the hidden channels of the Conv1D.
|
|
padding: int, optional
|
|
padding for the conv1d,
|
|
default: 0.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim,
|
|
depthwise_seperable_out_channel,
|
|
kernel_size,
|
|
depthwise_multiplier,
|
|
padding=0,
|
|
):
|
|
super().__init__()
|
|
|
|
self.dw_conv = nn.Conv1d(
|
|
input_dim,
|
|
input_dim * depthwise_multiplier,
|
|
kernel_size,
|
|
1,
|
|
padding=padding,
|
|
groups=input_dim,
|
|
)
|
|
|
|
if depthwise_seperable_out_channel != 0:
|
|
self.pw_conv = nn.Conv1d(
|
|
input_dim * depthwise_multiplier, depthwise_seperable_out_channel, 1, 1, 0
|
|
)
|
|
else:
|
|
self.pw_conv = nn.Identity()
|
|
self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
|
|
|
|
def forward(self, x):
|
|
"""
|
|
|
|
Args:
|
|
x: torch.Tensor
|
|
input tensor
|
|
"""
|
|
x = self.dw_conv(x)
|
|
if self.depthwise_seperable_out_channel != 0:
|
|
x = self.pw_conv(x)
|
|
return x
|
|
|
|
|
|
class ConvModule(nn.Module):
|
|
"""ConvModule Module for the conformer block.
|
|
for more details see:
|
|
https://arxiv.org/pdf/2005.08100v1.pdf
|
|
|
|
Args:
|
|
input_dim: int
|
|
input channel size.
|
|
ext_pw_out_channel: int
|
|
if > 0, ext_pw_out_channel is a dim channel size
|
|
for the last pointwise conv after swish activation.
|
|
depthwise_seperable_out_channel: int
|
|
if set different to 0, the number of depthwise_seperable_out_channel
|
|
will be used as a channel_out of the second conv1d layer.
|
|
otherwise, it equal to 0, the second conv1d layer is skipped.
|
|
ext_pw_kernel_size: int
|
|
kernel size of the conv pointwise of the conformer.
|
|
kernel_size: int
|
|
kernel size.
|
|
depthwise_multiplier: int
|
|
number of input_dim channels duplication. this value
|
|
will be used to compute the hidden channels of the Conv1D.
|
|
dropout_rate: float
|
|
dropout rate.
|
|
causal: bool, optional
|
|
if set to True, convolution have no access
|
|
to future frames. default False.
|
|
batch_norm: bool, optional
|
|
if set to True, apply batchnorm before activation.
|
|
default False
|
|
chunk_se: int, optional
|
|
0 for offline SE.
|
|
1 for streaming SE, where mean is computed
|
|
by accumulated history until current chunk_se.
|
|
2 for streaming SE, where mean is computed
|
|
by only the current chunk.
|
|
chunk_size: int, optional
|
|
chunk size for cnn. default 18
|
|
activation: str, optional
|
|
activation function used in ConvModule,
|
|
default: "relu".
|
|
glu_type: str, optional
|
|
activation function used for the glu,
|
|
default: "sigmoid".
|
|
bias_in_glu: bool, optional
|
|
if set to True, use additive bias in the weight module
|
|
before GLU.
|
|
linear_glu_in_convm: bool, optional
|
|
if set to True, use GLULinear module,
|
|
otherwise, used GLUPointWiseConv module.
|
|
default to False.
|
|
export: bool, optional,
|
|
if set to True, padding is equal to 0. This is for inference,
|
|
or onnx export. Typically this is set by the export program or
|
|
the decoder program, and it isn't present in your config file.
|
|
default False
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim,
|
|
ext_pw_out_channel,
|
|
depthwise_seperable_out_channel,
|
|
ext_pw_kernel_size,
|
|
kernel_size,
|
|
depthwise_multiplier,
|
|
dropout_rate,
|
|
causal=False,
|
|
batch_norm=False,
|
|
chunk_se=0,
|
|
chunk_size=18,
|
|
activation="relu",
|
|
glu_type="sigmoid",
|
|
bias_in_glu=True,
|
|
linear_glu_in_convm=False,
|
|
export=False,
|
|
):
|
|
super().__init__()
|
|
self.layer_norm = nn.LayerNorm(input_dim)
|
|
self.input_dim = input_dim
|
|
self.ext_pw_out_channel = ext_pw_out_channel
|
|
self.ext_pw_kernel_size = ext_pw_kernel_size
|
|
self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
|
|
self.glu_type = glu_type
|
|
self.bias_in_glu = bias_in_glu
|
|
self.linear_glu_in_convm = linear_glu_in_convm
|
|
self.causal = causal
|
|
|
|
self._add_ext_pw_layer()
|
|
|
|
self.batch_norm = batch_norm
|
|
self.kernel_size = kernel_size
|
|
|
|
if batch_norm:
|
|
self.bn_layer = nn.BatchNorm1d(input_dim)
|
|
|
|
self.act = get_activation(activation)
|
|
self.dropout = nn.Dropout(dropout_rate)
|
|
self.export = export
|
|
|
|
if causal:
|
|
if export: # Inference only.
|
|
padding = 0 # A cache is concatenated to the left. No padding in the kernel.
|
|
else:
|
|
# Training only. Padding will be added symmetrically on both sides.
|
|
# After convolution, clip off kernel_size-1 points on the right.
|
|
padding = kernel_size - 1
|
|
else:
|
|
padding = (kernel_size - 1) // 2
|
|
|
|
self.dw_sep_conv_1d = DepthWiseSeperableConv1d(
|
|
input_dim,
|
|
depthwise_seperable_out_channel,
|
|
kernel_size,
|
|
depthwise_multiplier,
|
|
padding=padding,
|
|
)
|
|
|
|
if depthwise_seperable_out_channel != 0:
|
|
if input_dim != depthwise_seperable_out_channel:
|
|
self.ln2 = nn.Linear(depthwise_seperable_out_channel, input_dim)
|
|
else:
|
|
if depthwise_multiplier != 1:
|
|
self.ln2 = nn.Linear(input_dim * depthwise_multiplier, input_dim)
|
|
|
|
def _add_ext_pw_layer(self):
|
|
"""
|
|
This function is an extension of __init__ function
|
|
and dedicated to the convolution module creation
|
|
of the conformer.
|
|
"""
|
|
self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = nn.Identity() # jit hacks.
|
|
self.squeeze_excitation = nn.Identity() # jit.
|
|
self.apply_ln1 = self.fix_len1 = False # jit.
|
|
|
|
if self.ext_pw_out_channel != 0:
|
|
if self.causal:
|
|
self.ext_pw_conv_1d = nn.Conv1d(
|
|
self.input_dim,
|
|
self.ext_pw_out_channel,
|
|
self.ext_pw_kernel_size,
|
|
1,
|
|
padding=(self.ext_pw_kernel_size - 1),
|
|
)
|
|
if self.ext_pw_kernel_size > 1:
|
|
self.fix_len1 = True
|
|
else:
|
|
self.fix_len1 = False
|
|
else:
|
|
self.ext_pw_conv_1d = nn.Conv1d(
|
|
self.input_dim,
|
|
self.ext_pw_out_channel,
|
|
self.ext_pw_kernel_size,
|
|
1,
|
|
padding=(self.ext_pw_kernel_size - 1) // 2,
|
|
)
|
|
self.fix_len1 = False
|
|
|
|
if self.linear_glu_in_convm:
|
|
self.glu = GLULinear(
|
|
self.input_dim, self.ext_pw_out_channel, self.glu_type, self.bias_in_glu
|
|
)
|
|
else:
|
|
self.glu = GLUPointWiseConv(
|
|
self.input_dim,
|
|
self.ext_pw_out_channel,
|
|
self.ext_pw_kernel_size,
|
|
self.glu_type,
|
|
self.bias_in_glu,
|
|
self.causal,
|
|
)
|
|
|
|
if self.input_dim != self.ext_pw_out_channel:
|
|
self.apply_ln1 = True
|
|
self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim)
|
|
else:
|
|
self.apply_ln1 = False
|
|
else:
|
|
self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3))
|
|
self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3))
|
|
|
|
def forward(self, x):
|
|
"""ConvModule Forward.
|
|
|
|
Args:
|
|
x: torch.Tensor
|
|
input tensor.
|
|
"""
|
|
x = self.layer_norm(x)
|
|
|
|
if self.ext_pw_out_channel != 0:
|
|
x = self.glu(x)
|
|
if self.causal and self.ext_pw_kernel_size > 1:
|
|
x = x[:, : -(self.ext_pw_kernel_size - 1), :]
|
|
if self.apply_ln1:
|
|
x = self.ln1(x)
|
|
else:
|
|
x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0]
|
|
x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1]
|
|
x = x_0 + x_1
|
|
|
|
x = x.permute([0, 2, 1])
|
|
|
|
x = self.dw_sep_conv_1d(x)
|
|
if self.causal and self.kernel_size > 1:
|
|
x = x[:, :, : -(self.kernel_size - 1)]
|
|
if hasattr(self, "ln2"):
|
|
x = x.permute([0, 2, 1])
|
|
x = self.ln2(x)
|
|
x = x.permute([0, 2, 1])
|
|
if self.batch_norm:
|
|
x = self.bn_layer(x)
|
|
x = self.act(x)
|
|
|
|
if self.ext_pw_out_channel != 0:
|
|
x = self.ext_pw_conv_1d(x)
|
|
if self.fix_len1:
|
|
x = x[:, :, : -(self.ext_pw_kernel_size - 1)]
|
|
|
|
if self.apply_ln1:
|
|
x = x.permute([0, 2, 1])
|
|
x = self.ln1(x)
|
|
x = x.permute([0, 2, 1])
|
|
|
|
x = x.permute([0, 2, 1])
|
|
else:
|
|
x = x.unsqueeze(1).permute([0, 1, 3, 2])
|
|
x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2]
|
|
x = x.squeeze(1)
|
|
|
|
x = self.dropout(x)
|
|
return x
|
|
|
|
class GLULinear(nn.Module):
|
|
"""Linear + GLU module
|
|
|
|
Args:
|
|
input_dim: int
|
|
input size
|
|
output_dim: int
|
|
output size.
|
|
glu_type:
|
|
activation function name used in glu module.
|
|
default "sigmoid" (swish function).
|
|
bias_in_glu: bool, optional
|
|
If True, the addtive bias is added. Default False.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim,
|
|
output_dim,
|
|
glu_type="sigmoid",
|
|
bias_in_glu=True,
|
|
):
|
|
super().__init__()
|
|
self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu)
|
|
self.glu_act = GLU(-1, glu_type)
|
|
|
|
def forward(self, x):
|
|
"""GLULinear forward
|
|
|
|
Args:
|
|
x: torch.Tensor
|
|
inpute tensor.
|
|
"""
|
|
x = self.linear(x)
|
|
return self.glu_act(x)
|
|
|
|
class FeedForward(nn.Module):
|
|
"""FeedForward Module.
|
|
For more details see Conformer paper:
|
|
https://arxiv.org/pdf/2005.08100.pdf
|
|
|
|
Args:
|
|
d_model: int
|
|
input size.
|
|
d_inner: int
|
|
output size.
|
|
dropout_rate: float,
|
|
dropout rate.
|
|
activation: str,
|
|
activation function name,
|
|
one of ["relu", "swish", "sigmoid"],
|
|
sigmoid activation is only used with "glu_in_fnn=True",
|
|
default "sigmoid".
|
|
bias_in_glu: bool, optional
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
d_model,
|
|
d_inner,
|
|
dropout_rate,
|
|
activation="sigmoid",
|
|
bias_in_glu=True,
|
|
):
|
|
super().__init__()
|
|
self.d_model = d_model
|
|
self.d_inner = d_inner
|
|
|
|
self.layer_norm = nn.LayerNorm(d_model)
|
|
module = GLULinear(d_model, d_inner, activation, bias_in_glu)
|
|
self.net = nn.Sequential(
|
|
module,
|
|
nn.Dropout(dropout_rate),
|
|
nn.Linear(d_inner, d_model),
|
|
nn.Dropout(dropout_rate),
|
|
)
|
|
|
|
def forward(self, x):
|
|
"""FeedForward forward function.
|
|
|
|
Args:
|
|
x: torch.Tensor
|
|
input tensor.
|
|
"""
|
|
out = self.net(self.layer_norm(x))
|
|
|
|
return out
|
|
|
|
#### positional encoding starts here
|
|
def _pre_hook(
|
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
):
|
|
"""Perform pre-hook in load_state_dict for backward compatibility.
|
|
|
|
Note:
|
|
We saved self.pe until v.0.5.2 but we have omitted it later.
|
|
Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
|
|
|
|
"""
|
|
k = prefix + "pe"
|
|
if k in state_dict:
|
|
state_dict.pop(k)
|
|
|
|
class T5RelativeAttentionLogitBias(nn.Module):
|
|
"""
|
|
This module implements the relative position bias described in Section 2.1 of
|
|
the T5 paper: https://arxiv.org/pdf/1910.10683.pdf
|
|
|
|
The Huggingface implementation is used as a reference
|
|
https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/t5/modeling_t5.py#L435
|
|
|
|
Modifies attention as Q*K^T + B, where B is a learned scalar bias based on relative position
|
|
of the query and key. It is HxNxN, where H is the number of heads, N is the sequence length.
|
|
|
|
I've made these modifications to the original T5 bias:
|
|
- Skipping of the bucketing step. Original T5 bias converted rel position distances into
|
|
logarithmically increasing buckets. This is supposed to help with length generalization.
|
|
- I just directly use rel position index as bias values, as we don't need length
|
|
generalization (40s max is good enough for ASR encoder), and it keeps ONNX export simple.
|
|
- I've also extended it so that biases can be asymmetric, the default implementation treats
|
|
L->R and R->L the same. Asymmetric was found to yield better results in my experiments.
|
|
|
|
Args:
|
|
num_heads: int
|
|
Number of attention heads
|
|
num_buckets: int
|
|
Number of buckets to use for relative attention bias. This is the size of the learnable
|
|
bias parameter. Bucketing is not yet supported, so this defaults to -1 which means
|
|
no bucketing is used (max_distance determines size of bias param).
|
|
max_distance: int
|
|
Maximum distance to use for relative attention bias. With num_buckets=-1, this directly
|
|
controls the max size of the bias parameter. When num_buckets > 0 is supported, this
|
|
will control the maximum distance for logarithmic bucketing after which all positions
|
|
are in the same bucket.
|
|
symmetric: bool
|
|
Whether to use symmetric or asymmetric biases. symmetric=False uses 2x number of bias
|
|
params to distinguish L->R from R->L. This was found to be better for the encoder.
|
|
"""
|
|
|
|
def __init__(self, num_heads, num_buckets=-1, max_distance=1000, symmetric=False):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.num_buckets = num_buckets
|
|
self.max_distance = max_distance
|
|
self.symmetric = symmetric
|
|
self._skip_bucketing = self.num_buckets < 0
|
|
if self._skip_bucketing:
|
|
self.num_buckets = max_distance
|
|
else:
|
|
raise NotImplementedError("T5 attention bias with bucketed positions is not yet tested")
|
|
if not self.symmetric:
|
|
self.num_buckets *= 2
|
|
self.bias_values = nn.Embedding(self.num_buckets, self.num_heads)
|
|
|
|
def forward(self, x):
|
|
# instantiate bias compatible with shape of x
|
|
maxpos = x.size(1)
|
|
context_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[:, None]
|
|
memory_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[None, :]
|
|
relative_position = memory_position - context_position
|
|
# clipping to a maximum distance using ops that play well with ONNX export
|
|
relative_position = relative_position.masked_fill(
|
|
relative_position < -self.max_distance, -self.max_distance
|
|
)
|
|
relative_position = relative_position.masked_fill(
|
|
relative_position > self.max_distance - 1, self.max_distance - 1
|
|
)
|
|
|
|
# mapping from relative position to index in the bias parameter
|
|
if self._skip_bucketing:
|
|
bias_idx = relative_position
|
|
else:
|
|
bias_idx = self._bucket_relative_position(relative_position)
|
|
if self.symmetric:
|
|
bias_idx = bias_idx.abs()
|
|
else:
|
|
bias_idx += self.num_buckets // 2
|
|
|
|
t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H]
|
|
t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(0) # [1, H, L, L]
|
|
|
|
return t5_rel_att_bias
|
|
|
|
def _bucket_relative_position(self, relative_position):
|
|
# this is a placeholder (isn't tested, likely buggy) using HuggingFace implem as a reference
|
|
# this also needs to be extended to support asymmetric +/- ve positions
|
|
relative_buckets = 0
|
|
if not self.causal:
|
|
num_buckets //= 2
|
|
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
|
relative_position = torch.abs(relative_position)
|
|
else:
|
|
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
|
# now relative_position is in the range [0, inf)
|
|
|
|
# half of the buckets are for exact increments in positions
|
|
max_exact = num_buckets // 2
|
|
is_small = relative_position < max_exact
|
|
|
|
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
|
relative_position_if_large = max_exact + (
|
|
torch.log(relative_position.float() / max_exact)
|
|
/ math.log(self.max_distance / max_exact)
|
|
* (num_buckets - max_exact)
|
|
).to(torch.long)
|
|
relative_position_if_large = torch.min(
|
|
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
|
|
)
|
|
|
|
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
|
return relative_buckets
|
|
|
|
class AbsolutePositionalEncoding(nn.Module):
|
|
"""Absolute Positional encoding module.
|
|
This module implement Absolute sinusoidal positional encoding
|
|
from: https://arxiv.org/pdf/1706.03762.pdf
|
|
|
|
Args:
|
|
d_model: int
|
|
Input embedding size.
|
|
dropout_rate: float
|
|
dropout rate
|
|
max_len: int, optional
|
|
Maximum input length sequence, Default 5000
|
|
|
|
"""
|
|
|
|
def __init__(self, d_model, dropout_rate, max_len=5000):
|
|
"""Construct an PositionalEncoding object."""
|
|
super().__init__()
|
|
self.d_model = d_model
|
|
self.xscale = math.sqrt(self.d_model)
|
|
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
|
self.pe = None
|
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
|
self._register_load_state_dict_pre_hook(_pre_hook)
|
|
|
|
def extend_pe(self, x):
|
|
"""Reset the positional encodings.
|
|
|
|
Args:
|
|
x: torch.Tensor
|
|
"""
|
|
if self.pe is not None:
|
|
if self.pe.size(1) >= x.size(1):
|
|
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
|
return
|
|
pe = torch.zeros(x.size(1), self.d_model)
|
|
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
|
div_term = torch.exp(
|
|
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
|
* -(math.log(10000.0) / self.d_model)
|
|
)
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
pe = pe.unsqueeze(0)
|
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
"""Add positional encoding.
|
|
|
|
Args:
|
|
x: torch.Tensor
|
|
Input tensor. shape is (batch, time, ...)
|
|
|
|
Returns:
|
|
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
|
|
|
"""
|
|
self.extend_pe(x)
|
|
x = x * self.xscale + self.pe[:, : x.size(1)]
|
|
return self.dropout(x)
|
|
|
|
#### forward embedding layers starts here
|
|
|
|
@backoff.on_exception(backoff.expo, Exception, max_tries=10)
|
|
def np_loadtxt_with_retry(filepath):
|
|
"""np.loadtxt with retry
|
|
|
|
Args:
|
|
filepath: str
|
|
file path to the numpy array.
|
|
"""
|
|
result = np.loadtxt(filepath, dtype="f")
|
|
return result
|
|
|
|
class MeanVarianceNormLayer(nn.Module):
|
|
"""Mean/variance normalization layer.
|
|
|
|
Will substract mean and multiply input by inverted standard deviation.
|
|
Typically used as a very first layer in a model.
|
|
|
|
Args:
|
|
input_size: int
|
|
layer input size.
|
|
"""
|
|
|
|
def __init__(self, input_size):
|
|
super().__init__()
|
|
self.input_size = input_size
|
|
self.register_buffer("global_mean", torch.zeros(input_size))
|
|
self.register_buffer("global_invstd", torch.ones(input_size))
|
|
self.global_mean: Optional[Tensor]
|
|
self.global_invstd: Optional[Tensor]
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
"""MeanVarianceNormLayer Forward
|
|
|
|
Args:
|
|
input_: torch.Tensor
|
|
input tensor.
|
|
"""
|
|
return (input_ - self.global_mean) * self.global_invstd
|
|
|
|
def load_mean_invstd(self, mean_file, invstd_file, cuside_features=False):
|
|
"""Load feature mean and variance used for normalization.
|
|
|
|
Args:
|
|
mean_file: str
|
|
path to the feature mean statistics file.
|
|
invstd_file: str
|
|
path to the features inverted standard deviation
|
|
statistics file.
|
|
cuside_features: bool
|
|
Boolean that indicates CUSIDE is being used.
|
|
The statistics of CUSIDE features are copied
|
|
from the normal features
|
|
"""
|
|
self.global_mean.data = torch.from_numpy(np_loadtxt_with_retry(mean_file))
|
|
self.global_invstd.data = torch.from_numpy(np_loadtxt_with_retry(invstd_file))
|
|
|
|
if cuside_features:
|
|
self.global_mean.data = torch.cat((self.global_mean.data, self.global_mean.data), 0)
|
|
self.global_invstd.data = torch.cat(
|
|
(self.global_invstd.data, self.global_invstd.data), 0
|
|
)
|
|
|
|
class CausalConv1D(nn.Conv1d):
|
|
"""
|
|
A causal version of nn.Conv1d where each step would have limited access to locations on its right or left
|
|
All arguments are the same as nn.Conv1d except padding.
|
|
|
|
If padding is set None, then paddings are set automatically to make it a causal convolution where each location would not see any steps on its right.
|
|
|
|
If padding is set as a list (size of 2), then padding[0] would be used as left padding and padding[1] as right padding.
|
|
It would make it possible to control the number of steps to be accessible on the right and left.
|
|
This mode is not supported when stride > 1. padding[0]+padding[1] should be equal to (kernel_size - 1).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: int,
|
|
stride: int = 1,
|
|
padding: Union[str, int] = 0,
|
|
dilation: int = 1,
|
|
groups: int = 1,
|
|
bias: bool = True,
|
|
padding_mode: str = "zeros",
|
|
device=None,
|
|
dtype=None,
|
|
) -> None:
|
|
self.cache_drop_size = None
|
|
if padding is None:
|
|
self._left_padding = kernel_size - 1
|
|
self._right_padding = stride - 1
|
|
else:
|
|
if stride != 1 and padding != kernel_size - 1:
|
|
raise ValueError("No striding allowed for non-symmetric convolutions!")
|
|
if isinstance(padding, int):
|
|
self._left_padding = padding
|
|
self._right_padding = padding
|
|
elif (
|
|
isinstance(padding, list)
|
|
and len(padding) == 2
|
|
and padding[0] + padding[1] == kernel_size - 1
|
|
):
|
|
self._left_padding = padding[0]
|
|
self._right_padding = padding[1]
|
|
else:
|
|
raise ValueError(f"Invalid padding param: {padding}!")
|
|
|
|
self._max_cache_len = self._left_padding
|
|
|
|
super().__init__(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=0,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
bias=bias,
|
|
padding_mode=padding_mode,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
|
|
def update_cache(self, x, cache=None):
|
|
if cache is None:
|
|
new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
|
|
next_cache = cache
|
|
else:
|
|
new_x = F.pad(x, pad=(0, self._right_padding))
|
|
new_x = torch.cat([cache, new_x], dim=-1)
|
|
if self.cache_drop_size > 0:
|
|
next_cache = new_x[:, :, : -self.cache_drop_size]
|
|
else:
|
|
next_cache = new_x
|
|
next_cache = next_cache[:, :, -cache.size(-1) :]
|
|
return new_x, next_cache
|
|
|
|
def forward(self, x, cache=None):
|
|
x, cache = self.update_cache(x, cache=cache)
|
|
x = super().forward(x)
|
|
if cache is None:
|
|
return x
|
|
else:
|
|
return x, cache
|
|
|
|
|
|
class CausalConv2D(nn.Conv2d):
|
|
"""
|
|
A causal version of nn.Conv2d where each location in the 2D matrix would have no access to locations on its right or down
|
|
All arguments are the same as nn.Conv2d except padding which should be set as None
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: int,
|
|
stride: int = 1,
|
|
padding: Union[str, int] = 0,
|
|
dilation: int = 1,
|
|
groups: int = 1,
|
|
bias: bool = True,
|
|
padding_mode: str = "zeros",
|
|
device=None,
|
|
dtype=None,
|
|
) -> None:
|
|
if padding is not None:
|
|
raise ValueError("Argument padding should be set to None for CausalConv2D.")
|
|
self._left_padding = kernel_size - 1
|
|
self._right_padding = stride - 1
|
|
|
|
padding = 0
|
|
super().__init__(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
bias,
|
|
padding_mode,
|
|
device,
|
|
dtype,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
):
|
|
if self.training:
|
|
x = F.pad(
|
|
x,
|
|
pad=(
|
|
self._left_padding,
|
|
self._right_padding,
|
|
self._left_padding,
|
|
self._right_padding,
|
|
),
|
|
)
|
|
else:
|
|
x = F.pad(
|
|
x,
|
|
pad=(self._left_padding, self._right_padding, 0, 0),
|
|
)
|
|
x = super().forward(x)
|
|
return x
|
|
|
|
|
|
class NemoConvSubsampling(torch.nn.Module):
|
|
"""Convlutional subsampling module, taken from NeMo ASR
|
|
(https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a34501479cf/nemo/collections/asr/parts/submodules/subsampling.py)
|
|
|
|
Striding Subsampling: "Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for
|
|
Speech Recognition" by Linhao Dong et al. (https://ieeexplore.ieee.org/document/8462506)
|
|
|
|
|
|
Compared with the EncoderConv2D (`input_layer: custom`), this is a much simplified approach,
|
|
and uses no LayerNorm and far fewer Conv2Ds. Moreover, depthwise convolutions are used to reduce
|
|
FLOPs, but the first layer is kept as a regular convolution so as not to degrade accuracy.
|
|
|
|
`Striding` and `dw_striding` are the same except that the latter uses depthwise convolutions
|
|
after the first layer, whereas the former does not.
|
|
|
|
Args:
|
|
subsampling_factor (int): Time reduction factor
|
|
feat_in (int): size of the input features
|
|
feat_out (int): size of the output features
|
|
subsampling (str): The subsampling technique, choose from
|
|
{"striding", "dw-striding", "striding_conv1d", "dw_striding_conv1d"}
|
|
conv_channels (int): Number of channels for the convolution layers, default is 256.
|
|
subsampling_conv_chunking_factor (int): Input chunking factor which can be -1 (no chunking)
|
|
1 (auto) or a power of 2. Default is 1
|
|
activation (Module): activation function, default is nn.ReLU()
|
|
is_causal (bool): whether to use causal Conv1/2D, where each step will have limited access
|
|
to locations on its right or left
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
feat_in,
|
|
feat_out,
|
|
subsampling_factor=4,
|
|
subsampling="dw_striding",
|
|
conv_channels=256,
|
|
subsampling_conv_chunking_factor=1,
|
|
activation=nn.ReLU(),
|
|
is_causal=False,
|
|
):
|
|
super().__init__()
|
|
self._subsampling = subsampling
|
|
self._conv_channels = conv_channels
|
|
self._feat_in = feat_in
|
|
self._feat_out = feat_out
|
|
|
|
if subsampling_factor % 2 != 0:
|
|
raise ValueError("Sampling factor should be a multiply of 2!")
|
|
self._sampling_num = int(math.log(subsampling_factor, 2))
|
|
self.subsampling_factor = subsampling_factor
|
|
self.is_causal = is_causal
|
|
self.subsampling_causal_cond = subsampling in ("dw_striding", "striding", "striding_conv1d")
|
|
|
|
if (
|
|
subsampling_conv_chunking_factor != -1
|
|
and subsampling_conv_chunking_factor != 1
|
|
and subsampling_conv_chunking_factor % 2 != 0
|
|
):
|
|
raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2")
|
|
self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
|
|
|
|
in_channels = 1
|
|
layers = []
|
|
|
|
if subsampling == "dw_striding":
|
|
self._stride = 2
|
|
self._kernel_size = 3
|
|
self._ceil_mode = False
|
|
|
|
if self.is_causal:
|
|
self._left_padding = self._kernel_size - 1
|
|
self._right_padding = self._stride - 1
|
|
self._max_cache_len = subsampling_factor + 1
|
|
else:
|
|
self._left_padding = (self._kernel_size - 1) // 2
|
|
self._right_padding = (self._kernel_size - 1) // 2
|
|
self._max_cache_len = 0
|
|
|
|
# Layer 1
|
|
if self.is_causal:
|
|
layers.append(
|
|
CausalConv2D(
|
|
in_channels=in_channels,
|
|
out_channels=conv_channels,
|
|
kernel_size=self._kernel_size,
|
|
stride=self._stride,
|
|
padding=None,
|
|
)
|
|
)
|
|
else:
|
|
layers.append(
|
|
torch.nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=conv_channels,
|
|
kernel_size=self._kernel_size,
|
|
stride=self._stride,
|
|
padding=self._left_padding,
|
|
)
|
|
)
|
|
in_channels = conv_channels
|
|
layers.append(activation)
|
|
|
|
for i in range(self._sampling_num - 1):
|
|
if self.is_causal:
|
|
layers.append(
|
|
CausalConv2D(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
kernel_size=self._kernel_size,
|
|
stride=self._stride,
|
|
padding=None,
|
|
groups=in_channels,
|
|
)
|
|
)
|
|
else:
|
|
layers.append(
|
|
torch.nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
kernel_size=self._kernel_size,
|
|
stride=self._stride,
|
|
padding=self._left_padding,
|
|
groups=in_channels,
|
|
)
|
|
)
|
|
|
|
layers.append(
|
|
torch.nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=conv_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
groups=1,
|
|
)
|
|
)
|
|
layers.append(activation)
|
|
in_channels = conv_channels
|
|
|
|
elif subsampling == "striding":
|
|
self._stride = 2
|
|
self._kernel_size = 3
|
|
self._ceil_mode = False
|
|
|
|
if self.is_causal:
|
|
self._left_padding = self._kernel_size - 1
|
|
self._right_padding = self._stride - 1
|
|
self._max_cache_len = subsampling_factor + 1
|
|
else:
|
|
self._left_padding = (self._kernel_size - 1) // 2
|
|
self._right_padding = (self._kernel_size - 1) // 2
|
|
self._max_cache_len = 0
|
|
|
|
for i in range(self._sampling_num):
|
|
if self.is_causal:
|
|
layers.append(
|
|
CausalConv2D(
|
|
in_channels=in_channels,
|
|
out_channels=conv_channels,
|
|
kernel_size=self._kernel_size,
|
|
stride=self._stride,
|
|
padding=None,
|
|
)
|
|
)
|
|
else:
|
|
layers.append(
|
|
torch.nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=conv_channels,
|
|
kernel_size=self._kernel_size,
|
|
stride=self._stride,
|
|
padding=self._left_padding,
|
|
)
|
|
)
|
|
layers.append(activation)
|
|
in_channels = conv_channels
|
|
|
|
elif subsampling == "striding_conv1d":
|
|
in_channels = feat_in
|
|
|
|
self._stride = 2
|
|
self._kernel_size = 5
|
|
self._ceil_mode = False
|
|
|
|
if self.is_causal:
|
|
self._left_padding = self._kernel_size - 1
|
|
self._right_padding = self._stride - 1
|
|
self._max_cache_len = subsampling_factor + 1
|
|
else:
|
|
self._left_padding = (self._kernel_size - 1) // 2
|
|
self._right_padding = (self._kernel_size - 1) // 2
|
|
self._max_cache_len = 0
|
|
|
|
for i in range(self._sampling_num):
|
|
if self.is_causal:
|
|
layers.append(
|
|
CausalConv1D(
|
|
in_channels=in_channels,
|
|
out_channels=feat_out if self._sampling_num == i + 1 else conv_channels,
|
|
kernel_size=self._kernel_size,
|
|
stride=self._stride,
|
|
padding=None,
|
|
)
|
|
)
|
|
else:
|
|
layers.append(
|
|
torch.nn.Conv1d(
|
|
in_channels=in_channels,
|
|
out_channels=feat_out if self._sampling_num == i + 1 else conv_channels,
|
|
kernel_size=self._kernel_size,
|
|
stride=self._stride,
|
|
padding=self._left_padding,
|
|
)
|
|
)
|
|
layers.append(activation)
|
|
in_channels = conv_channels
|
|
|
|
elif subsampling == "dw_striding_conv1d":
|
|
in_channels = feat_in
|
|
|
|
self._stride = 2
|
|
self._kernel_size = 5
|
|
self._ceil_mode = False
|
|
|
|
self._left_padding = (self._kernel_size - 1) // 2
|
|
self._right_padding = (self._kernel_size - 1) // 2
|
|
|
|
# Layer 1
|
|
layers.extend(
|
|
[
|
|
torch.nn.Conv1d(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
kernel_size=self._kernel_size,
|
|
stride=self._stride,
|
|
padding=self._left_padding,
|
|
groups=in_channels,
|
|
),
|
|
torch.nn.Conv1d(
|
|
in_channels=in_channels,
|
|
out_channels=feat_out if self._sampling_num == 1 else conv_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
groups=1,
|
|
),
|
|
]
|
|
)
|
|
in_channels = conv_channels
|
|
layers.append(activation)
|
|
|
|
for i in range(self._sampling_num - 1):
|
|
layers.extend(
|
|
[
|
|
torch.nn.Conv1d(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
kernel_size=self._kernel_size,
|
|
stride=self._stride,
|
|
padding=self._left_padding,
|
|
groups=in_channels,
|
|
),
|
|
torch.nn.Conv1d(
|
|
in_channels=in_channels,
|
|
out_channels=feat_out if self._sampling_num == i + 2 else conv_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
groups=1,
|
|
),
|
|
]
|
|
)
|
|
layers.append(activation)
|
|
in_channels = conv_channels
|
|
|
|
else:
|
|
raise ValueError(f"Not valid sub-sampling: {subsampling}!")
|
|
|
|
if subsampling in ["dw_striding", "striding"]:
|
|
in_length = torch.tensor(feat_in, dtype=torch.float)
|
|
out_length = calc_length(
|
|
lengths=in_length,
|
|
all_paddings=self._left_padding + self._right_padding,
|
|
kernel_size=self._kernel_size,
|
|
stride=self._stride,
|
|
ceil_mode=self._ceil_mode,
|
|
repeat_num=self._sampling_num,
|
|
)
|
|
self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
|
|
self.conv2d_subsampling = True
|
|
elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
|
|
self.out = None
|
|
self.conv2d_subsampling = False
|
|
else:
|
|
raise ValueError(f"Not valid sub-sampling: {subsampling}!")
|
|
|
|
self.conv = torch.nn.Sequential(*layers)
|
|
|
|
def get_sampling_frames(self):
|
|
return [1, self.subsampling_factor]
|
|
|
|
def get_streaming_cache_size(self):
|
|
return [0, self.subsampling_factor + 1]
|
|
|
|
def forward(self, x, mask):
|
|
"""
|
|
Forward method for NeMo subsampling.
|
|
|
|
Args:
|
|
x[Batch, Time, Filters]: torch.Tensor
|
|
input tensor
|
|
x_mask: torch.Tensor
|
|
input mask
|
|
|
|
Returns:
|
|
x: torch.Tensor
|
|
Resulting tensor from subsampling (B, T // time_reduction_factor, feat_out)
|
|
pad_mask: torch.Tensor
|
|
tensor of padded hidden state sequences (B, 1, T // time_reduction_factor)
|
|
"""
|
|
# Unsqueeze Channel Axis
|
|
if self.conv2d_subsampling:
|
|
x = x.unsqueeze(1)
|
|
# Transpose to Channel First mode
|
|
else:
|
|
x = x.transpose(1, 2)
|
|
|
|
# split inputs if chunking_factor is set
|
|
if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling:
|
|
if self.subsampling_conv_chunking_factor == 1:
|
|
# if subsampling_conv_chunking_factor is 1, we split only if needed
|
|
# avoiding a bug / feature limiting indexing of tensors to 2**31
|
|
# see https://github.com/pytorch/pytorch/issues/80020
|
|
x_ceil = 2**31 / self._conv_channels * self._stride * self._stride
|
|
if torch.numel(x) > x_ceil:
|
|
need_to_split = True
|
|
else:
|
|
need_to_split = False
|
|
else:
|
|
# if subsampling_conv_chunking_factor > 1 we always split
|
|
need_to_split = True
|
|
|
|
if need_to_split:
|
|
x, success = self.conv_split_by_batch(x)
|
|
if not success: # if unable to split by batch, try by channel
|
|
if self._subsampling == "dw_striding":
|
|
x = self.conv_split_by_channel(x)
|
|
else:
|
|
x = self.conv(x) # try anyway
|
|
else:
|
|
x = self.conv(x)
|
|
else:
|
|
x = self.conv(x)
|
|
|
|
# Flatten Channel and Frequency Axes
|
|
if self.conv2d_subsampling:
|
|
b, c, t, f = x.size()
|
|
x = self.out(x.transpose(1, 2).reshape(b, t, -1))
|
|
# Transpose to Channel Last mode
|
|
else:
|
|
x = x.transpose(1, 2)
|
|
|
|
if mask is None:
|
|
return x, None
|
|
|
|
max_audio_length = x.shape[1]
|
|
feature_lens = mask.sum(1)
|
|
padding_length = torch.ceil(feature_lens / self.subsampling_factor)
|
|
if self.is_causal and self.subsampling_causal_cond:
|
|
feature_lens_remainder = feature_lens % self.subsampling_factor
|
|
padding_length[feature_lens_remainder != 1] += 1
|
|
pad_mask = (
|
|
torch.arange(0, max_audio_length, device=x.device).expand(padding_length.size(0), -1)
|
|
< padding_length.unsqueeze(1)
|
|
)
|
|
return x, pad_mask.unsqueeze(1)
|
|
|
|
def reset_parameters(self):
|
|
# initialize weights
|
|
if self._subsampling == "dw_striding":
|
|
with torch.no_grad():
|
|
# init conv
|
|
scale = 1.0 / self._kernel_size
|
|
dw_max = (self._kernel_size**2) ** -0.5
|
|
pw_max = self._conv_channels**-0.5
|
|
|
|
torch.nn.init.uniform_(self.conv[0].weight, -scale, scale)
|
|
torch.nn.init.uniform_(self.conv[0].bias, -scale, scale)
|
|
|
|
for idx in range(2, len(self.conv), 3):
|
|
torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max)
|
|
torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max)
|
|
torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max)
|
|
torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max)
|
|
|
|
# init fc (80 * 64 = 5120 from https://github.com/kssteven418/Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/src/models/conformer_encoder.py#L487
|
|
fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5
|
|
torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale)
|
|
torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale)
|
|
|
|
def conv_split_by_batch(self, x):
|
|
"""Tries to split input by batch, run conv and concat results"""
|
|
b, _, _, _ = x.size()
|
|
if b == 1: # can't split if batch size is 1
|
|
return x, False
|
|
|
|
if self.subsampling_conv_chunking_factor > 1:
|
|
cf = self.subsampling_conv_chunking_factor
|
|
else:
|
|
# avoiding a bug / feature limiting indexing of tensors to 2**31
|
|
# see https://github.com/pytorch/pytorch/issues/80020
|
|
x_ceil = 2**31 / self._conv_channels * self._stride * self._stride
|
|
p = math.ceil(math.log(torch.numel(x) / x_ceil, 2))
|
|
cf = 2**p
|
|
|
|
new_batch_size = b // cf
|
|
if new_batch_size == 0: # input is too big
|
|
return x, False
|
|
|
|
return torch.cat([self.conv(chunk) for chunk in torch.split(x, new_batch_size, 0)]), True
|
|
|
|
def conv_split_by_channel(self, x):
|
|
"""For dw convs, tries to split input by time, run conv and concat results"""
|
|
x = self.conv[0](x) # full conv2D
|
|
x = self.conv[1](x) # activation
|
|
|
|
for i in range(self._sampling_num - 1):
|
|
_, c, t, _ = x.size()
|
|
|
|
if self.subsampling_conv_chunking_factor > 1:
|
|
cf = self.subsampling_conv_chunking_factor
|
|
else:
|
|
# avoiding a bug / feature limiting indexing of tensors to 2**31
|
|
# see https://github.com/pytorch/pytorch/issues/80020
|
|
p = math.ceil(math.log(torch.numel(x) / 2**31, 2))
|
|
cf = 2**p
|
|
|
|
new_c = int(c // cf)
|
|
if new_c == 0:
|
|
new_c = 1
|
|
|
|
new_t = int(t // cf)
|
|
if new_t == 0:
|
|
new_t = 1
|
|
|
|
x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c, x) # conv2D, depthwise
|
|
|
|
# splitting pointwise convs by time
|
|
x = torch.cat(
|
|
[self.conv[i * 3 + 3](chunk) for chunk in torch.split(x, new_t, 2)], 2
|
|
) # conv2D, pointwise
|
|
x = self.conv[i * 3 + 4](x) # activation
|
|
return x
|
|
|
|
def channel_chunked_conv(self, conv, chunk_size, x):
|
|
"""Performs channel chunked convolution"""
|
|
|
|
ind = 0
|
|
out_chunks = []
|
|
for chunk in torch.split(x, chunk_size, 1):
|
|
step = chunk.size()[1]
|
|
|
|
if self.is_causal:
|
|
chunk = nn.functional.pad(
|
|
chunk,
|
|
pad=(
|
|
self._kernel_size - 1,
|
|
self._stride - 1,
|
|
self._kernel_size - 1,
|
|
self._stride - 1,
|
|
),
|
|
)
|
|
ch_out = nn.functional.conv2d(
|
|
chunk,
|
|
conv.weight[ind : ind + step, :, :, :],
|
|
bias=conv.bias[ind : ind + step],
|
|
stride=self._stride,
|
|
padding=0,
|
|
groups=step,
|
|
)
|
|
else:
|
|
ch_out = nn.functional.conv2d(
|
|
chunk,
|
|
conv.weight[ind : ind + step, :, :, :],
|
|
bias=conv.bias[ind : ind + step],
|
|
stride=self._stride,
|
|
padding=self._left_padding,
|
|
groups=step,
|
|
)
|
|
out_chunks.append(ch_out)
|
|
ind += step
|
|
|
|
return torch.cat(out_chunks, 1)
|
|
|
|
def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_factor: int):
|
|
if (
|
|
subsampling_conv_chunking_factor != -1
|
|
and subsampling_conv_chunking_factor != 1
|
|
and subsampling_conv_chunking_factor % 2 != 0
|
|
):
|
|
raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2")
|
|
self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
|
|
|
|
|
|
def calc_length(lengths, all_paddings, kernel_size, stride, ceil_mode, repeat_num=1):
|
|
"""Calculates the output length of a Tensor passed through a convolution or max pooling layer"""
|
|
add_pad: float = all_paddings - kernel_size
|
|
one: float = 1.0
|
|
for i in range(repeat_num):
|
|
lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one
|
|
if ceil_mode:
|
|
lengths = torch.ceil(lengths)
|
|
else:
|
|
lengths = torch.floor(lengths)
|
|
return lengths.to(dtype=torch.int)
|
|
|
|
#### multihead attention starts here
|
|
class AttModule(nn.Module):
|
|
"""Attention abstraction module"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.export_mode = False
|
|
|
|
def set_export(self, mode=True):
|
|
"""set the export mode"""
|
|
self.export_mode = mode
|
|
|
|
def forward(
|
|
self,
|
|
x: Tensor,
|
|
memory: Optional[Tensor] = None,
|
|
pos_emb: Optional[Tensor] = None,
|
|
att_mask: Optional[Tensor] = None,
|
|
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
|
"""AttModule forward
|
|
|
|
Args:
|
|
x: torch.Tensor
|
|
input tensor.
|
|
memory: torch.Tensor, optional
|
|
memory tensor.
|
|
pos_emb: torch.Tensor, optional
|
|
positional encoder embedding.
|
|
att_mask: torch.Tensor, optional
|
|
attention mask tensor.
|
|
"""
|
|
return x, memory, pos_emb, att_mask
|
|
|
|
|
|
class AttBlock(Block, AttModule):
|
|
"""Attention Block module to support both Attention and Block module."""
|
|
|
|
def memory_dims(self, max_len=False):
|
|
"""memory dimensions"""
|
|
return (1, self.input_size)
|
|
|
|
def masked_softmax(
|
|
scores,
|
|
mask: Optional[Tensor],
|
|
):
|
|
if mask is not None:
|
|
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
|
|
scores = scores.masked_fill(mask, -torch.inf)
|
|
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
|
|
else:
|
|
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
|
return attn
|
|
|
|
|
|
class MultiHeadedAttention(nn.Module):
|
|
"""Multi-Head Attention layer with optional relative position embedding and GLU.
|
|
|
|
Args:
|
|
n_head: int
|
|
the number of heads.
|
|
n_feat: int
|
|
input size features.
|
|
dropout_rate: float
|
|
dropout rate.
|
|
use_LN: bool
|
|
apply layer norm or not
|
|
dropout_at_output: bool
|
|
whether to apply dropout at output
|
|
attention_inner_dim: int, optional
|
|
the attention dimension used in the class,
|
|
it can be different from the input dimension n_feat.
|
|
default: -1 (equal to n_feat).
|
|
use_pt_scaled_dot_product_attention: bool, optional
|
|
if set True, use pytorch scaled dot product attention in training. NOTE: this will NOT
|
|
be used in ONNX decoding due to a lack of support. In that case, we use the original
|
|
attention implementation, which shows no regression.
|
|
default: False.
|
|
n_value: int, optional
|
|
if set to values other than -1, use a different dimension for value. With the default value (i.e. -1), it is backward compatible.
|
|
group_size: int, optional. must divide `n_head`
|
|
if group_size > 1: GQA
|
|
if group_size = 1: MHA
|
|
if group_size = n_head: MQA
|
|
"""
|
|
|
|
inv_sqrt_d_k: torch.jit.Final[float]
|
|
h: torch.jit.Final[int]
|
|
h_k: torch.jit.Final[int]
|
|
g: torch.jit.Final[int]
|
|
|
|
def __init__(
|
|
self,
|
|
n_head,
|
|
n_feat,
|
|
dropout_rate,
|
|
attention_inner_dim=-1,
|
|
glu_type="swish",
|
|
bias_in_glu=True,
|
|
use_pt_scaled_dot_product_attention=False,
|
|
n_value=-1,
|
|
group_size: int = 1,
|
|
):
|
|
super().__init__()
|
|
if n_value == -1:
|
|
n_value = n_feat
|
|
if attention_inner_dim == -1:
|
|
attention_inner_dim = n_feat
|
|
assert attention_inner_dim % n_head == 0
|
|
|
|
# We assume d_v always equals d_k
|
|
self.d_k = attention_inner_dim // n_head
|
|
self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k)
|
|
self.h = n_head
|
|
assert n_head % group_size == 0, "group_size must divide n_head"
|
|
self.g = group_size
|
|
self.h_k = n_head // group_size
|
|
|
|
self.linear_q = nn.Linear(n_feat, attention_inner_dim)
|
|
self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size)
|
|
self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size)
|
|
self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value)
|
|
|
|
self.attn = torch.jit.Attribute(None, Optional[Tensor])
|
|
self.dropout = nn.Dropout(p=dropout_rate)
|
|
self.dropout_rate = dropout_rate
|
|
self.use_pt_scaled_dot_product_attention = use_pt_scaled_dot_product_attention
|
|
|
|
if use_pt_scaled_dot_product_attention and group_size > 1:
|
|
raise ValueError("Cannot use PT Scaled Attention with GQA")
|
|
|
|
# Torchscript eager quantization. Note that these functions below are
|
|
# NOOPs and have very little impact on performance unless quantization is
|
|
# enabled.
|
|
self.quant_q = torch.ao.quantization.QuantStub()
|
|
self.quant_x = torch.ao.quantization.QuantStub()
|
|
self.dequant = torch.ao.quantization.DeQuantStub()
|
|
self.ffunc = torch.ao.nn.quantized.FloatFunctional()
|
|
|
|
def forward(
|
|
self,
|
|
query: Tensor,
|
|
key: Tensor,
|
|
value: Tensor,
|
|
pos_k: Tensor,
|
|
pos_v: Tensor,
|
|
mask: Optional[Tensor],
|
|
relative_attention_bias: Optional[Tensor] = None,
|
|
):
|
|
"""Compute 'Scaled Dot Product Attention'.
|
|
|
|
Args:
|
|
query: torch.Tensor
|
|
query tensor (batch, time1, size)
|
|
key: torch.Tensor
|
|
key tensor (batch, time2, size)
|
|
value: torch.Tensor
|
|
value tensor (batch, time1, size)
|
|
pos_k: torch.Tensor
|
|
key tensor used for relative positional embedding.
|
|
pos_v: torch.Tensor
|
|
value tensor used for relative positional embedding.
|
|
mask: torch.Tensor
|
|
mask tensor (batch, time1, time2)
|
|
relative_attention_bias: torch.Tensor
|
|
bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2)
|
|
"""
|
|
n_batch = query.size(0)
|
|
|
|
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) # (b, t, d)
|
|
k = self.linear_k(key).view(n_batch, -1, self.h_k, self.d_k) # (b, t, d)
|
|
v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k)
|
|
q = (
|
|
q.transpose(1, 2)
|
|
if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting()
|
|
else q.transpose(1, 2) * self.inv_sqrt_d_k
|
|
)
|
|
k = k.transpose(1, 2) # (batch, head_k, time2, d_k)
|
|
v = v.transpose(1, 2) # (batch, head_k, time2, d_k)
|
|
|
|
if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting():
|
|
attn_mask = None
|
|
if mask is not None:
|
|
mask = mask.unsqueeze(1)
|
|
if relative_attention_bias is not None:
|
|
attn_mask = mask + relative_attention_bias
|
|
else:
|
|
attn_mask = mask
|
|
if mask.dtype != q.dtype:
|
|
attn_mask = attn_mask.to(q.dtype)
|
|
|
|
with torch.backends.cuda.sdp_kernel(
|
|
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
|
):
|
|
x = torch.nn.functional.scaled_dot_product_attention(
|
|
q,
|
|
k,
|
|
v,
|
|
attn_mask=attn_mask,
|
|
dropout_p=self.dropout_rate,
|
|
)
|
|
else:
|
|
if self.h != self.h_k:
|
|
q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k)
|
|
A = torch.einsum("b g h t d, b h s d -> b h t s", q, k)
|
|
else:
|
|
A = torch.matmul(q, k.transpose(-2, -1))
|
|
if pos_k is not None:
|
|
if self.h != self.h_k:
|
|
B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k)
|
|
else:
|
|
reshape_q = (
|
|
q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0, 1)
|
|
) # (t1,nh,dk)
|
|
B = torch.matmul(reshape_q, pos_k.transpose(-2, -1)) # pos_k: (t1,dk,t2)
|
|
B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1))
|
|
scores = A + B
|
|
else:
|
|
scores = A
|
|
|
|
if relative_attention_bias is not None:
|
|
scores = scores + relative_attention_bias
|
|
|
|
attn = masked_softmax(scores, mask) # (batch, head, time1, time2)
|
|
|
|
self.attn = attn
|
|
|
|
p_attn = self.dropout(attn)
|
|
x = torch.matmul(p_attn.to(v.dtype), v) # (batch, head, time1, d_k)
|
|
if pos_v is not None:
|
|
reshape_attn = (
|
|
p_attn.contiguous()
|
|
.view(n_batch * self.h, pos_v.size(0), pos_v.size(1))
|
|
.transpose(0, 1)
|
|
) # (t1, bh, t2)
|
|
|
|
attn_v = (
|
|
torch.matmul(reshape_attn, pos_v)
|
|
.transpose(0, 1)
|
|
.contiguous()
|
|
.view(n_batch, self.h, pos_v.size(0), self.d_k)
|
|
)
|
|
x = x + attn_v
|
|
x = (
|
|
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h_k * self.d_k)
|
|
) # (batch, time1, d_model)
|
|
|
|
return self.linear_out(x) # (batch, time1, d_model)
|
|
|
|
|
|
def unfold_tensor(xs_pad, max_seq_len):
|
|
"""
|
|
For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len,
|
|
this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len.
|
|
Args:
|
|
xs_pad: N, T, D
|
|
"""
|
|
_, _, D = xs_pad.shape
|
|
xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T
|
|
# N x D x 1 x T => N x (D x max_seq_len) x T'
|
|
xs_pad = F.unfold(
|
|
xs_pad[..., None, :],
|
|
kernel_size=(1, max_seq_len),
|
|
stride=(1, max_seq_len),
|
|
)
|
|
|
|
new_bsz, _, slen = xs_pad.shape
|
|
# N x D x max_seq_len x T'
|
|
xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen)
|
|
# N x T' x max_seq_len x D
|
|
xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous()
|
|
# NT' x max_seq_len x D
|
|
xs_pad = xs_pad.view(-1, max_seq_len, D)
|
|
return xs_pad
|
|
|
|
# conformer_encoder.py
|
|
class MultiSequential(torch.nn.Sequential):
|
|
"""Multi-input multi-output torch.nn.Sequential"""
|
|
|
|
@torch.jit.ignore
|
|
def forward(self, *args):
|
|
"""Forward method implementation."""
|
|
for m in self:
|
|
args = m(*args)
|
|
return args
|
|
|
|
def repeat(repeat_num, module_gen_fn):
|
|
"""repeat module N times
|
|
|
|
:param int repeat_num: repeat time
|
|
:param function module_gen_fn: function to generate module
|
|
:return: repeated modules
|
|
:rtype: MultiSequential
|
|
"""
|
|
return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)])
|
|
|
|
class ConformerEncoderLayer(nn.Module):
|
|
"""ConformerEncoder Layer module.
|
|
for more details see conformer paper:
|
|
https://arxiv.org/abs/2005.08100
|
|
This module implement the Conformer block layer.
|
|
|
|
Args:
|
|
d_model: int
|
|
attention dim.
|
|
ext_pw_out_channel: int
|
|
if > 0, ext_pw_out_channel is a dim channel size
|
|
for the last pointwise conv after swish activation.
|
|
depthwise_seperable_out_channel: int
|
|
if set different to 0, the number of depthwise_seperable_out_channel
|
|
will be used as a channel_out of the second conv1d layer.
|
|
otherwise, it equal to 0, the second conv1d layer is skipped.
|
|
depthwise_multiplier: int
|
|
number of input_dim channels duplication. this value
|
|
will be used to compute the hidden channels of the Conv1D.
|
|
n_head: int
|
|
the number of heads for multihead attention module.
|
|
d_ffn: int
|
|
output size of the feed_forward blocks.
|
|
ext_pw_kernel_size: int
|
|
kernel size of the conv pointwise of the conformer.
|
|
kernel_size: int
|
|
kernel size.
|
|
dropout_rate: float
|
|
dropout rate.
|
|
causal: bool, optional
|
|
if set to True, convolution have no access
|
|
to future frames. default False.
|
|
batch_norm: bool, optional
|
|
if set to True, apply batchnorm before activation
|
|
in ConvModule layer of the conformer.
|
|
default False
|
|
activation: str, optional
|
|
activation function name,
|
|
one of ["relu", "swish", "sigmoid"],
|
|
sigmoid activation is only used with "glu_in_fnn=True",
|
|
default "relu".
|
|
chunk_se: int, optional
|
|
0 for offline SE.
|
|
1 for streaming SE, where mean is computed
|
|
by accumulated history until current chunk_se.
|
|
2 for streaming SE, where mean is computed
|
|
by only the current chunk.
|
|
default 0.
|
|
chunk_size: int, optional
|
|
chunk_size for cnn. default 18
|
|
conv_activation: str, optional
|
|
activation function used in ConvModule part
|
|
of the conformer, default "relu".
|
|
conv_glu_type: str, optional
|
|
activation function used for the glu inside
|
|
the ConvModule part of the conformer.
|
|
default: "sigmoid".
|
|
bias_in_glu: bool, optional
|
|
if set to True, use additive bias in the weight module
|
|
before GLU.
|
|
linear_glu_in_convm: bool, optional
|
|
if set to True, use GLULinear module,
|
|
otherwise, used GLUPointWiseConv module.
|
|
default to False.
|
|
attention_innner_dim: int, otional
|
|
if equal to -1, attention dim for linears k/q/v is
|
|
equal to d_model. otherwise attention_innner_dim is used.
|
|
default -1.
|
|
attention_glu_type: str, optional
|
|
activation function for glu used in the multihead attention,
|
|
default "swish".
|
|
activation_checkpointing: str, optional
|
|
a dictionarry of {"module","interval","offload"}, where
|
|
"module": str
|
|
accept ["transformer", "attention"] to select
|
|
which module should do activation checkpointing.
|
|
"interval": int, default 1,
|
|
interval of applying activation checkpointing,
|
|
interval = 1 means that we apply checkpointing
|
|
on every layer (if activation), otherwise,
|
|
we apply it every x interval.
|
|
"offload": bool, default False,
|
|
if set to True, we offload activation to cpu and
|
|
reload it during backward, otherwise,
|
|
we recalculate activation in backward.
|
|
default "".
|
|
export: bool, optional
|
|
if set to True, it remove the padding from convolutional layers
|
|
and allow the onnx conversion for inference.
|
|
default False.
|
|
use_pt_scaled_dot_product_attention: bool, optional
|
|
if set to True, use pytorch's scaled dot product attention implementation in training.
|
|
attn_group_sizes: int, optional
|
|
the number of groups to use for attention, default 1 (Multi-Head Attention),
|
|
1 = typical Multi-Head Attention,
|
|
1 < attn_group_sizes < attention_heads = Grouped-Query Attention
|
|
attn_group_sizes = attenion_heads = Multi-Query Attention
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
d_model=512,
|
|
ext_pw_out_channel=0,
|
|
depthwise_seperable_out_channel=256,
|
|
depthwise_multiplier=1,
|
|
n_head=4,
|
|
d_ffn=2048,
|
|
ext_pw_kernel_size=1,
|
|
kernel_size=3,
|
|
dropout_rate=0.1,
|
|
causal=False,
|
|
batch_norm=False,
|
|
activation="relu",
|
|
chunk_se=0,
|
|
chunk_size=18,
|
|
conv_activation="relu",
|
|
conv_glu_type="sigmoid",
|
|
bias_in_glu=True,
|
|
linear_glu_in_convm=False,
|
|
attention_innner_dim=-1,
|
|
attention_glu_type="swish",
|
|
activation_checkpointing="",
|
|
export=False,
|
|
use_pt_scaled_dot_product_attention=False,
|
|
attn_group_sizes: int = 1,
|
|
):
|
|
super().__init__()
|
|
|
|
self.feed_forward_in = FeedForward(
|
|
d_model=d_model,
|
|
d_inner=d_ffn,
|
|
dropout_rate=dropout_rate,
|
|
activation=activation,
|
|
bias_in_glu=bias_in_glu,
|
|
)
|
|
|
|
self.self_attn = encoder_checkpoint_wrapper(
|
|
activation_checkpointing,
|
|
MultiHeadedAttention,
|
|
)(
|
|
MultiHeadedAttention(
|
|
n_head,
|
|
d_model,
|
|
dropout_rate,
|
|
attention_innner_dim,
|
|
attention_glu_type,
|
|
bias_in_glu,
|
|
use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
|
|
group_size=attn_group_sizes,
|
|
)
|
|
)
|
|
self.conv = ConvModule(
|
|
d_model,
|
|
ext_pw_out_channel,
|
|
depthwise_seperable_out_channel,
|
|
ext_pw_kernel_size,
|
|
kernel_size,
|
|
depthwise_multiplier,
|
|
dropout_rate,
|
|
causal,
|
|
batch_norm,
|
|
chunk_se,
|
|
chunk_size,
|
|
conv_activation,
|
|
conv_glu_type,
|
|
bias_in_glu,
|
|
linear_glu_in_convm,
|
|
export=export,
|
|
)
|
|
|
|
self.feed_forward_out = FeedForward(
|
|
d_model=d_model,
|
|
d_inner=d_ffn,
|
|
dropout_rate=dropout_rate,
|
|
activation=activation,
|
|
bias_in_glu=bias_in_glu,
|
|
)
|
|
|
|
self.layer_norm_att = nn.LayerNorm(d_model)
|
|
self.layer_norm = nn.LayerNorm(d_model)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
pos_k,
|
|
pos_v,
|
|
mask,
|
|
relative_attention_bias: Optional[Tensor] = None,
|
|
):
|
|
"""ConformerEncoder forward.
|
|
|
|
Args:
|
|
x: torch.Tensor
|
|
input feature of shape (batch, max_time_in, size)
|
|
pos_k: torch.Tensor
|
|
positional key embedding.
|
|
mask: torch.Tensor
|
|
mask for x (batch, max_time_in)
|
|
relative_attention_bias: Optional[torch.Tensor]
|
|
bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2)
|
|
"""
|
|
x = x + 0.5 * self.feed_forward_in(x)
|
|
norm_x = self.layer_norm_att(x)
|
|
|
|
x = x + self.self_attn(
|
|
norm_x,
|
|
norm_x,
|
|
norm_x,
|
|
pos_k,
|
|
pos_v,
|
|
mask,
|
|
relative_attention_bias=relative_attention_bias,
|
|
)
|
|
x = x + self.conv(x)
|
|
x = x + 0.5 * self.feed_forward_out(x)
|
|
|
|
out = self.layer_norm(x)
|
|
|
|
return out, pos_k, pos_v, mask
|
|
|
|
class TransformerEncoderBase(abc.ABC, nn.Module):
|
|
"""The Base class for Transformer based encoders
|
|
|
|
Please set causal = True in streaming model
|
|
Args:
|
|
input_size: int
|
|
input feature dimension.
|
|
chunk_size: int, list(int)
|
|
Number of frames for each chunk
|
|
This variable can take 2 forms:
|
|
int: Used for inference, or single chunk size training
|
|
list(int) : Used only for variable chunk size training
|
|
Some examples for the 2 cases:
|
|
chunk_size = 12
|
|
chunk_size = [6, 8, 12, 24]
|
|
left_chunk: int, list(int)
|
|
Number of chunks used for masking in streaming mode.
|
|
This variable can take 2 forms:
|
|
int: Used for inference, or single chunk size training
|
|
list(int) : Used only for variable chunk size training. When
|
|
chunk_size is a list, left_chunk must be a list with same length.
|
|
Some examples for the 2 cases:
|
|
left_chunk = 6
|
|
left_chunk = [12, 9, 6, 3]
|
|
attention_dim: int, optional
|
|
attention dimension. default 256.
|
|
attention_heads: int, optional
|
|
the number of heads. default 4
|
|
input_layer: str, optional
|
|
input layer type before Conformer,
|
|
one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
|
|
default "conv2d"
|
|
cnn_out: int, optional
|
|
the number of CNN channels before Conformer.
|
|
default -1.
|
|
cnn_layer_norm: bool, optional
|
|
layer norm between Conformer and the first CNN.
|
|
default False.
|
|
time_reduction: int, optional
|
|
time reduction factor
|
|
default 4
|
|
dropout_rate: float, optional
|
|
dropout rate. default 0.1
|
|
padding_idx: int, optional
|
|
padding index for input_layer=embed
|
|
default -1
|
|
relative_attention_bias_args: dict, optional
|
|
use more efficient scalar bias-based relative multihead attention (Q*K^T + B)
|
|
implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias
|
|
usage: relative_attention_bias_args={"type": t5/alibi}
|
|
additional method-specific arguments can be provided (see transformer_base.py)
|
|
positional_dropout_rate: float, optional
|
|
dropout rate after positional encoding. default 0.0
|
|
nemo_conv_settings: dict, optional
|
|
A dictionary of settings for NeMo Subsampling.
|
|
default None
|
|
conv2d_extra_padding: str, optional
|
|
Add extra padding in conv2d subsampling layers. Choices are
|
|
(feat, feat_time, none, True).
|
|
if True or feat_time, the extra padding is added into non full
|
|
supraframe utts in batch.
|
|
Default: none
|
|
attention_group_size: int, optional
|
|
the number of groups to use for attention, default 1 (Multi-Head Attention),
|
|
1 = typical Multi-Head Attention,
|
|
1 < attention_group_size < attention_heads = Grouped-Query Attention
|
|
attention_group_size = attenion_heads = Multi-Query Attention
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_size,
|
|
chunk_size,
|
|
left_chunk,
|
|
attention_dim=256,
|
|
attention_heads=4,
|
|
input_layer="nemo_conv",
|
|
cnn_out=-1,
|
|
cnn_layer_norm=False,
|
|
time_reduction=4,
|
|
dropout_rate=0.0,
|
|
padding_idx=-1,
|
|
relative_attention_bias_args=None,
|
|
positional_dropout_rate=0.0,
|
|
nemo_conv_settings=None,
|
|
conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
|
|
attention_group_size=1,
|
|
encoder_embedding_config=None,
|
|
):
|
|
super().__init__()
|
|
self.input_size = input_size
|
|
self.input_layer = input_layer
|
|
self.chunk_size = chunk_size
|
|
self.left_chunk = left_chunk
|
|
self.attention_dim = attention_dim
|
|
self.num_heads = attention_heads
|
|
self.attention_group_size = attention_group_size
|
|
self.time_reduction = time_reduction
|
|
self.nemo_conv_settings = nemo_conv_settings
|
|
self.encoder_embedding_config = encoder_embedding_config
|
|
|
|
if self.input_layer == "nemo_conv":
|
|
default_nemo_conv_settings = {
|
|
"subsampling": "dw_striding",
|
|
"subsampling_factor": self.time_reduction,
|
|
"feat_in": input_size,
|
|
"feat_out": attention_dim,
|
|
"conv_channels": 256,
|
|
"subsampling_conv_chunking_factor": 1,
|
|
"activation": nn.ReLU(),
|
|
"is_causal": False,
|
|
}
|
|
# Override any of the defaults with the incoming, user settings
|
|
if nemo_conv_settings:
|
|
default_nemo_conv_settings.update(nemo_conv_settings)
|
|
for i in ["subsampling_factor", "feat_in", "feat_out"]:
|
|
assert (
|
|
i not in nemo_conv_settings
|
|
), "{i} should be specified outside of the NeMo dictionary"
|
|
|
|
self.embed = NemoConvSubsampling(
|
|
**default_nemo_conv_settings,
|
|
)
|
|
else:
|
|
raise ValueError("unknown input_layer: " + input_layer)
|
|
|
|
self.pos_emb = AbsolutePositionalEncoding(attention_dim, positional_dropout_rate)
|
|
|
|
self.relative_attention_bias_type = (
|
|
relative_attention_bias_args.get("type") if relative_attention_bias_args else None
|
|
)
|
|
if self.relative_attention_bias_type == "t5":
|
|
assert (
|
|
self.num_heads % self.attention_group_size == 0
|
|
), "attention_group_size must divide n_head"
|
|
self.relative_attention_bias_layer = T5RelativeAttentionLogitBias(
|
|
self.num_heads // self.attention_group_size,
|
|
max_distance=relative_attention_bias_args.get("t5_bias_max_distance", 1000),
|
|
symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False),
|
|
)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
def post_init(self, init_model_config):
|
|
|
|
pretrained_speech_encoder_path = init_model_config.get('pretrained_speech_encoder_path', None)
|
|
if pretrained_speech_encoder_path:
|
|
model_state = torch.load(pretrained_speech_encoder_path, map_location="cpu")
|
|
encoder_state_dict = {}
|
|
for k, v in model_state.items():
|
|
if "encoder." in k:
|
|
tmp_k = k.replace("encoder.", "")
|
|
encoder_state_dict[tmp_k] = v
|
|
|
|
if hasattr(self, "encoder_embedding"):
|
|
del self.encoder_embedding
|
|
self.load_state_dict(encoder_state_dict)
|
|
|
|
if not hasattr(self, "encoder_embedding"):
|
|
self.encoder_embedding = MeanVarianceNormLayer(self.encoder_embedding_config["input_size"])
|
|
|
|
mean_file = init_model_config.get('mean_file', None)
|
|
invstd_file = init_model_config.get('invstd_file', None)
|
|
if mean_file is not None and invstd_file is not None:
|
|
self.encoder_embedding.load_mean_invstd(mean_file, invstd_file)
|
|
|
|
def compute_lens_change(self, feature_lens):
|
|
"""feature_lens: int
|
|
return updated feature lens.
|
|
|
|
This used to return a different lambda function for each case that computed
|
|
the right thing. That does not work within Torchscript. If you really
|
|
need this to be faster, create nn.Module()-s for all the cases and return
|
|
one of them. Torchscript does support that.
|
|
"""
|
|
if self.input_layer == "nemo_conv":
|
|
# Handle the special causal case
|
|
subsampling_causal_cond = self.nemo_conv_settings.get("subsampling", "dw_striding") in [
|
|
"dw_striding",
|
|
"striding",
|
|
"striding_conv1d",
|
|
]
|
|
is_causal = self.nemo_conv_settings.get("is_causal", False)
|
|
if is_causal and subsampling_causal_cond:
|
|
lens_change = (
|
|
torch.ceil(feature_lens / self.time_reduction).long()
|
|
if isinstance(feature_lens, Tensor)
|
|
else math.ceil(feature_lens / self.time_reduction)
|
|
)
|
|
feature_lens_remainder = feature_lens % self.time_reduction
|
|
if isinstance(feature_lens, Tensor):
|
|
lens_change[feature_lens_remainder != 1] += 1
|
|
elif feature_lens_remainder != 1:
|
|
lens_change += 1
|
|
return lens_change
|
|
ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil
|
|
return ceil_func(feature_lens / self.time_reduction)
|
|
|
|
@abc.abstractmethod
|
|
def forward(self):
|
|
"""Abstract forward method implementation."""
|
|
|
|
def _chunk_size_selection(self, chunk_size=None, left_chunk=None):
|
|
"""If chunk size is a list, we will randomly select a chunk size."""
|
|
|
|
if chunk_size is None:
|
|
chunk_size = self.chunk_size
|
|
if left_chunk is None:
|
|
left_chunk = self.left_chunk
|
|
if isinstance(chunk_size, list):
|
|
# Variable chunk size during training
|
|
chunk_size_index = int(torch.randint(low=0, high=len(chunk_size), size=(1,)))
|
|
chunk_size_train_eff = chunk_size[chunk_size_index]
|
|
if not isinstance(left_chunk, list):
|
|
raise ValueError("Since chunk_size is a list, left_chunk must be a list")
|
|
if len(left_chunk) != len(chunk_size):
|
|
raise ValueError(
|
|
"The length of left_chunk must be the same as length of chunk_size."
|
|
)
|
|
left_chunk_train_eff = left_chunk[chunk_size_index]
|
|
else:
|
|
chunk_size_train_eff = chunk_size
|
|
left_chunk_train_eff = left_chunk
|
|
|
|
return chunk_size_train_eff, left_chunk_train_eff
|
|
|
|
def _get_embed_class(self, embed):
|
|
# pylint: disable=protected-access
|
|
is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper)
|
|
is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel)
|
|
embed_class = embed
|
|
if is_embed_using_act_chkpt:
|
|
embed_class = embed._checkpoint_wrapped_module
|
|
if is_embed_fsdp_wrapped:
|
|
embed_class = embed.module
|
|
return embed_class
|
|
|
|
def _forward_embeddings_core(self, input_tensor, masks):
|
|
embed_class = self._get_embed_class(self.embed)
|
|
assert isinstance(embed_class, NemoConvSubsampling)
|
|
input_tensor, masks = self.embed(input_tensor, masks)
|
|
return input_tensor, masks
|
|
|
|
def _position_embedding(self, input_tensor):
|
|
pos_k = None
|
|
pos_v = None
|
|
if self.relative_attention_bias_layer is None:
|
|
input_tensor = self.pos_emb(input_tensor) # default to add abs sinusoid embedding
|
|
return pos_k, pos_v
|
|
|
|
def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk):
|
|
chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection(
|
|
chunk_size, left_chunk
|
|
)
|
|
|
|
# Create mask matrix for streaming
|
|
# S stores start index. if chunksize is 18, s is [0,18,36,....]
|
|
chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
|
|
# avoid randomness when run evaluation or decoding
|
|
if self.training and np.random.rand() > 0.5:
|
|
# Either first or last chunk is not complete.
|
|
# If only the last one is not complete, EOS is not effective
|
|
chunk_start_idx = seq_len - chunk_start_idx
|
|
chunk_start_idx = chunk_start_idx[::-1]
|
|
chunk_start_idx = chunk_start_idx[:-1]
|
|
chunk_start_idx = np.insert(chunk_start_idx, 0, 0)
|
|
|
|
enc_streaming_mask = (
|
|
adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk_train_eff)
|
|
.unsqueeze(0)
|
|
.expand([batch_size, -1, -1])
|
|
)
|
|
return enc_streaming_mask
|
|
|
|
def forward_embeddings(self, xs_pad, masks, chunk_size_nc=None, left_chunk_nc=None):
|
|
"""Forwarding the inputs through the top embedding layers
|
|
|
|
Args:
|
|
xs_pad: torch.Tensor
|
|
input tensor
|
|
masks: torch.Tensor
|
|
input mask
|
|
chunk_size_nc: (optional, default is None) chunk size for non-causal layers
|
|
left_chunk_nc: (optional, default is None) # of left chunks for non-causal layers
|
|
"""
|
|
# pylint: disable=R0915
|
|
# get new lens.
|
|
seq_len = int(self.compute_lens_change(xs_pad.shape[1]))
|
|
if seq_len <= 0:
|
|
raise ValueError(
|
|
f"""The squence length after time reduction is invalid: {seq_len}.
|
|
Your input feature is too short. Consider filtering out the very
|
|
short sentence from data loader""",
|
|
)
|
|
|
|
batch_size = xs_pad.shape[0]
|
|
|
|
enc_streaming_mask = self._streaming_mask(
|
|
seq_len, batch_size, self.chunk_size, self.left_chunk
|
|
)
|
|
|
|
if xs_pad.is_cuda:
|
|
enc_streaming_mask = enc_streaming_mask.cuda()
|
|
xs_pad = xs_pad.cuda()
|
|
|
|
input_tensor = xs_pad
|
|
input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
|
|
|
|
streaming_mask = enc_streaming_mask
|
|
if streaming_mask is not None and masks is not None:
|
|
hs_mask = masks & streaming_mask
|
|
elif masks is not None:
|
|
hs_mask = masks
|
|
else:
|
|
hs_mask = streaming_mask
|
|
|
|
if chunk_size_nc is not None:
|
|
enc_streaming_mask_nc = self._streaming_mask(
|
|
seq_len, batch_size, chunk_size_nc, left_chunk_nc
|
|
)
|
|
if xs_pad.is_cuda:
|
|
enc_streaming_mask_nc = enc_streaming_mask_nc.cuda()
|
|
if masks is not None:
|
|
hs_mask_nc = masks & enc_streaming_mask_nc
|
|
else:
|
|
hs_mask_nc = enc_streaming_mask_nc
|
|
else:
|
|
hs_mask_nc = None
|
|
|
|
pos_k, pos_v = self._position_embedding(input_tensor)
|
|
|
|
if chunk_size_nc is None:
|
|
return input_tensor, pos_k, pos_v, hs_mask, masks
|
|
return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc
|
|
|
|
def get_offset(self):
|
|
"""Returns offset used when retaining inputs for decoding.
|
|
|
|
This is essentially, how many additional frames have to be added to
|
|
the front-end CNN input to ensure it can produce a single output.
|
|
So if the "padding" parameter is 0, typically offset will be > 0.
|
|
"""
|
|
return get_offset(self.input_layer, self.time_reduction)
|
|
|
|
|
|
def get_offset(input_layer: str, time_reduction: int):
|
|
"""Get an offset. We will use the offset for determining #frames of a subsampled feature.
|
|
|
|
Args:
|
|
input_layer (str): Type of an input layer
|
|
time_reduction (int): time reduction factor for downsampling a feature
|
|
Returns:
|
|
int: offset
|
|
"""
|
|
if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4:
|
|
return 3
|
|
if input_layer in ("conv2d",) and time_reduction == 6:
|
|
return 1
|
|
if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8:
|
|
return 7
|
|
return 0
|
|
|
|
|
|
class ConformerEncoder(TransformerEncoderBase):
|
|
"""ConformerEncoder module.
|
|
see original paper for more details:
|
|
https://arxiv.org/abs/2005.08100
|
|
|
|
Please set causal = True in streaming model
|
|
Args:
|
|
input_size: int
|
|
input feature dimension.
|
|
chunk_size: int, list(int)
|
|
Number of frames for each chunk
|
|
This variable can take 2 forms:
|
|
int: Used for inference, or single chunk size training
|
|
list(int) : Used only for variable chunk size training
|
|
Some examples for the 2 cases:
|
|
chunk_size = 12
|
|
chunk_size = [6, 8, 12, 24]
|
|
left_chunk: int, list(int)
|
|
Number of chunks used for masking in streaming mode.
|
|
This variable can take 2 forms:
|
|
int: Used for inference, or single chunk size training
|
|
list(int) : Used only for variable chunk size training. When
|
|
chunk_size is a list, left_chunk must be a list with same length.
|
|
Some examples for the 2 cases:
|
|
left_chunk = 6
|
|
left_chunk = [12, 9, 6, 3]
|
|
left_chunk: int
|
|
number of chunks used for masking in streaming mode.
|
|
num_lang: int
|
|
This parameter is used to store the number of languages in the lang_dict,
|
|
only used for multiseed/multilingual models. default None.
|
|
attention_dim: int, optional
|
|
attention dimension. default 256.
|
|
attention_heads: int, optional
|
|
the number of heads. default 4
|
|
linear_units:
|
|
the number of units of position-wise feed forward.
|
|
default 2048
|
|
num_block:
|
|
number of Transformer layer. default 6
|
|
dropout_rate: float, optional
|
|
dropout rate. default 0.1
|
|
input_layer: str, optional
|
|
input layer type before Conformer,
|
|
one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
|
|
default "conv2d"
|
|
causal: bool, optional
|
|
if set to True, convolution have no access
|
|
to future frames. default False.
|
|
batch_norm: bool, optional
|
|
if set to True, apply batchnorm before activation
|
|
in ConvModule layer of the conformer.
|
|
default False
|
|
cnn_out: int, optional
|
|
the number of CNN channels before Conformer.
|
|
default -1.
|
|
cnn_layer_norm: bool, optional
|
|
layer norm between Conformer and the first CNN.
|
|
default False.
|
|
ext_pw_out_channel: int, optional
|
|
the number of channel for CNN
|
|
before depthwise_seperable_CNN.
|
|
If 0 then use linear. default 0.
|
|
ext_pw_kernel_size: int, optional
|
|
kernel size of N before depthwise_seperable_CNN.
|
|
only work for ext_pw_out_channel > 0.
|
|
default 1
|
|
depthwise_seperable_out_channel: int, optional
|
|
the number of channel for
|
|
depthwise_seperable_CNN.
|
|
default 256.
|
|
depthwise_multiplier: int, optional
|
|
the number of multiplier for
|
|
depthwise_seperable_CNN.
|
|
default 1.
|
|
chunk_se: int, optional
|
|
0 for offline SE.
|
|
1 for streaming SE, where mean is computed
|
|
by accumulated history until current chunk_se.
|
|
2 for streaming SE, where mean is computed
|
|
by only the current chunk.
|
|
default 0.
|
|
kernel_size: int, optional
|
|
the number of kernels for depthwise_seperable_CNN.
|
|
default 3.
|
|
activation: str, optional
|
|
FeedForward block activation.
|
|
one of ["relu", "swish", "sigmoid"]
|
|
default "relu".
|
|
conv_activation: str, optional
|
|
activation function used in ConvModule part
|
|
of the conformer, default "relu".
|
|
conv_glu_type: str, otional
|
|
activation used use glu in depthwise_seperable_CNN,
|
|
default "sigmoid"
|
|
bias_in_glu: bool, optional
|
|
if set to True, use additive bias in the weight module
|
|
before GLU. default True
|
|
linear_glu_in_convm: bool, optional
|
|
if set to True, use GLULinear module,
|
|
otherwise, used GLUPointWiseConv module.
|
|
default to False.
|
|
attention_glu_type: str
|
|
only work for glu_in_attention !=0
|
|
default "swish".
|
|
export: bool, optional
|
|
if set to True, it remove the padding from convolutional layers
|
|
and allow the onnx conversion for inference.
|
|
default False.
|
|
activation_checkpointing: str, optional
|
|
a dictionarry of {"module","interval","offload"}, where
|
|
"module": str
|
|
accept ["transformer", "attention"] to select
|
|
which module should do activation checkpointing.
|
|
"interval": int, default 1,
|
|
interval of applying activation checkpointing,
|
|
interval = 1 means that we apply checkpointing
|
|
on every layer (if activation), otherwise,
|
|
we apply it every x interval.
|
|
"offload": bool, default False,
|
|
if set to True, we offload activation to cpu and
|
|
reload it during backward, otherwise,
|
|
we recalculate activation in backward.
|
|
default "".
|
|
extra_layer_output_idx: int
|
|
the layer index to be exposed.
|
|
relative_attention_bias_args: dict, optional
|
|
use more efficient scalar bias-based relative multihead attention (Q*K^T + B)
|
|
implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias
|
|
usage: relative_attention_bias_args={"type": t5/alibi}
|
|
additional method-specific arguments can be provided (see transformer_base.py)
|
|
time_reduction: int optional
|
|
time reduction factor
|
|
default 4
|
|
use_pt_scaled_dot_product_attention: whether to use pytorch scaled dot product attention
|
|
in training.
|
|
Default: False
|
|
nemo_conv_settings: dict, optional
|
|
A dictionary of settings for NeMo Subsampling.
|
|
default: None
|
|
usage: nemo_conv_settings=
|
|
{
|
|
"subsampling":
|
|
dw_striding/striding/dw_striding_conv1d/striding_conv1d,
|
|
"conv_channels": int,
|
|
"subsampling_conv_chunking_factor": int,
|
|
"is_causal": True/False
|
|
}
|
|
conv2d_extra_padding: str, optional
|
|
Add extra padding in conv2d subsampling layers. Choices are
|
|
(feat, feat_time, none, True)
|
|
Default: none
|
|
replication_pad_for_subsample_embedding: For batched-streaming decoding, use
|
|
"replication" padding for the cache at start of utterance.
|
|
Default: False
|
|
attention_group_size: int, optional
|
|
the number of groups to use for attention, default 1 (Multi-Head Attention),
|
|
1 = typical Multi-Head Attention,
|
|
1 < attention_group_size < attention_heads = Grouped-Query Attention
|
|
attention_group_size = attenion_heads = Multi-Query Attention
|
|
"""
|
|
|
|
extra_multi_layer_output_idxs: List[int]
|
|
|
|
def __init__( # pylint: disable-all
|
|
self,
|
|
input_size,
|
|
chunk_size,
|
|
left_chunk,
|
|
num_lang=None,
|
|
attention_dim=256,
|
|
attention_heads=4,
|
|
linear_units=2048,
|
|
num_blocks=6,
|
|
dropout_rate=0.1,
|
|
input_layer="nemo_conv",
|
|
causal=True,
|
|
batch_norm=False,
|
|
cnn_out=-1,
|
|
cnn_layer_norm=False,
|
|
ext_pw_out_channel=0,
|
|
ext_pw_kernel_size=1,
|
|
depthwise_seperable_out_channel=256,
|
|
depthwise_multiplier=1,
|
|
chunk_se=0,
|
|
kernel_size=3,
|
|
activation="relu",
|
|
conv_activation="relu",
|
|
conv_glu_type="sigmoid",
|
|
bias_in_glu=True,
|
|
linear_glu_in_convm=False,
|
|
attention_glu_type="swish",
|
|
export=False,
|
|
extra_layer_output_idx=-1,
|
|
extra_multi_layer_output_idxs=[],
|
|
activation_checkpointing="",
|
|
relative_attention_bias_args=None,
|
|
time_reduction=4,
|
|
use_pt_scaled_dot_product_attention=False,
|
|
nemo_conv_settings=None,
|
|
conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
|
|
replication_pad_for_subsample_embedding=False,
|
|
attention_group_size=1,
|
|
encoder_embedding_config=None,
|
|
):
|
|
super().__init__(
|
|
input_size,
|
|
chunk_size,
|
|
left_chunk,
|
|
attention_dim,
|
|
attention_heads,
|
|
input_layer,
|
|
cnn_out,
|
|
cnn_layer_norm,
|
|
time_reduction,
|
|
dropout_rate=dropout_rate,
|
|
relative_attention_bias_args=relative_attention_bias_args,
|
|
positional_dropout_rate=0.0,
|
|
nemo_conv_settings=nemo_conv_settings,
|
|
conv2d_extra_padding=conv2d_extra_padding,
|
|
attention_group_size=attention_group_size,
|
|
encoder_embedding_config=encoder_embedding_config,
|
|
)
|
|
self.num_blocks = num_blocks
|
|
self.num_lang = num_lang
|
|
self.kernel_size = kernel_size
|
|
self.embed = embedding_checkpoint_wrapper(activation_checkpointing)(self.embed)
|
|
self.replication_pad_for_subsample_embedding: bool = replication_pad_for_subsample_embedding
|
|
assert self.num_heads % attention_group_size == 0, "attention_group_size must divide n_head"
|
|
self.num_heads_k = self.num_heads // attention_group_size
|
|
|
|
self.encoders = repeat(
|
|
num_blocks,
|
|
lambda i: encoder_checkpoint_wrapper(
|
|
activation_checkpointing, ConformerEncoderLayer, i
|
|
)(
|
|
ConformerEncoderLayer(
|
|
d_model=attention_dim,
|
|
ext_pw_out_channel=ext_pw_out_channel,
|
|
depthwise_seperable_out_channel=depthwise_seperable_out_channel,
|
|
depthwise_multiplier=depthwise_multiplier,
|
|
n_head=attention_heads,
|
|
d_ffn=linear_units,
|
|
ext_pw_kernel_size=ext_pw_kernel_size,
|
|
kernel_size=kernel_size,
|
|
dropout_rate=dropout_rate,
|
|
causal=causal,
|
|
batch_norm=batch_norm,
|
|
activation=activation,
|
|
chunk_se=chunk_se,
|
|
chunk_size=chunk_size,
|
|
conv_activation=conv_activation,
|
|
conv_glu_type=conv_glu_type,
|
|
bias_in_glu=bias_in_glu,
|
|
linear_glu_in_convm=linear_glu_in_convm,
|
|
attention_glu_type=attention_glu_type,
|
|
activation_checkpointing=attn_checkpointing(activation_checkpointing, i),
|
|
export=export,
|
|
use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
|
|
attn_group_sizes=attention_group_size,
|
|
)
|
|
),
|
|
)
|
|
self.extra_layer_output_idx = extra_layer_output_idx
|
|
self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
|
|
# Make a zeros scalar we can use in get_initial_state to determine
|
|
# the device and the needed dtype:
|
|
self.register_buffer("dev_type", torch.zeros(()), persistent=False)
|
|
|
|
def init_relative_attention_bias(self, input_tensor):
|
|
if self.relative_attention_bias_layer:
|
|
return self.relative_attention_bias_layer(input_tensor)
|
|
|
|
def calculate_hs_mask(self, xs_pad, device, mask):
|
|
max_audio_length = xs_pad.shape[1]
|
|
batch_size = xs_pad.shape[0]
|
|
enc_streaming_mask = self._streaming_mask(
|
|
max_audio_length, batch_size, self.chunk_size, self.left_chunk
|
|
)
|
|
enc_streaming_mask = enc_streaming_mask.to(device)
|
|
if mask is None:
|
|
return enc_streaming_mask
|
|
|
|
feature_lens = mask.sum(1)
|
|
padding_length = feature_lens
|
|
pad_mask = (
|
|
torch.arange(0, max_audio_length, device=device).expand(padding_length.size(0), -1)
|
|
< padding_length.unsqueeze(1)
|
|
)
|
|
pad_mask = pad_mask.unsqueeze(1)
|
|
pad_mask = pad_mask & enc_streaming_mask
|
|
return pad_mask
|
|
|
|
@torch.jit.ignore
|
|
def forward(self, xs_pad, masks):
|
|
"""Conformer Forward function
|
|
|
|
Args:
|
|
xs_pad: torch.Tensor
|
|
input tensor
|
|
masks: torch.Tensor
|
|
post-embedding input lengths
|
|
"""
|
|
xs_pad = self.encoder_embedding(xs_pad)
|
|
input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings(xs_pad, masks)
|
|
|
|
unfolded = False
|
|
ori_bz, seq_len, D = input_tensor.shape
|
|
max_seq_len = 500 #maxium position for absolute positional encoding
|
|
if seq_len > max_seq_len:
|
|
# audio sequence is longer than max_seq_len, unfold it into chunks of max_seq_len
|
|
unfolded = True
|
|
# the unfold op will drop residual frames, pad it to the multiple of max_seq_len
|
|
if seq_len % max_seq_len > 0:
|
|
chunk_pad_size = max_seq_len - (seq_len % max_seq_len)
|
|
else:
|
|
chunk_pad_size = 0
|
|
if chunk_pad_size > 0:
|
|
input_tensor_pad = F.pad(input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0)
|
|
input_tensor = input_tensor_pad.to(input_tensor.device)
|
|
|
|
input_tensor = unfold_tensor(input_tensor, max_seq_len)
|
|
if masks is not None:
|
|
# revise hs_mask here because the previous calculated hs_mask did not consider extra pad
|
|
subsampled_pad_mask = masks.squeeze(1) # [bz, subsampled_unmask_seq_len]
|
|
extra_padded_subsamlped_pad_mask = F.pad(subsampled_pad_mask, (0, chunk_pad_size), "constant", False) # extra padding to the pad mask
|
|
extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float()
|
|
masks_unfold = unfold_tensor(extra_padded_subsamlped_pad_mask, max_seq_len) # unfold the pad mask like we did to the input tensor
|
|
masks_unfold = masks_unfold.squeeze(-1).bool() # unfold op does not support bool tensor
|
|
else:
|
|
masks_unfold = None
|
|
hs_mask = self.calculate_hs_mask(input_tensor, input_tensor.device, masks_unfold) # calculate hs_mask based on the unfolded pad mask
|
|
layer_emb = None
|
|
|
|
relative_attention_bias = self.init_relative_attention_bias(input_tensor)
|
|
|
|
_simplified_path = (
|
|
self.extra_layer_output_idx == -1
|
|
and relative_attention_bias is None
|
|
)
|
|
|
|
if _simplified_path:
|
|
input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask)
|
|
else:
|
|
for i, layer in enumerate(self.encoders):
|
|
input_tensor, _, _, _ = layer(
|
|
input_tensor,
|
|
pos_k,
|
|
pos_v,
|
|
hs_mask,
|
|
relative_attention_bias=relative_attention_bias,
|
|
)
|
|
|
|
if i == self.extra_layer_output_idx:
|
|
layer_emb = input_tensor
|
|
if unfolded:
|
|
embed_dim = input_tensor.shape[-1]
|
|
input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim)
|
|
# if we ever padded before unfolding, we need to remove the padding
|
|
if chunk_pad_size > 0:
|
|
input_tensor = input_tensor[:, :-chunk_pad_size, :]
|
|
return input_tensor, masks #, layer_emb
|
|
|
|
def gradient_checkpointing_enable(self):
|
|
pass
|