[Qwen3-Next][GDN] fixes cuda graph capturing bug in GDN metadata and a stride bug in causal_conv_1d. (#25743)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
@@ -41,6 +41,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
stride_istate_seq: tl.constexpr,
|
||||
stride_istate_dim: tl.constexpr,
|
||||
stride_istate_token: tl.constexpr,
|
||||
stride_cache_indices: tl.constexpr,
|
||||
stride_o_seq: tl.constexpr,
|
||||
stride_o_dim: tl.constexpr,
|
||||
stride_o_token: tl.constexpr,
|
||||
@@ -69,7 +70,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
# rather than mixing sequences - to make updating initial_states across sequences efficiently
|
||||
|
||||
# single-sequence id
|
||||
idx_seq = tl.load(batch_ptr + tl.program_id(0))
|
||||
idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64)
|
||||
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
|
||||
|
||||
# BLOCK_N elements along the feature-dimension (channel)
|
||||
@@ -91,8 +92,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
# cache_idx
|
||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(
|
||||
tl.int64)
|
||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_cache_indices).to(
|
||||
tl.int64)
|
||||
else:
|
||||
# cache_idx
|
||||
conv_state_batch_coord = idx_seq
|
||||
@@ -480,6 +482,8 @@ def causal_conv1d_fn(
|
||||
stride_o_seq = out.stride(0)
|
||||
stride_o_dim = out.stride(1)
|
||||
stride_o_token = out.stride(2)
|
||||
stride_cache_indices = cache_indices.stride(
|
||||
0) if cache_indices is not None else 0
|
||||
|
||||
if validate_data:
|
||||
assert x.dim() == 2
|
||||
@@ -595,6 +599,7 @@ def causal_conv1d_fn(
|
||||
stride_istate_seq,
|
||||
stride_istate_dim,
|
||||
stride_istate_token,
|
||||
stride_cache_indices,
|
||||
stride_o_seq,
|
||||
stride_o_dim,
|
||||
stride_o_token,
|
||||
|
||||
Reference in New Issue
Block a user