Steel-LLM_a1373794422999859.../softmoe_v3.py

112 lines
4.0 KiB
Python

from typing import Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
def softmax(x: torch.Tensor, dim: int | tuple[int, ...]) -> torch.Tensor:
"""
Compute the softmax along the specified dimensions.
This function adds the option to specify multiple dimensions
Args:
x (torch.Tensor): Input tensor.
dims (int or tuple[int]): The dimension or list of dimensions along which the softmax probabilities are computed.
Returns:
torch.Tensor: Output tensor containing softmax probabilities along the specified dimensions.
"""
dtype = x.dtype
x = x.to(torch.float32)
max_vals = torch.amax(x, dim=dim, keepdim=True)
e_x = torch.exp(x - max_vals)
sum_exp = e_x.sum(dim=dim, keepdim=True)
return (e_x / sum_exp).to(dtype)
# copy from https://github.com/bwconrad/soft-moe
class SteelSoftMoEV3(nn.Module):
"""
A wrapper class to create a Soft Mixture of Experts layer.
From "From Sparse to Soft Mixtures of Experts"
https://arxiv.org/pdf/2308.00951.pdf
"""
def __init__(
self,
config,
layer: Callable,
) -> None:
"""
Args:
dim (int): Dimensionality of input features.
num_experts (int): Number of experts.
slots_per_expert (int): Number of token slots per expert.
layer (Callable): Network layer of the experts.
normalize (bool): Normalize input and phi (sec. 2.3 from paper)
**layer_kwargs: Additional keyword arguments for the layer class.
"""
super().__init__()
self.dim = config.hidden_size
self.num_experts = config.n_experts
self.slots_per_expert = config.slots_per_expert if hasattr(config, "slots_per_expert") else 1
self.normalize = True
# Initialize phi and normalization scaling factor
self.phi = nn.Parameter(torch.zeros(self.dim, self.num_experts, self.slots_per_expert))
if self.normalize:
self.scale = nn.Parameter(torch.ones(1))
# Initialize phi using LeCun normal initialization
# https://github.com/google-research/vmoe/blob/662341d007650d5bbb7c6a2bef7f3c759a20cc7e/vmoe/projects/soft_moe/router.py#L49C1-L49C1
nn.init.normal_(self.phi, mean=0, std=1 / self.dim**0.5)
# Create a list of expert networks
self.experts = nn.ModuleList(
[layer(config) for _ in range(self.num_experts)]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the Soft-MoE layer (algorithm 1 from paper).
Args:
x (torch.Tensor): Input tensor of shape [batch_size, seq_len, input_dim].
Returns:
torch.Tensor: Output tensor of shape [batch_size, seq_len, input_dim].
"""
assert (
x.shape[-1] == self.dim
), f"Input feature dim of {x.shape[-1]} does not match layer dim of {self.dim}"
assert (
len(x.shape) == 3
), f"Input expected to have 3 dimensions but has {len(x.shape)}"
phi = self.phi
# Normalize input and phi
if self.normalize:
x = F.normalize(x, dim=2) # [b, m, d]
phi = self.scale * F.normalize(phi, dim=0) # [d, n, p]
# Compute dispatch and combine weights
logits = torch.einsum("bmd,dnp->bmnp", x, phi)
d = softmax(logits, dim=1)
c = softmax(logits, dim=(2, 3))
# tmp = c[0,:,:,0].reshape([c.shape[1],-1])
# print("num:",tmp, "shape:",tmp.shape, "sum:",tmp.sum(dim=1))
# Compute input slots as weighted average of input tokens using dispatch weights
xs = torch.einsum("bmd,bmnp->bnpd", x, d)
# Apply expert to corresponding slots
ys = torch.stack(
[f_i(xs[:, i, :, :]) for i, f_i in enumerate(self.experts)], dim=1
)
# Compute output tokens as weighted average of output slots using combine weights
y = torch.einsum("bnpd,bmnp->bmd", ys, c)
return y