diff --git a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py index 574f6f251..83bf33079 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py +++ b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py @@ -129,22 +129,42 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( # main recurrence for i_t in range(NT): p_h1 = tl.make_block_ptr( - h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0) + h + i_t.to(tl.int64) * stride_h, + (V, K), + (K, 1), + (i_v * BV, 0), + (BV, 64), + (1, 0), ) tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) if K > 64: p_h2 = tl.make_block_ptr( - h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0) + h + i_t.to(tl.int64) * stride_h, + (V, K), + (K, 1), + (i_v * BV, 64), + (BV, 64), + (1, 0), ) tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) if K > 128: p_h3 = tl.make_block_ptr( - h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0) + h + i_t.to(tl.int64) * stride_h, + (V, K), + (K, 1), + (i_v * BV, 128), + (BV, 64), + (1, 0), ) tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) if K > 192: p_h4 = tl.make_block_ptr( - h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0) + h + i_t.to(tl.int64) * stride_h, + (V, K), + (K, 1), + (i_v * BV, 192), + (BV, 64), + (1, 0), ) tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) @@ -182,9 +202,9 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( ) tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) - last_idx = min((i_t + 1) * BT, T) - 1 + last_idx = min((i_t.to(tl.int64) + 1) * BT, T) - 1 if USE_G: - m_t = (i_t * BT + tl.arange(0, BT)) < T + m_t = (i_t.to(tl.int64) * BT + tl.arange(0, BT)) < T b_g_last = tl.load(g + bos * H + last_idx * H + i_h) p_g = tl.make_block_ptr( g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)