112 lines
4.0 KiB
Python
112 lines
4.0 KiB
Python
|
from typing import Callable
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
|
||
|
def softmax(x: torch.Tensor, dim: int | tuple[int, ...]) -> torch.Tensor:
|
||
|
"""
|
||
|
Compute the softmax along the specified dimensions.
|
||
|
This function adds the option to specify multiple dimensions
|
||
|
|
||
|
Args:
|
||
|
x (torch.Tensor): Input tensor.
|
||
|
dims (int or tuple[int]): The dimension or list of dimensions along which the softmax probabilities are computed.
|
||
|
|
||
|
Returns:
|
||
|
torch.Tensor: Output tensor containing softmax probabilities along the specified dimensions.
|
||
|
"""
|
||
|
dtype = x.dtype
|
||
|
x = x.to(torch.float32)
|
||
|
max_vals = torch.amax(x, dim=dim, keepdim=True)
|
||
|
e_x = torch.exp(x - max_vals)
|
||
|
sum_exp = e_x.sum(dim=dim, keepdim=True)
|
||
|
return (e_x / sum_exp).to(dtype)
|
||
|
|
||
|
# copy from https://github.com/bwconrad/soft-moe
|
||
|
class SteelSoftMoEV3(nn.Module):
|
||
|
"""
|
||
|
A wrapper class to create a Soft Mixture of Experts layer.
|
||
|
|
||
|
From "From Sparse to Soft Mixtures of Experts"
|
||
|
https://arxiv.org/pdf/2308.00951.pdf
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
config,
|
||
|
layer: Callable,
|
||
|
) -> None:
|
||
|
"""
|
||
|
Args:
|
||
|
dim (int): Dimensionality of input features.
|
||
|
num_experts (int): Number of experts.
|
||
|
slots_per_expert (int): Number of token slots per expert.
|
||
|
layer (Callable): Network layer of the experts.
|
||
|
normalize (bool): Normalize input and phi (sec. 2.3 from paper)
|
||
|
**layer_kwargs: Additional keyword arguments for the layer class.
|
||
|
"""
|
||
|
super().__init__()
|
||
|
|
||
|
self.dim = config.hidden_size
|
||
|
self.num_experts = config.n_experts
|
||
|
self.slots_per_expert = config.slots_per_expert if hasattr(config, "slots_per_expert") else 1
|
||
|
self.normalize = True
|
||
|
|
||
|
# Initialize phi and normalization scaling factor
|
||
|
self.phi = nn.Parameter(torch.zeros(self.dim, self.num_experts, self.slots_per_expert))
|
||
|
if self.normalize:
|
||
|
self.scale = nn.Parameter(torch.ones(1))
|
||
|
|
||
|
# Initialize phi using LeCun normal initialization
|
||
|
# https://github.com/google-research/vmoe/blob/662341d007650d5bbb7c6a2bef7f3c759a20cc7e/vmoe/projects/soft_moe/router.py#L49C1-L49C1
|
||
|
nn.init.normal_(self.phi, mean=0, std=1 / self.dim**0.5)
|
||
|
|
||
|
# Create a list of expert networks
|
||
|
self.experts = nn.ModuleList(
|
||
|
[layer(config) for _ in range(self.num_experts)]
|
||
|
)
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
"""
|
||
|
Forward pass through the Soft-MoE layer (algorithm 1 from paper).
|
||
|
|
||
|
Args:
|
||
|
x (torch.Tensor): Input tensor of shape [batch_size, seq_len, input_dim].
|
||
|
|
||
|
Returns:
|
||
|
torch.Tensor: Output tensor of shape [batch_size, seq_len, input_dim].
|
||
|
"""
|
||
|
assert (
|
||
|
x.shape[-1] == self.dim
|
||
|
), f"Input feature dim of {x.shape[-1]} does not match layer dim of {self.dim}"
|
||
|
assert (
|
||
|
len(x.shape) == 3
|
||
|
), f"Input expected to have 3 dimensions but has {len(x.shape)}"
|
||
|
|
||
|
phi = self.phi
|
||
|
|
||
|
# Normalize input and phi
|
||
|
if self.normalize:
|
||
|
x = F.normalize(x, dim=2) # [b, m, d]
|
||
|
phi = self.scale * F.normalize(phi, dim=0) # [d, n, p]
|
||
|
|
||
|
# Compute dispatch and combine weights
|
||
|
logits = torch.einsum("bmd,dnp->bmnp", x, phi)
|
||
|
d = softmax(logits, dim=1)
|
||
|
c = softmax(logits, dim=(2, 3))
|
||
|
# tmp = c[0,:,:,0].reshape([c.shape[1],-1])
|
||
|
# print("num:",tmp, "shape:",tmp.shape, "sum:",tmp.sum(dim=1))
|
||
|
# Compute input slots as weighted average of input tokens using dispatch weights
|
||
|
xs = torch.einsum("bmd,bmnp->bnpd", x, d)
|
||
|
|
||
|
# Apply expert to corresponding slots
|
||
|
ys = torch.stack(
|
||
|
[f_i(xs[:, i, :, :]) for i, f_i in enumerate(self.experts)], dim=1
|
||
|
)
|
||
|
|
||
|
# Compute output tokens as weighted average of output slots using combine weights
|
||
|
y = torch.einsum("bnpd,bmnp->bmd", ys, c)
|
||
|
|
||
|
return y
|