mPLUG-Owl3-2B-241014_a14065.../x_sdpa.py

60 lines
2.1 KiB
Python
Raw Normal View History

2024-12-26 10:53:19 +08:00
from torch import nn
from einops import rearrange
class ScaleDotProductAttention(nn.Module):
def __init__(self, layer_number, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
self.layer_number = layer_number
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
# Qwen 不需要scale
def forward(self, q, k, v, attn_mask=None, order='sbhd'):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
# (N,...,L,E)
import torch
import torch.nn as nn
import torch.nn.functional as F
if order == 'sbhd':
q, k, v = [rearrange(x, 's b h d -> b h s d').contiguous()
for x in (q, k, v)]
elif order == 'bhsd':
pass
if attn_mask is not None:
attn_mask = (~attn_mask.clone().bool()).contiguous()
else:
attn_mask = None
# attention mask, True means it will take part in attention B H s_q s_k
if self.training:
# during training q,k,v always have same seqlen
if self.causal:
assert q.shape[-2] == k.shape[-2]
is_causal = self.causal
dropout_p = self.dropout_p
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
if self.causal:
is_causal = q.shape[-2] == k.shape[-2]
else:
is_causal = self.causal
dropout_p = 0.0
# 如果is_causal则无视输入的mask 反之会使用输入的mask
o = F.scaled_dot_product_attention(q, k, v,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=self.softmax_scale
)
# B Head L D -> L B (Head D)
o = rearrange(o, 'B Head L D -> L B (Head D)').contiguous()
return o