[Bugfix] Add Multiple of 16 block_size to triton fallback on rocm Attention to support qwen3_5 (#35923)

Signed-off-by: JartX <sagformas@epdcenter.es>
Co-authored-by: akaratza <akaratza@amd.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
JartX
2026-03-11 08:45:57 +01:00
committed by GitHub
parent eac2dc2b41
commit a40ee486f2
2 changed files with 11 additions and 23 deletions

View File

@@ -174,25 +174,15 @@ class RocmAttentionBackend(AttentionBackend):
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
# ROCM paged attention kernel only supports block sizes 16 and 32
# ROCM paged attention native C++ kernel only supports block sizes 16 and 32
# due to shared memory (LDS) constraints on AMD GPUs.
# See csrc/rocm/attention.cu CALL_CUSTOM_LAUNCHER_BLK macro.
# However, The limitations in [16, 32] are reasonable for a native C++ kernel,
# but vLLM should allow support for non-standard sizes via the Triton path,
# as addressed in this PR: https://github.com/vllm-project/vllm/pull/31380,
# where the Triton kernel under rocm_atten does not support inference
# for a non-standard qwen3-next model with a block_size of 544.
# We have fixed the Triton kernel so that the standard model uses the original
# bit-addressing logic, while the non-standard model
# uses our optimized kernel logic.
return [16, 32, 544]
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
if block_size is None:
return True
return block_size in (16, 32, 544)
# However, vLLM allows support for any multiple of 16 via the Triton path.
# As addressed in PR: https://github.com/vllm-project/vllm/pull/31380,
# non-standard models (like qwen3-next with block_size 544, or qwen3_5
# with 784 and 1056) are dynamically routed to our optimized Triton kernel
# in `do_kv_cache_update`.
return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
@@ -463,11 +453,9 @@ class RocmAttentionImpl(AttentionImpl):
# Get the actual block_size from value_cache
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
# Determine if it is a power of 2
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
if is_pow2:
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
if block_size in (16, 32):
# Normal 16, 32, use vLLM native HIP C++ logic
PagedAttention.write_to_paged_cache(
key,
value,
@@ -479,7 +467,7 @@ class RocmAttentionImpl(AttentionImpl):
layer._v_scale,
)
else:
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
# Case B: Non-standard blocks (e.g., 64, 128, 544 in Qwen3Next or Qwen3.5 ),
# force using our modified Triton logic
triton_reshape_and_cache_flash(
key,