[Bugfix][Qwen3-Next] fixes the varlen issue in qwen3-next's MTP implementation. (#24957)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
@@ -626,6 +626,7 @@ def _causal_conv1d_update_kernel(
|
||||
cache_seqlens_ptr, # circular buffer
|
||||
conv_state_indices_ptr,
|
||||
num_accepted_tokens_ptr,
|
||||
query_start_loc_ptr, # (batch + 1)
|
||||
o_ptr, # (batch, dim, seqlen)
|
||||
# Matrix dimensions
|
||||
batch: int,
|
||||
@@ -652,6 +653,7 @@ def _causal_conv1d_update_kernel(
|
||||
HAS_BIAS: tl.constexpr,
|
||||
KERNEL_WIDTH: tl.constexpr,
|
||||
SILU_ACTIVATION: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
NP2_STATELEN: tl.constexpr,
|
||||
@@ -678,6 +680,25 @@ def _causal_conv1d_update_kernel(
|
||||
# not processing as this is not the actual sequence
|
||||
return
|
||||
|
||||
if IS_VARLEN:
|
||||
query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64)
|
||||
query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(
|
||||
tl.int64)
|
||||
# revise state_len and seqlen
|
||||
state_len = state_len - (seqlen -
|
||||
(query_end_index - query_start_index))
|
||||
seqlen = query_end_index - query_start_index
|
||||
x_offset = query_start_index * stride_x_token
|
||||
o_offset = query_start_index * stride_o_token
|
||||
else:
|
||||
query_start_index = idx_seq * seqlen
|
||||
query_end_index = query_start_index + seqlen
|
||||
x_offset = idx_seq * stride_x_seq
|
||||
o_offset = idx_seq * stride_o_seq
|
||||
|
||||
if query_start_index == query_end_index:
|
||||
return
|
||||
|
||||
if IS_SPEC_DECODING:
|
||||
# The rolling of conv state:
|
||||
#
|
||||
@@ -692,8 +713,8 @@ def _causal_conv1d_update_kernel(
|
||||
# - accept 1 tokens: [history2, ..., historyM, draft1]
|
||||
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
|
||||
# - and so on.
|
||||
conv_state_token_offset = (tl.load(num_accepted_tokens_ptr + idx_seq) -
|
||||
1)
|
||||
conv_state_token_offset = (
|
||||
tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1)
|
||||
else:
|
||||
conv_state_token_offset = 0
|
||||
|
||||
@@ -713,9 +734,12 @@ def _causal_conv1d_update_kernel(
|
||||
if KERNEL_WIDTH >= 4:
|
||||
conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
|
||||
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||
if KERNEL_WIDTH == 5:
|
||||
if KERNEL_WIDTH >= 5:
|
||||
conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
|
||||
col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||
if KERNEL_WIDTH >= 6:
|
||||
conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N]
|
||||
col4 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||
|
||||
# STEP 2: assume state_len > seqlen
|
||||
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||
@@ -735,8 +759,7 @@ def _causal_conv1d_update_kernel(
|
||||
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
|
||||
|
||||
VAL = state_len - seqlen
|
||||
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim
|
||||
) # [BLOCK_N]
|
||||
x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N]
|
||||
|
||||
x_ptrs = x_base[None, :] + (
|
||||
(idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
|
||||
@@ -782,12 +805,18 @@ def _causal_conv1d_update_kernel(
|
||||
if KERNEL_WIDTH >= 4:
|
||||
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
|
||||
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||
if KERNEL_WIDTH >= 5:
|
||||
w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor
|
||||
w_col4 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||
if KERNEL_WIDTH >= 6:
|
||||
w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor
|
||||
w_col5 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||
|
||||
x_base_1d = x_base # starting of chunk [BLOCK_N]
|
||||
mask_x_1d = idx_feats < dim
|
||||
|
||||
# STEP 5: compute each token
|
||||
for idx_token in tl.static_range(seqlen):
|
||||
for idx_token in tl.range(seqlen):
|
||||
acc = acc_preload
|
||||
|
||||
matrix_w = w_col0
|
||||
@@ -817,6 +846,37 @@ def _causal_conv1d_update_kernel(
|
||||
matrix_w = w_col3
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||
elif KERNEL_WIDTH == 5:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
matrix_x = col1
|
||||
elif j == 2:
|
||||
matrix_w = w_col2
|
||||
matrix_x = col2
|
||||
elif j == 3:
|
||||
matrix_w = w_col3
|
||||
matrix_x = col3
|
||||
elif j == 4:
|
||||
matrix_w = w_col4
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||
elif KERNEL_WIDTH == 6:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
matrix_x = col1
|
||||
elif j == 2:
|
||||
matrix_w = w_col2
|
||||
matrix_x = col2
|
||||
elif j == 3:
|
||||
matrix_w = w_col3
|
||||
matrix_x = col3
|
||||
elif j == 4:
|
||||
matrix_w = w_col4
|
||||
matrix_x = col4
|
||||
elif j == 5:
|
||||
matrix_w = w_col5
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||
|
||||
acc += matrix_x * matrix_w # [BLOCK_N]
|
||||
|
||||
@@ -829,14 +889,24 @@ def _causal_conv1d_update_kernel(
|
||||
col0 = col1
|
||||
col1 = col2
|
||||
col2 = matrix_x
|
||||
elif KERNEL_WIDTH == 5:
|
||||
col0 = col1
|
||||
col1 = col2
|
||||
col2 = col3
|
||||
col3 = matrix_x
|
||||
elif KERNEL_WIDTH == 6:
|
||||
col0 = col1
|
||||
col1 = col2
|
||||
col2 = col3
|
||||
col3 = col4
|
||||
col4 = matrix_x
|
||||
|
||||
if SILU_ACTIVATION:
|
||||
acc = acc / (1 + tl.exp(-acc))
|
||||
mask_1d = (idx_token < seqlen) & (idx_feats < dim
|
||||
) # token-index # feature-index
|
||||
o_ptrs = o_ptr + (
|
||||
idx_seq) * stride_o_seq + idx_token * stride_o_token + (
|
||||
idx_feats * stride_o_dim)
|
||||
o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats *
|
||||
stride_o_dim)
|
||||
|
||||
tl.store(o_ptrs, acc, mask=mask_1d)
|
||||
|
||||
@@ -850,14 +920,18 @@ def causal_conv1d_update(
|
||||
cache_seqlens: Optional[torch.Tensor] = None,
|
||||
conv_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
query_start_loc: Optional[torch.Tensor] = None,
|
||||
max_query_len: int = -1,
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
metadata=None,
|
||||
validate_data=False,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim) or (batch, dim, seqlen)
|
||||
x: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim)
|
||||
[shape=2: single token prediction]
|
||||
[shape=3: single or multiple tokens prediction]
|
||||
[shape=2 with num_tokens: continuous batching, where num_tokens is the
|
||||
total tokens of all sequences in that batch]
|
||||
conv_state: (..., dim, state_len), where state_len >= width - 1
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
@@ -870,13 +944,24 @@ def causal_conv1d_update(
|
||||
If not None, the conv_state is a larger tensor along the batch dim,
|
||||
and we are selecting the batch coords specified by conv_state_indices.
|
||||
Useful for a continuous batching scenario.
|
||||
num_accepted_tokens: (batch,), dtype int32
|
||||
If not None, it indicates the number of accepted tokens for each
|
||||
sequence in the batch.
|
||||
This is used in speculative decoding, where the conv_state is updated
|
||||
in a sliding window manner.
|
||||
query_start_loc: (batch + 1,) int32
|
||||
If not None, the inputs is given in a varlen fashion and this indicates
|
||||
the starting index of each sequence in the batch.
|
||||
max_query_len: int
|
||||
If query_start_loc is not None, this indicates the maximum query
|
||||
length in the batch.
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: (batch, dim) or (batch, dim, seqlen)
|
||||
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
|
||||
"""
|
||||
if validate_data:
|
||||
assert cache_seqlens is None # not implemented yet - ok for vLLM
|
||||
@@ -886,11 +971,17 @@ def causal_conv1d_update(
|
||||
activation = "silu" if activation is True else None
|
||||
elif activation is not None:
|
||||
assert activation in ["silu", "swish"]
|
||||
unsqueeze = x.dim() == 2
|
||||
unsqueeze = query_start_loc is None and x.dim() == 2
|
||||
if unsqueeze:
|
||||
# make it (batch, dim, seqlen) with seqlen == 1
|
||||
x = x.unsqueeze(-1)
|
||||
batch, dim, seqlen = x.shape
|
||||
if query_start_loc is None:
|
||||
batch, dim, seqlen = x.shape
|
||||
else:
|
||||
assert conv_state_indices is not None
|
||||
batch = conv_state_indices.size(0)
|
||||
dim = x.size(1)
|
||||
seqlen = max_query_len
|
||||
_, width = weight.shape
|
||||
# conv_state: (..., dim, state_len), where state_len >= width - 1
|
||||
num_cache_lines, _, state_len = conv_state.size()
|
||||
@@ -916,10 +1007,17 @@ def causal_conv1d_update(
|
||||
out = x
|
||||
stride_w_dim, stride_w_width = weight.stride()
|
||||
|
||||
stride_x_seq, stride_x_dim, stride_x_token = x.stride(
|
||||
) # X (batch, dim, seqlen)
|
||||
if query_start_loc is None:
|
||||
# X (batch, dim, seqlen)
|
||||
stride_x_seq, stride_x_dim, stride_x_token = x.stride()
|
||||
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
|
||||
else:
|
||||
# X (dim, cu_seqlen)
|
||||
stride_x_token, stride_x_dim = x.stride()
|
||||
stride_x_seq = 0
|
||||
stride_o_token, stride_o_dim = out.stride()
|
||||
stride_o_seq = 0
|
||||
|
||||
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
|
||||
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
|
||||
)
|
||||
stride_state_indices = conv_state_indices.stride(
|
||||
@@ -945,6 +1043,7 @@ def causal_conv1d_update(
|
||||
cache_seqlens,
|
||||
conv_state_indices,
|
||||
num_accepted_tokens,
|
||||
query_start_loc,
|
||||
out,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
@@ -971,6 +1070,7 @@ def causal_conv1d_update(
|
||||
HAS_BIAS=bias is not None,
|
||||
KERNEL_WIDTH=width,
|
||||
SILU_ACTIVATION=activation in ["silu", "swish"],
|
||||
IS_VARLEN=query_start_loc is not None,
|
||||
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
|
||||
IS_SPEC_DECODING=num_accepted_tokens is not None,
|
||||
NP2_STATELEN=np2_statelen,
|
||||
|
||||
Reference in New Issue
Block a user