[Bugfix][Kernel] Fix negative memory offset in GDN Triton kernel (#33326)
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
This commit is contained in:
@@ -106,13 +106,14 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
|
||||
else:
|
||||
i_t = 0
|
||||
p_h0 = (
|
||||
h0
|
||||
+ tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
|
||||
# Load state index and check for PAD_SLOT_ID (-1)
|
||||
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
|
||||
tl.int64
|
||||
)
|
||||
* stride_init_state_token
|
||||
)
|
||||
# Skip if state index is invalid (PAD_SLOT_ID = -1)
|
||||
if state_idx < 0:
|
||||
return
|
||||
p_h0 = h0 + state_idx * stride_init_state_token
|
||||
else:
|
||||
p_h0 = h0 + bos * HV * K * V
|
||||
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
@@ -149,13 +150,15 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
|
||||
# keep the states for multi-query tokens
|
||||
if INPLACE_FINAL_STATE:
|
||||
p_ht = (
|
||||
ht
|
||||
+ tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
|
||||
tl.int64
|
||||
)
|
||||
* stride_final_state_token
|
||||
)
|
||||
# Load state index and check for PAD_SLOT_ID (-1)
|
||||
final_state_idx = tl.load(
|
||||
ssm_state_indices + i_n * stride_indices_seq + i_t
|
||||
).to(tl.int64)
|
||||
# Only store if state index is valid (not PAD_SLOT_ID)
|
||||
if final_state_idx >= 0:
|
||||
p_ht = ht + final_state_idx * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
else:
|
||||
p_ht = ht + (bos + i_t) * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
|
||||
Reference in New Issue
Block a user