Add the support for the qwen3 next model (a hybrid attention model). (#24526)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -70,6 +70,15 @@ class MambaStateDtypeCalculator:
|
||||
model_dtype)
|
||||
return (conv_state_dtype, )
|
||||
|
||||
@classmethod
|
||||
def gated_delta_net_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (state_dtype, state_dtype)
|
||||
|
||||
|
||||
class MambaStateShapeCalculator:
|
||||
|
||||
@@ -163,3 +172,31 @@ class MambaStateShapeCalculator:
|
||||
|
||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||
return tp_size - ngroups
|
||||
|
||||
@classmethod
|
||||
def gated_delta_net_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
num_k_heads: int,
|
||||
num_v_heads: int,
|
||||
head_k_dim: int,
|
||||
head_v_dim: int,
|
||||
conv_kernel_size: int,
|
||||
num_spec: int = 0,
|
||||
use_v1: bool = True,
|
||||
):
|
||||
conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads)
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, tp_world_size),
|
||||
conv_kernel_size - 1 + num_spec,
|
||||
)
|
||||
|
||||
# In V0, the conv_state shape was swapped during allocation in
|
||||
# MambaCacheManager, but in V1 it needs to be determined here at the
|
||||
# calculation level
|
||||
if use_v1:
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
temporal_state_shape = (divide(num_v_heads,
|
||||
tp_world_size), head_k_dim, head_v_dim)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@@ -464,7 +464,9 @@ def causal_conv1d_fn(
|
||||
# 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
|
||||
# 4. computation can be skipped if cache_indices[idx] == pad_slot_id
|
||||
num_cache_lines = conv_states.size(0)
|
||||
assert (num_cache_lines, dim, width - 1) == conv_states.shape
|
||||
assert (num_cache_lines == conv_states.shape[0]
|
||||
and dim == conv_states.shape[1]
|
||||
and width - 1 <= conv_states.shape[2])
|
||||
stride_istate_seq = conv_states.stride(0)
|
||||
stride_istate_dim = conv_states.stride(1)
|
||||
stride_istate_token = conv_states.stride(2)
|
||||
@@ -623,6 +625,7 @@ def _causal_conv1d_update_kernel(
|
||||
conv_state_ptr,
|
||||
cache_seqlens_ptr, # circular buffer
|
||||
conv_state_indices_ptr,
|
||||
num_accepted_tokens_ptr,
|
||||
o_ptr, # (batch, dim, seqlen)
|
||||
# Matrix dimensions
|
||||
batch: int,
|
||||
@@ -639,6 +642,7 @@ def _causal_conv1d_update_kernel(
|
||||
stride_conv_state_seq: tl.constexpr,
|
||||
stride_conv_state_dim: tl.constexpr,
|
||||
stride_conv_state_tok: tl.constexpr,
|
||||
stride_state_indices: tl.constexpr,
|
||||
stride_o_seq: tl.constexpr,
|
||||
stride_o_dim: tl.constexpr,
|
||||
stride_o_token: tl.constexpr,
|
||||
@@ -649,6 +653,7 @@ def _causal_conv1d_update_kernel(
|
||||
KERNEL_WIDTH: tl.constexpr,
|
||||
SILU_ACTIVATION: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
NP2_STATELEN: tl.constexpr,
|
||||
USE_PAD_SLOT: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
@@ -663,8 +668,9 @@ def _causal_conv1d_update_kernel(
|
||||
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
# mask = idx_seq < batch
|
||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(
|
||||
tl.int64)
|
||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_state_indices).to(
|
||||
tl.int64)
|
||||
else:
|
||||
conv_state_batch_coord = idx_seq
|
||||
if USE_PAD_SLOT: # noqa
|
||||
@@ -672,13 +678,32 @@ def _causal_conv1d_update_kernel(
|
||||
# not processing as this is not the actual sequence
|
||||
return
|
||||
|
||||
if IS_SPEC_DECODING:
|
||||
# The rolling of conv state:
|
||||
#
|
||||
# Before forward, the conv_state is:
|
||||
# [history1, history2, ..., historyM].
|
||||
#
|
||||
# After forward, the conv_state becomes:
|
||||
# [history2, ..., historyM, draft1, draft2, ..., draftN].
|
||||
#
|
||||
# After acceptance, it becomes:
|
||||
#
|
||||
# - accept 1 tokens: [history2, ..., historyM, draft1]
|
||||
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
|
||||
# - and so on.
|
||||
conv_state_token_offset = (tl.load(num_accepted_tokens_ptr + idx_seq) -
|
||||
1)
|
||||
else:
|
||||
conv_state_token_offset = 0
|
||||
|
||||
# STEP 1: READ init_state data
|
||||
conv_states_base = (conv_state_ptr +
|
||||
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim))
|
||||
mask_w = idx_feats < dim
|
||||
|
||||
prior_tokens = conv_states_base
|
||||
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
|
||||
if KERNEL_WIDTH >= 2:
|
||||
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||
@@ -695,10 +720,14 @@ def _causal_conv1d_update_kernel(
|
||||
# STEP 2: assume state_len > seqlen
|
||||
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||
|
||||
# The conv_state updates works in a sliding window manner,
|
||||
# at each forward pass, the tokens are shift by 1, so we
|
||||
# load since idx_tokens + 1.
|
||||
conv_state_ptrs_source = (
|
||||
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
|
||||
conv_state_token_offset * stride_conv_state_tok +
|
||||
(idx_feats * stride_conv_state_dim)[None, :] +
|
||||
((idx_tokens + seqlen) * stride_conv_state_tok)[:, None]
|
||||
((idx_tokens + 1) * stride_conv_state_tok)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
mask = ((conv_state_batch_coord < num_cache_lines)
|
||||
& ((idx_tokens + seqlen) < state_len)[:, None]
|
||||
@@ -820,6 +849,7 @@ def causal_conv1d_update(
|
||||
activation: Union[bool, str, None] = None,
|
||||
cache_seqlens: Optional[torch.Tensor] = None,
|
||||
conv_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
metadata=None,
|
||||
validate_data=False,
|
||||
@@ -890,10 +920,11 @@ def causal_conv1d_update(
|
||||
) # X (batch, dim, seqlen)
|
||||
|
||||
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
|
||||
|
||||
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
|
||||
)
|
||||
state_len = width - 1
|
||||
stride_state_indices = conv_state_indices.stride(
|
||||
0) if conv_state_indices is not None else 0
|
||||
state_len = width - 1 + (seqlen - 1) # effective state_len needed
|
||||
np2_statelen = triton.next_power_of_2(state_len)
|
||||
|
||||
def grid(META):
|
||||
@@ -910,6 +941,7 @@ def causal_conv1d_update(
|
||||
conv_state,
|
||||
cache_seqlens,
|
||||
conv_state_indices,
|
||||
num_accepted_tokens,
|
||||
out,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
@@ -926,6 +958,7 @@ def causal_conv1d_update(
|
||||
stride_istate_seq,
|
||||
stride_istate_dim,
|
||||
stride_istate_token,
|
||||
stride_state_indices,
|
||||
stride_o_seq,
|
||||
stride_o_dim,
|
||||
stride_o_token,
|
||||
@@ -936,6 +969,7 @@ def causal_conv1d_update(
|
||||
KERNEL_WIDTH=width,
|
||||
SILU_ACTIVATION=activation in ["silu", "swish"],
|
||||
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
|
||||
IS_SPEC_DECODING=num_accepted_tokens is not None,
|
||||
NP2_STATELEN=np2_statelen,
|
||||
USE_PAD_SLOT=pad_slot_id is not None,
|
||||
BLOCK_N=256,
|
||||
|
||||
Reference in New Issue
Block a user