[Bugfix] Add Multiple of 16 block_size to triton fallback on rocm Attention to support qwen3_5 (#35923)

Signed-off-by: JartX <sagformas@epdcenter.es>
Co-authored-by: akaratza <akaratza@amd.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
JartX
2026-03-11 08:45:57 +01:00
committed by GitHub
parent eac2dc2b41
commit a40ee486f2
2 changed files with 11 additions and 23 deletions

View File

@@ -173,7 +173,7 @@ Priority is **1 = highest** (tried first).
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A |
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 544 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ✅ | ✅ | ❌ | All | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ✅ | ✅ | ❌ | All | N/A |
| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |

View File

@@ -174,25 +174,15 @@ class RocmAttentionBackend(AttentionBackend):
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
# ROCM paged attention kernel only supports block sizes 16 and 32
# ROCM paged attention native C++ kernel only supports block sizes 16 and 32
# due to shared memory (LDS) constraints on AMD GPUs.
# See csrc/rocm/attention.cu CALL_CUSTOM_LAUNCHER_BLK macro.
# However, The limitations in [16, 32] are reasonable for a native C++ kernel,
# but vLLM should allow support for non-standard sizes via the Triton path,
# as addressed in this PR: https://github.com/vllm-project/vllm/pull/31380,
# where the Triton kernel under rocm_atten does not support inference
# for a non-standard qwen3-next model with a block_size of 544.
# We have fixed the Triton kernel so that the standard model uses the original
# bit-addressing logic, while the non-standard model
# uses our optimized kernel logic.
return [16, 32, 544]
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
if block_size is None:
return True
return block_size in (16, 32, 544)
# However, vLLM allows support for any multiple of 16 via the Triton path.
# As addressed in PR: https://github.com/vllm-project/vllm/pull/31380,
# non-standard models (like qwen3-next with block_size 544, or qwen3_5
# with 784 and 1056) are dynamically routed to our optimized Triton kernel
# in `do_kv_cache_update`.
return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
@@ -463,11 +453,9 @@ class RocmAttentionImpl(AttentionImpl):
# Get the actual block_size from value_cache
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
# Determine if it is a power of 2
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
if is_pow2:
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
if block_size in (16, 32):
# Normal 16, 32, use vLLM native HIP C++ logic
PagedAttention.write_to_paged_cache(
key,
value,
@@ -479,7 +467,7 @@ class RocmAttentionImpl(AttentionImpl):
layer._v_scale,
)
else:
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
# Case B: Non-standard blocks (e.g., 64, 128, 544 in Qwen3Next or Qwen3.5 ),
# force using our modified Triton logic
triton_reshape_and_cache_flash(
key,