187 lines
7.6 KiB
Python
187 lines
7.6 KiB
Python
from typing import Any, Tuple
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from torch import Tensor
|
||
from flash_attn import flash_attn_varlen_func
|
||
try:
|
||
import deepspeed.comm as dist
|
||
except:
|
||
dist = None
|
||
|
||
|
||
try:
|
||
from utils import (
|
||
get_sequence_parallel_group,
|
||
get_sequence_parallel_size,
|
||
get_sequence_parallel_rank
|
||
)
|
||
except (ModuleNotFoundError, ImportError):
|
||
# 从 utils 获取seq parallel设置,import不成功默认为不开启
|
||
get_sequence_parallel_group = lambda : None
|
||
get_sequence_parallel_size = lambda : 1
|
||
get_sequence_parallel_rank = lambda : 0
|
||
|
||
|
||
def single_all_to_all(input, scatter_idx, gather_idx, group):
|
||
seq_world_size = dist.get_world_size(group)
|
||
inp_shape = list(input.shape)
|
||
inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
|
||
if scatter_idx < 2:
|
||
input_t = input.reshape(
|
||
[seq_world_size, inp_shape[scatter_idx]] + \
|
||
inp_shape[scatter_idx + 1:]
|
||
).contiguous()
|
||
else:
|
||
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
|
||
input_t = input.reshape(
|
||
[-1, seq_world_size, inp_shape[scatter_idx]] + \
|
||
inp_shape[scatter_idx + 1:]
|
||
).transpose(0, 1).contiguous()
|
||
|
||
output = torch.empty_like(input_t)
|
||
dist.all_to_all_single(output, input_t, group=group)
|
||
|
||
# if scattering the seq-dim, transpose the heads back to the original dimension
|
||
# [sp_size, seq_len//sp_size, batch_size, head_num // sp_size, head_dim] -->
|
||
# [seq_len//sp_size,batch_size, sp_size, head_num // sp_size, head_dim]
|
||
if scatter_idx < 2:
|
||
output = output.transpose(0, 1).transpose(1, 2).contiguous()
|
||
|
||
return output.reshape(
|
||
inp_shape[: gather_idx] + \
|
||
[inp_shape[gather_idx] * seq_world_size,] + \
|
||
inp_shape[gather_idx + 1:]).contiguous()
|
||
|
||
|
||
class _SeqAllToAll(torch.autograd.Function):
|
||
|
||
@staticmethod
|
||
def forward(ctx: Any, group: 'dist.ProcessGroup', input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
|
||
ctx.group = group
|
||
ctx.scatter_idx = scatter_idx
|
||
ctx.gather_idx = gather_idx
|
||
|
||
return single_all_to_all(input, scatter_idx, gather_idx, group)
|
||
|
||
@staticmethod
|
||
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
|
||
return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
|
||
|
||
|
||
# import from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
|
||
# but fix some bugs for 符合训练的维度设置
|
||
class DistributedAttention(nn.Module):
|
||
"""Initialization.
|
||
|
||
Arguments:
|
||
local_attention (Module): local attention with q,k,v
|
||
sequence_process_group (ProcessGroup): sequence parallel process group
|
||
scatter_idx (int): scatter_idx for all2all comm
|
||
gather_idx (int): gather_idx for all2all comm
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
local_attention: nn.Module,
|
||
sequence_process_group: 'dist.ProcessGroup',
|
||
scatter_idx: int = 2,
|
||
gather_idx: int = 0,
|
||
) -> None:
|
||
|
||
super(DistributedAttention, self).__init__()
|
||
self.local_attn = local_attention
|
||
self.spg = sequence_process_group
|
||
self.scatter_idx = scatter_idx
|
||
self.gather_idx = gather_idx
|
||
|
||
def pad_attention_head(self, query: Tensor, key: Tensor, value: Tensor):
|
||
# 将输入的head 维度pad到sp_size的倍数
|
||
sp_size = torch.distributed.get_world_size(self.spg)
|
||
pad_size = (sp_size - query.size(1) % sp_size) % sp_size
|
||
if pad_size > 0:
|
||
# [bs, num_head, seq_len, head_dim] -> [bs, num_head+pad_size, seq_len, head_dim]
|
||
query = torch.nn.functional.pad(query, (0,0,0,0,0,pad_size), value = 0.01)
|
||
key = torch.nn.functional.pad(key, (0,0,0,0,0,pad_size), value = 0.01)
|
||
value = torch.nn.functional.pad(value, (0,0,0,0,0,pad_size),value=0.0)
|
||
return query, key, value
|
||
|
||
def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
|
||
""" forward
|
||
|
||
Arguments:
|
||
query (Tensor): query input to the layer [batch_size, num_head, seq_len, head_dim]
|
||
key (Tensor): key input to the layer
|
||
value (Tensor): value input to the layer
|
||
args: other args
|
||
|
||
Returns:
|
||
* output (Tensor): context output
|
||
"""
|
||
# TODO Merge three alltoall calls into one
|
||
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
|
||
# [batch_size,num_head,seq_len, head_dim ]trans to [seq_len,batch_size,num_head,head_dim]
|
||
origin_num_head = query.size(1)
|
||
query, key, value = self.pad_attention_head(query,key,value)
|
||
|
||
query = query.transpose(1,2).transpose(0,1)
|
||
key = key.transpose(1,2).transpose(0,1)
|
||
value = value.transpose(1,2).transpose(0,1)
|
||
#in shape : e.g., [s/p,bs,h,head_dim]
|
||
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
|
||
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
|
||
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
|
||
|
||
context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)
|
||
context_layer = context_layer.transpose(0,1).contiguous()
|
||
# [seq_len, batch_size, num_head, head_dim]
|
||
output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
|
||
return output.transpose(0,1)[:,:,:origin_num_head,:]
|
||
|
||
|
||
class LocalAttention(nn.Module):
|
||
def __init__(self, hidden_size, num_heads, head_dim):
|
||
super().__init__()
|
||
self.hidden_size = hidden_size
|
||
self.num_heads = num_heads
|
||
self.head_dim = head_dim
|
||
|
||
def forward(self, q, k, v, *args, use_flash=True, **kwargs):
|
||
# input q,k,v [batch_size, num_head, seq_len, head_dim]
|
||
# output [batch_size, seq_len, num_head, head_dim]
|
||
if use_flash:
|
||
q_len, num_heads = q.shape[2], q.shape[1]
|
||
q = q.transpose(1,2).reshape(-1, num_heads, self.head_dim)
|
||
k = k.transpose(1,2).reshape(-1, num_heads, self.head_dim)
|
||
v = v.transpose(1,2).reshape(-1, num_heads, self.head_dim)
|
||
return flash_attn_varlen_func(q,k,v,*args, **kwargs).reshape(-1,q_len, num_heads, self.head_dim)
|
||
else:
|
||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
||
attn_output = F.scaled_dot_product_attention(
|
||
q,k,v, *args, **kwargs)
|
||
attn_output = attn_output.transpose(1, 2)
|
||
return attn_output
|
||
|
||
|
||
def create_attention_layer(hidden_size, num_heads, head_dim):
|
||
if get_sequence_parallel_group() is None:
|
||
return LocalAttention(hidden_size, num_heads, head_dim)
|
||
else:
|
||
return DistributedAttention(
|
||
local_attention=LocalAttention(hidden_size, num_heads, head_dim),
|
||
sequence_process_group=get_sequence_parallel_group()
|
||
)
|
||
|
||
|
||
def get_sequence_parallel_chunk(tensor, dim=1, shift=0):
|
||
assert tensor.size(dim) % get_sequence_parallel_size() == 0
|
||
original_size = tensor.size(dim)
|
||
if shift:
|
||
tensor = tensor.split([shift, tensor.size(dim) - shift], dim=dim)[1]
|
||
if get_sequence_parallel_group() is None:
|
||
return tensor
|
||
else:
|
||
chunk_size = original_size // get_sequence_parallel_size()
|
||
return tensor.split(chunk_size, dim=dim)[get_sequence_parallel_rank()]
|