[Qwen3-Next][GDN] fixes cuda graph capturing bug in GDN metadata and a stride bug in causal_conv_1d. (#25743)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
Tao He
2025-09-26 16:18:58 +08:00
committed by GitHub
parent 6e30010d2f
commit 99b3a504c5
3 changed files with 48 additions and 43 deletions

View File

@@ -41,6 +41,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
stride_istate_seq: tl.constexpr,
stride_istate_dim: tl.constexpr,
stride_istate_token: tl.constexpr,
stride_cache_indices: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
@@ -69,7 +70,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
# rather than mixing sequences - to make updating initial_states across sequences efficiently
# single-sequence id
idx_seq = tl.load(batch_ptr + tl.program_id(0))
idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64)
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
# BLOCK_N elements along the feature-dimension (channel)
@@ -91,8 +92,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
if IS_CONTINUOUS_BATCHING:
# cache_idx
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(
tl.int64)
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_cache_indices).to(
tl.int64)
else:
# cache_idx
conv_state_batch_coord = idx_seq
@@ -480,6 +482,8 @@ def causal_conv1d_fn(
stride_o_seq = out.stride(0)
stride_o_dim = out.stride(1)
stride_o_token = out.stride(2)
stride_cache_indices = cache_indices.stride(
0) if cache_indices is not None else 0
if validate_data:
assert x.dim() == 2
@@ -595,6 +599,7 @@ def causal_conv1d_fn(
stride_istate_seq,
stride_istate_dim,
stride_istate_token,
stride_cache_indices,
stride_o_seq,
stride_o_dim,
stride_o_token,