[CI][AMD][BugFix][Kernel] Cast induction variable to int64 on MI350 for chunk_gated_delta_rule_fwd_kernel_h_blockdim64 to avoid illegal memory access (#39087)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
@@ -129,22 +129,42 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
# main recurrence
|
||||
for i_t in range(NT):
|
||||
p_h1 = tl.make_block_ptr(
|
||||
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)
|
||||
h + i_t.to(tl.int64) * stride_h,
|
||||
(V, K),
|
||||
(K, 1),
|
||||
(i_v * BV, 0),
|
||||
(BV, 64),
|
||||
(1, 0),
|
||||
)
|
||||
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_h2 = tl.make_block_ptr(
|
||||
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
|
||||
h + i_t.to(tl.int64) * stride_h,
|
||||
(V, K),
|
||||
(K, 1),
|
||||
(i_v * BV, 64),
|
||||
(BV, 64),
|
||||
(1, 0),
|
||||
)
|
||||
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_h3 = tl.make_block_ptr(
|
||||
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
|
||||
h + i_t.to(tl.int64) * stride_h,
|
||||
(V, K),
|
||||
(K, 1),
|
||||
(i_v * BV, 128),
|
||||
(BV, 64),
|
||||
(1, 0),
|
||||
)
|
||||
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_h4 = tl.make_block_ptr(
|
||||
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
|
||||
h + i_t.to(tl.int64) * stride_h,
|
||||
(V, K),
|
||||
(K, 1),
|
||||
(i_v * BV, 192),
|
||||
(BV, 64),
|
||||
(1, 0),
|
||||
)
|
||||
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
@@ -182,9 +202,9 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
)
|
||||
tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
last_idx = min((i_t + 1) * BT, T) - 1
|
||||
last_idx = min((i_t.to(tl.int64) + 1) * BT, T) - 1
|
||||
if USE_G:
|
||||
m_t = (i_t * BT + tl.arange(0, BT)) < T
|
||||
m_t = (i_t.to(tl.int64) * BT + tl.arange(0, BT)) < T
|
||||
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
|
||||
p_g = tl.make_block_ptr(
|
||||
g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
||||
|
||||
Reference in New Issue
Block a user