diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index b2c955b49..45779636e 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -112,6 +112,7 @@ def test_contexted_kv_attention( kv_cache_dtype: str, device: str, op: Callable, + block_size: int = 32, ) -> None: if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89): pytest.skip( @@ -138,7 +139,6 @@ def test_contexted_kv_attention( MAX_CTX_LEN = 1024 BS = 10 cache_size = 640 - block_size = 32 max_block_per_request = 64 query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] # ensure one sequence in batch is a decode @@ -333,6 +333,7 @@ def test_contexted_kv_attention_alibi( kv_cache_dtype: str, device: str, op: Callable, + block_size: int = 32, ) -> None: if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89): pytest.skip( @@ -385,7 +386,6 @@ def test_contexted_kv_attention_alibi( MAX_CTX_LEN = 1024 BS = 10 cache_size = 640 - block_size = 32 max_block_per_request = 64 query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] @@ -637,3 +637,34 @@ def test_contexted_kv_attention_alibi_f32( test_contexted_kv_attention_alibi( num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op ) + + +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("op", OPS) +@torch.inference_mode() +def test_qwen3_nonstandard_block_size( + head_size: int, + dtype: torch.dtype, + device: str, + op: Callable, +) -> None: + """ + A separate test function specifically added + for Qwen3-Next-80B (Block Size 544). + """ + if not current_platform.is_rocm(): + pytest.skip("544 block size optimization is only for ROCm.") + + test_contexted_kv_attention( + num_heads=64, + num_queries_per_kv=1, + head_size=head_size, + block_size=544, + sliding_window=0, + dtype=dtype, + kv_cache_dtype="auto", + device=device, + op=op, + ) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 25ffe240f..c8b25d387 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -46,6 +46,7 @@ def kernel_paged_attention_2d( output_stride_0: tl.int64, # int output_stride_1: tl.int64, # int, should be equal to head_size BLOCK_SIZE: tl.constexpr, # int + PHYSICAL_BLOCK_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool @@ -104,14 +105,15 @@ def kernel_paged_attention_2d( if not USE_SINKS: M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) + L = tl.zeros([num_queries_per_kv_padded], dtype=tl.float32) else: M = tl.load( sink_ptr + query_head_idx, mask=head_mask, other=float("-inf"), ).to(dtype=tl.float32) + L = tl.where(float("-inf") < M, 1.0, 0.0) - L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32) # sequence len for this particular sequence @@ -125,30 +127,45 @@ def kernel_paged_attention_2d( num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + offs_n = tl.arange(0, BLOCK_SIZE) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) # iterate through tiles for j in range(0, num_blocks): - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) - - offs_n = tl.arange(0, BLOCK_SIZE) - offs_d = tl.arange(0, HEAD_SIZE_PADDED) - - v_offset = ( - physical_block_idx * stride_v_cache_0 - + kv_head_idx * stride_v_cache_1 - + offs_d[None, :] * stride_v_cache_2 - + offs_n[:, None] * stride_v_cache_3 - ) + start_n = j * BLOCK_SIZE + # Calculate the logical location within a non-standard physical block, + # such as 544 in Qwen/Qwen3-Next-80B-A3B-Thinking. + # Supports non-contiguous mapping + # from logical blocks to physical blocks + abs_token_idx = start_n + offs_n + l_block_idx = abs_token_idx // PHYSICAL_BLOCK_SIZE + # Vectorized loading of physical block IDs + p_block_idx = tl.load(block_tables_ptr + block_table_offset + l_block_idx) + internal_offsets = abs_token_idx % PHYSICAL_BLOCK_SIZE + # 5D addressing logic of K k_offset = ( - physical_block_idx * stride_k_cache_0 + p_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_1 + (offs_d[:, None] // x) * stride_k_cache_2 - + offs_n[None, :] * stride_k_cache_3 + + internal_offsets[None, :] * stride_k_cache_3 + (offs_d[:, None] % x) * stride_k_cache_4 ) + # 4D addressing logic of V (Slot is innermost) + v_offset = ( + p_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_1 + + offs_d[None, :] * stride_v_cache_2 + + internal_offsets[:, None] * stride_v_cache_3 + ) + # K : (HEAD_SIZE, BLOCK_SIZE) - K_load = tl.load(key_cache_ptr + k_offset, mask=dim_mask[:, None], other=0.0) + K_load = tl.load( + key_cache_ptr + k_offset, + mask=dim_mask[:, None], + other=0.0, + eviction_policy="evict_last", + ) if K_load.dtype.is_fp8(): K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) @@ -156,7 +173,12 @@ def kernel_paged_attention_2d( K = K_load # V : (BLOCK_SIZE, HEAD_SIZE) - V_load = tl.load(value_cache_ptr + v_offset, mask=dim_mask[None, :], other=0.0) + V_load = tl.load( + value_cache_ptr + v_offset, + mask=dim_mask[None, :], + other=0.0, + eviction_policy="evict_last", + ) if V_load.dtype.is_fp8(): V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) @@ -167,9 +189,9 @@ def kernel_paged_attention_2d( boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32) seq_mask = seq_offset[None, :] < boundary - # S : (num_queries_per_kv, BLOCK_SIZE,) - S = tl.where(head_mask[:, None] & seq_mask, 0.0, float("-inf")).to(tl.float32) - S += scale * tl.dot(Q, K) + # First calculate the dot, then apply the mask. + qk = scale * tl.dot(Q, K) + S = tl.where(head_mask[:, None] & seq_mask, qk, float("-inf")) context_len = seq_len - 1 @@ -184,13 +206,15 @@ def kernel_paged_attention_2d( m_j = tl.maximum(M, tl.max(S, axis=1)) # P : (num_queries_per_kv, BLOCK_SIZE,) - P = tl.exp(S - m_j[:, None]) + p = tl.exp(S - m_j[:, None]) + p = tl.where(m_j[:, None] == float("-inf"), 0.0, p) # l_j : (num_queries_per_kv,) - l_j = tl.sum(P, axis=1) + l_j = tl.sum(p, axis=1) # alpha : (num_queries_per_kv, ) alpha = tl.exp(M - m_j) + alpha = tl.where(float("-inf") == M, 0.0, alpha) # acc : (num_queries_per_kv, BLOCK_SIZE,) acc = acc * alpha[:, None] @@ -200,10 +224,10 @@ def kernel_paged_attention_2d( M = m_j # acc : (num_queries_per_kv, BLOCK_SIZE,) - acc += tl.dot(P.to(V.dtype), V) + acc += tl.dot(p.to(V.dtype), V) # epilogue - acc = acc / L[:, None] + acc = acc / (L[:, None] + 1e-10) if USE_FP8: acc = acc * tl.load(out_scale_inv) acc = tl.clamp(acc, FP8_MIN, FP8_MAX) @@ -241,9 +265,10 @@ def chunked_prefill_paged_decode( output_scale=None, # Optional tensor for sinks sinks=None, + is_block_table_ptr: bool = False, ): if sm_scale is None: - sm_scale = 1.0 / (query.shape[1] ** 0.5) + sm_scale = 1.0 / (query.shape[2] ** 0.5) use_alibi_slopes = alibi_slopes is not None @@ -315,6 +340,16 @@ def chunked_prefill_paged_decode( alibi_slopes, sinks, ) + # Triton is only forced when encountering a non-standard block + # like Qwen3 with a size of 544. + # 1. Check if block_size is a power of 2 (16, 32, 64...) + # 2. If it's a power of 2, we trust the vLLM's native use_custom decision. + # 3. If it's not a power of 2 (such as Qwen3's 544), + # then our Triton path is forced. + is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0) + if not is_pow2: + use_custom = False + if use_custom: _PARTITION_SIZE_ROCM = 256 max_num_partitions = ( @@ -356,6 +391,25 @@ def chunked_prefill_paged_decode( fp8_out_scale=output_scale, ) else: + real_block_size = value_cache.shape[3] + # The standard model directly uses the original block_size. + # Non-standard 544 uses 32 to accommodate integer division logic. + TRITON_BLOCK_SIZE = block_size if is_pow2 else 32 + if is_block_table_ptr: + # Using the physical base address of tensors + kv_element_size = key_cache.element_size() + block_byte_stride = key_cache.stride(0) * kv_element_size + # Get the starting physical address of the KV Cache + base_addr = key_cache.data_ptr() + + # Normalization: Directly calculate the block offset + # of the pointer relative to the base address + processed_block_table = ((block_table - base_addr) // block_byte_stride).to( + torch.int32 + ) + else: + processed_block_table = block_table.to(torch.int32) + kernel_paged_attention_2d[ ( num_seqs, @@ -367,7 +421,7 @@ def chunked_prefill_paged_decode( key_cache_ptr=key_cache, value_cache_ptr=value_cache, sink_ptr=sinks, - block_tables_ptr=block_table, + block_tables_ptr=processed_block_table, seq_lens_ptr=seq_lens, alibi_slopes_ptr=alibi_slopes, scale=sm_scale, @@ -377,12 +431,13 @@ def chunked_prefill_paged_decode( num_query_heads=num_query_heads, num_queries_per_kv=num_queries_per_kv, num_queries_per_kv_padded=num_queries_per_kv_padded, - block_table_stride=block_table.stride(0), + block_table_stride=processed_block_table.stride(0), query_stride_0=query.stride(0), query_stride_1=query.stride(1), output_stride_0=output.stride(0), output_stride_1=output.stride(1), - BLOCK_SIZE=block_size, + BLOCK_SIZE=TRITON_BLOCK_SIZE, + PHYSICAL_BLOCK_SIZE=real_block_size, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 5a507a779..13c82f586 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -79,6 +79,7 @@ def _fwd_kernel( BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL_PADDED: tl.constexpr, BLOCK_SIZE: tl.constexpr, + PHYSICAL_BLOCK_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, SLIDING_WINDOW: tl.constexpr, num_unroll_cache: tl.constexpr, @@ -139,42 +140,52 @@ def _fwd_kernel( # initialize pointer to m and l if not USE_SINKS: m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) else: m_i = tl.load( sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64), mask=(offs_m < cur_batch_query_len), other=float("-inf"), ).to(dtype=tl.float32) + l_i = tl.where(m_i > float("-inf"), 1.0, 0.0) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] # compute query against context (no causal mask here) for start_n in tl.range( 0, cur_batch_ctx_len, BLOCK_SIZE, loop_unroll_factor=num_unroll_cache ): - start_n = tl.multiple_of(start_n, BLOCK_SIZE) - # -- compute qk ---- + # Under a block size of 544 (Qwen/Qwen3-Next-80B-A3B-Thinking), + # replace one physical block every 17 32-Tile blocks + # Calculate the logical block index of each of the 32 tokens + # in the current Tile (handling cross-block cases). + token_indices = start_n + offs_bs_n + bn_logical_indices = token_indices // PHYSICAL_BLOCK_SIZE + + # 2. Vectorized loading of physical block IDs from B_Loc bn = tl.load( - B_Loc - + cur_batch * stride_b_loc_b - + (start_n // BLOCK_SIZE) * stride_b_loc_s + B_Loc + cur_batch * stride_b_loc_b + bn_logical_indices * stride_b_loc_s ).to(tl.int64) - # [D,BLOCK_SIZE] + + # 3. Calculate the exact offset of + # each token within its physical block. + internal_offsets = token_indices % PHYSICAL_BLOCK_SIZE + + # Addressing of K (5D) off_k = ( bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d - + ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl + + internal_offsets[None, :] * stride_k_cache_bl + (offs_d[:, None] % x) * stride_k_cache_x ) - # [BLOCK_SIZE,D] + # Addressing of V (4D) off_v = ( bn[:, None] * stride_v_cache_bs + cur_kv_head * stride_v_cache_h + offs_d[None, :] * stride_v_cache_d - + offs_bs_n[:, None] * stride_v_cache_bl + + internal_offsets[:, None] * stride_v_cache_bl ) if ( @@ -195,12 +206,12 @@ def _fwd_kernel( else: k = k_load - qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + # qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] + qk = sm_scale * tl.dot(q, k, input_precision=IN_PRECISION) qk = tl.where( (start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf") ) - qk *= sm_scale + # qk *= sm_scale if SLIDING_WINDOW > 0: # (cur_batch_ctx_len + offs_m[:, None]) are the positions of # Q entries in sequence @@ -217,14 +228,16 @@ def _fwd_kernel( (cur_batch_ctx_len + offs_m[:, None]) - (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk, - -10000, + float("-inf"), ) # compute running maximum m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) p = tl.exp(qk - m_ij[:, None]) + p = tl.where(m_ij[:, None] == float("-inf"), 0.0, p) l_ij = tl.sum(p, axis=1) alpha = tl.exp(m_i - m_ij) + alpha = tl.where(m_i == float("-inf"), 0.0, alpha) acc = acc * alpha[:, None] # update acc @@ -293,14 +306,17 @@ def _fwd_kernel( qk = tl.where( offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, - -10000, + float("-inf"), ) # compute running maximum m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) p = tl.exp(qk - m_ij[:, None]) + p = tl.where(m_ij[:, None] == float("-inf"), 0.0, p) l_ij = tl.sum(p, axis=1) alpha = tl.exp(m_i - m_ij) + # To prevent NaN from appearing in the first round + alpha = tl.where(m_i == float("-inf"), 0.0, alpha) acc = acc * alpha[:, None] # update acc @@ -317,7 +333,7 @@ def _fwd_kernel( l_i = l_i * alpha + l_ij m_i = m_ij - acc = acc / l_i[:, None] + acc = acc / (l_i[:, None] + 1e-10) # initialize pointers to output off_o = ( @@ -637,6 +653,7 @@ def context_attention_fwd( skip_decode=False, fp8_out_scale=None, sinks=None, + is_block_table_ptr: bool = False, ): q_dtype_is_f32 = q.dtype is torch.float32 @@ -689,6 +706,19 @@ def context_attention_fwd( if sliding_window is None or sliding_window <= 0: sliding_window = 0 + if is_block_table_ptr: + kv_element_size = k_cache.element_size() + block_byte_stride = k_cache.stride(0) * kv_element_size + # The physical starting point of the obtained KV Cache Pool + base_addr = k_cache.data_ptr() + + mask = b_loc > 0 + processed_b_loc = torch.where( + mask, (b_loc - base_addr) // block_byte_stride, b_loc + ).to(torch.int32) + else: + processed_b_loc = b_loc.to(torch.int32) + if alibi_slopes is not None: assert sinks is None, "Sinks arg is not supported with alibi" assert fp8_out_scale is None, "FP8 output not supported with alibi" @@ -752,7 +782,24 @@ def context_attention_fwd( max_seq_len = 0 if max_seq_len is None else max_seq_len extra_kargs = {} if current_platform.is_rocm(): - extra_kargs = {"kpack": 1, "waves_per_eu": 2} + extra_kargs = {} + + real_block_size = v_cache.shape[3] + is_pow2 = real_block_size > 0 and (real_block_size & (real_block_size - 1) == 0) + # For standard models involving powers of 2, + # follow the original logic (Llama 128/64) + # For non-standard models (Qwen3-next block_size 544), set to 32. + if is_pow2: + BLOCK_M = 128 + BLOCK_N = 64 + else: + BLOCK_M = 32 + BLOCK_N = 32 + + # TRITON_BLOCK_SIZE is kept at 32 to ensure + # correct alignment logic when the kernel handles + # non-standard sizes (such as 544). + TRITON_BLOCK_SIZE = 32 grid_fn = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) _fwd_kernel[grid_fn]( @@ -762,7 +809,7 @@ def context_attention_fwd( k_cache, v_cache, sinks, - b_loc, + processed_b_loc, sm_scale, k_scale, v_scale, @@ -771,8 +818,8 @@ def context_attention_fwd( b_seq_len, k_cache.shape[4], o, - b_loc.stride(0), - b_loc.stride(1), + processed_b_loc.stride(0), + processed_b_loc.stride(1), q.stride(0), q.stride(1), q.stride(2), @@ -785,16 +832,17 @@ def context_attention_fwd( o.stride(0), o.stride(1), o.stride(2), - k_cache.stride(0), - k_cache.stride(1), - k_cache.stride(2), - k_cache.stride(3), - k_cache.stride(4), # [num_blocks, num_kv_heads, head_size/x, block_size, x] - v_cache.stride(0), - v_cache.stride(1), - v_cache.stride(2), - v_cache.stride(3), # [num_blocks, num_kv_heads, head_size, block_size] - BLOCK_SIZE=v_cache.shape[3], + stride_k_cache_bs=k_cache.stride(0), + stride_k_cache_h=k_cache.stride(1), + stride_k_cache_d=k_cache.stride(2), + stride_k_cache_bl=k_cache.stride(3), + stride_k_cache_x=k_cache.stride(4), + stride_v_cache_bs=v_cache.stride(0), + stride_v_cache_h=v_cache.stride(1), + stride_v_cache_d=v_cache.stride(2), + stride_v_cache_bl=v_cache.stride(3), + BLOCK_SIZE=TRITON_BLOCK_SIZE, + PHYSICAL_BLOCK_SIZE=real_block_size, num_queries_per_kv=num_queries_per_kv, IN_PRECISION=IN_PRECISION, BLOCK_DMODEL=Lk, @@ -802,8 +850,8 @@ def context_attention_fwd( SLIDING_WINDOW=sliding_window, SKIP_DECODE=skip_decode, USE_FP8=fp8_out_scale is not None, - BLOCK_M=128, - BLOCK_N=64, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, num_unroll_cache=4, num_unroll_request=1, num_warps=4, diff --git a/vllm/attention/ops/triton_reshape_and_cache_flash.py b/vllm/attention/ops/triton_reshape_and_cache_flash.py index a383de0ac..c5c9a9c96 100644 --- a/vllm/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/attention/ops/triton_reshape_and_cache_flash.py @@ -20,10 +20,15 @@ def reshape_and_cache_kernel_flash( key_stride: tl.int64, value_stride: tl.int64, block_stride: tl.int64, + head_stride: tl.int64, + dim_stride_k: tl.int64, + dim_stride_v: tl.int64, page_stride: tl.int64, num_heads: tl.constexpr, head_size: tl.constexpr, block_size: tl.constexpr, + x: tl.constexpr, + USE_HEAD_MAJOR_LAYOUT: tl.constexpr, # FP8 flags FP8_KV_CACHE: tl.constexpr, # tune parameters @@ -35,17 +40,38 @@ def reshape_and_cache_kernel_flash( # Padding token that should be ignored. return - tile_i = tl.program_id(axis=1) - tile_offs = tl.arange(0, TILE_SIZE) - tile_pos = tile_i * TILE_SIZE + tile_offs - block_idx = slot_idx // block_size block_offset = slot_idx % block_size + tile_i = tl.program_id(axis=1) + tile_offs = tl.arange(0, TILE_SIZE) + tile_pos = tile_i * TILE_SIZE + tile_offs src_key_idx = token_idx * key_stride src_value_idx = token_idx * value_stride - tgt_idx = block_idx * block_stride + block_offset * page_stride + if USE_HEAD_MAJOR_LAYOUT: + # Decompose the tile index back into head and dim coordinates. + cur_head = tile_pos // head_size + cur_dim = tile_pos % head_size + # Value addressing (4D): [Block, Head, Dim, Slot] + tgt_idx_v = ( + block_idx * block_stride + + cur_head * head_stride + + cur_dim * dim_stride_v + + block_offset * 1 + ) + # Key addressing (5D): [Block, Head, Dim//8, Slot, 8] + tgt_idx_k = ( + block_idx * block_stride + + cur_head * head_stride + + (cur_dim // x) * dim_stride_k + + block_offset * x + + (cur_dim % x) + ) + else: + tgt_base = block_idx * block_stride + block_offset * page_stride + tgt_idx_k = tgt_base + tile_pos + tgt_idx_v = tgt_base + tile_pos # [TILE_SIZE] key_load = tl.load( @@ -73,12 +99,12 @@ def reshape_and_cache_kernel_flash( value_tile = value_load tl.store( - key_cache_ptr + tgt_idx + tile_pos, + key_cache_ptr + tgt_idx_k, key_tile, mask=tile_pos < (num_heads * head_size), ) tl.store( - value_cache_ptr + tgt_idx + tile_pos, + value_cache_ptr + tgt_idx_v, value_tile, mask=tile_pos < (num_heads * head_size), ) @@ -99,17 +125,26 @@ def triton_reshape_and_cache_flash( ): num_heads = key.shape[1] head_size = key.shape[2] - block_size = key_cache.shape[1] - n = num_heads * head_size + use_head_major_layout = key_cache.ndim == 5 + if use_head_major_layout: + block_size = key_cache.shape[3] + x = key_cache.shape[4] + head_stride = key_cache.stride(1) + dim_stride_k = key_cache.stride(2) + dim_stride_v = value_cache.stride(2) + else: + block_size = key_cache.shape[1] + x = 1 + dim_stride_k = 0 + dim_stride_v = 0 + head_stride = key_cache.stride()[2] + n = num_heads * head_size key_stride = key.stride()[0] value_stride = value.stride()[0] block_stride = key_cache.stride()[0] page_stride = key_cache.stride()[1] - head_stride = key_cache.stride()[2] - assert head_stride == head_size, "only continous heads are supported" - assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), ( f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." ) @@ -171,10 +206,15 @@ def triton_reshape_and_cache_flash( key_stride=key_stride, value_stride=value_stride, block_stride=block_stride, + head_stride=head_stride, + dim_stride_k=dim_stride_k, + dim_stride_v=dim_stride_v, page_stride=page_stride, num_heads=num_heads, head_size=head_size, block_size=block_size, + x=x, + USE_HEAD_MAJOR_LAYOUT=use_head_major_layout, # FP8 flags FP8_KV_CACHE=FP8_KV_CACHE, # autotune parameters diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 44fa2962a..0b7a51434 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -15,6 +15,9 @@ from vllm.attention.backends.abstract import ( ) from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.paged_attn import PagedAttention +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash, +) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -321,16 +324,38 @@ class RocmAttentionImpl(AttentionImpl): if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + + # Get the actual block_size from value_cache + # value_cache shape: [num_blocks, num_heads, head_size, block_size] + block_size = value_cache.shape[3] + # Determine if it is a power of 2 + is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0) + + if is_pow2: + # Normal 16, 32, 64, etc., use vLLM native HIP C++ logic + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + else: + # Case B: Non-standard blocks (e.g., 544 in Qwen3), + # force using our modified Triton logic + triton_reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(self.fp8_dtype)