[Bugfix][Kernel] Fix negative memory offset in GDN Triton kernel (#33326)

Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
This commit is contained in:
CarstyYou
2026-01-30 02:40:11 +08:00
committed by GitHub
parent 0493d897c4
commit 23591e631e

View File

@@ -106,13 +106,14 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
else:
i_t = 0
p_h0 = (
h0
+ tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
# Load state index and check for PAD_SLOT_ID (-1)
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
tl.int64
)
* stride_init_state_token
)
# Skip if state index is invalid (PAD_SLOT_ID = -1)
if state_idx < 0:
return
p_h0 = h0 + state_idx * stride_init_state_token
else:
p_h0 = h0 + bos * HV * K * V
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
@@ -149,13 +150,15 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
# keep the states for multi-query tokens
if INPLACE_FINAL_STATE:
p_ht = (
ht
+ tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
tl.int64
)
* stride_final_state_token
)
# Load state index and check for PAD_SLOT_ID (-1)
final_state_idx = tl.load(
ssm_state_indices + i_n * stride_indices_seq + i_t
).to(tl.int64)
# Only store if state index is valid (not PAD_SLOT_ID)
if final_state_idx >= 0:
p_ht = ht + final_state_idx * stride_final_state_token
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
else:
p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]