[Bugfix] Add Multiple of 16 block_size to triton fallback on rocm Attention to support qwen3_5 (#35923)
Signed-off-by: JartX <sagformas@epdcenter.es> Co-authored-by: akaratza <akaratza@amd.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -174,25 +174,15 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
# ROCM paged attention kernel only supports block sizes 16 and 32
|
||||
# ROCM paged attention native C++ kernel only supports block sizes 16 and 32
|
||||
# due to shared memory (LDS) constraints on AMD GPUs.
|
||||
# See csrc/rocm/attention.cu CALL_CUSTOM_LAUNCHER_BLK macro.
|
||||
|
||||
# However, The limitations in [16, 32] are reasonable for a native C++ kernel,
|
||||
# but vLLM should allow support for non-standard sizes via the Triton path,
|
||||
# as addressed in this PR: https://github.com/vllm-project/vllm/pull/31380,
|
||||
# where the Triton kernel under rocm_atten does not support inference
|
||||
# for a non-standard qwen3-next model with a block_size of 544.
|
||||
# We have fixed the Triton kernel so that the standard model uses the original
|
||||
# bit-addressing logic, while the non-standard model
|
||||
# uses our optimized kernel logic.
|
||||
return [16, 32, 544]
|
||||
|
||||
@classmethod
|
||||
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||
if block_size is None:
|
||||
return True
|
||||
return block_size in (16, 32, 544)
|
||||
# However, vLLM allows support for any multiple of 16 via the Triton path.
|
||||
# As addressed in PR: https://github.com/vllm-project/vllm/pull/31380,
|
||||
# non-standard models (like qwen3-next with block_size 544, or qwen3_5
|
||||
# with 784 and 1056) are dynamically routed to our optimized Triton kernel
|
||||
# in `do_kv_cache_update`.
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
@@ -463,11 +453,9 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
# Get the actual block_size from value_cache
|
||||
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
# Determine if it is a power of 2
|
||||
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
|
||||
|
||||
if is_pow2:
|
||||
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
|
||||
if block_size in (16, 32):
|
||||
# Normal 16, 32, use vLLM native HIP C++ logic
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
@@ -479,7 +467,7 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
|
||||
# Case B: Non-standard blocks (e.g., 64, 128, 544 in Qwen3Next or Qwen3.5 ),
|
||||
# force using our modified Triton logic
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
|
||||
Reference in New Issue
Block a user