[CI][AMD][BugFix][Kernel] Cast induction variable to int64 on MI350 for chunk_gated_delta_rule_fwd_kernel_h_blockdim64 to avoid illegal memory access (#39087)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
rasmith
2026-04-08 03:57:18 -05:00
committed by GitHub
parent 2488d1dca2
commit 78434b923c

View File

@@ -129,22 +129,42 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
# main recurrence
for i_t in range(NT):
p_h1 = tl.make_block_ptr(
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)
h + i_t.to(tl.int64) * stride_h,
(V, K),
(K, 1),
(i_v * BV, 0),
(BV, 64),
(1, 0),
)
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_h2 = tl.make_block_ptr(
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
h + i_t.to(tl.int64) * stride_h,
(V, K),
(K, 1),
(i_v * BV, 64),
(BV, 64),
(1, 0),
)
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_h3 = tl.make_block_ptr(
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
h + i_t.to(tl.int64) * stride_h,
(V, K),
(K, 1),
(i_v * BV, 128),
(BV, 64),
(1, 0),
)
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_h4 = tl.make_block_ptr(
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
h + i_t.to(tl.int64) * stride_h,
(V, K),
(K, 1),
(i_v * BV, 192),
(BV, 64),
(1, 0),
)
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
@@ -182,9 +202,9 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
)
tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
last_idx = min((i_t + 1) * BT, T) - 1
last_idx = min((i_t.to(tl.int64) + 1) * BT, T) - 1
if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)) < T
m_t = (i_t.to(tl.int64) * BT + tl.arange(0, BT)) < T
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = tl.make_block_ptr(
g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)