[Bugfix] Add Multiple of 16 block_size to triton fallback on rocm Attention to support qwen3_5 (#35923)
Signed-off-by: JartX <sagformas@epdcenter.es> Co-authored-by: akaratza <akaratza@amd.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -173,7 +173,7 @@ Priority is **1 = highest** (tried first).
|
|||||||
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
|
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
|
||||||
| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A |
|
| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A |
|
||||||
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A |
|
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A |
|
||||||
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 544 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ✅ | ✅ | ❌ | All | N/A |
|
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ✅ | ✅ | ❌ | All | N/A |
|
||||||
| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
|
| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
|
||||||
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |
|
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |
|
||||||
|
|
||||||
|
|||||||
@@ -174,25 +174,15 @@ class RocmAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
# ROCM paged attention kernel only supports block sizes 16 and 32
|
# ROCM paged attention native C++ kernel only supports block sizes 16 and 32
|
||||||
# due to shared memory (LDS) constraints on AMD GPUs.
|
# due to shared memory (LDS) constraints on AMD GPUs.
|
||||||
# See csrc/rocm/attention.cu CALL_CUSTOM_LAUNCHER_BLK macro.
|
# See csrc/rocm/attention.cu CALL_CUSTOM_LAUNCHER_BLK macro.
|
||||||
|
# However, vLLM allows support for any multiple of 16 via the Triton path.
|
||||||
# However, The limitations in [16, 32] are reasonable for a native C++ kernel,
|
# As addressed in PR: https://github.com/vllm-project/vllm/pull/31380,
|
||||||
# but vLLM should allow support for non-standard sizes via the Triton path,
|
# non-standard models (like qwen3-next with block_size 544, or qwen3_5
|
||||||
# as addressed in this PR: https://github.com/vllm-project/vllm/pull/31380,
|
# with 784 and 1056) are dynamically routed to our optimized Triton kernel
|
||||||
# where the Triton kernel under rocm_atten does not support inference
|
# in `do_kv_cache_update`.
|
||||||
# for a non-standard qwen3-next model with a block_size of 544.
|
return [MultipleOf(16)]
|
||||||
# We have fixed the Triton kernel so that the standard model uses the original
|
|
||||||
# bit-addressing logic, while the non-standard model
|
|
||||||
# uses our optimized kernel logic.
|
|
||||||
return [16, 32, 544]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def supports_block_size(cls, block_size: int | None) -> bool:
|
|
||||||
if block_size is None:
|
|
||||||
return True
|
|
||||||
return block_size in (16, 32, 544)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_head_sizes(cls) -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
@@ -463,11 +453,9 @@ class RocmAttentionImpl(AttentionImpl):
|
|||||||
# Get the actual block_size from value_cache
|
# Get the actual block_size from value_cache
|
||||||
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
|
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
|
||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
# Determine if it is a power of 2
|
|
||||||
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
|
|
||||||
|
|
||||||
if is_pow2:
|
if block_size in (16, 32):
|
||||||
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
|
# Normal 16, 32, use vLLM native HIP C++ logic
|
||||||
PagedAttention.write_to_paged_cache(
|
PagedAttention.write_to_paged_cache(
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
@@ -479,7 +467,7 @@ class RocmAttentionImpl(AttentionImpl):
|
|||||||
layer._v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
|
# Case B: Non-standard blocks (e.g., 64, 128, 544 in Qwen3Next or Qwen3.5 ),
|
||||||
# force using our modified Triton logic
|
# force using our modified Triton logic
|
||||||
triton_reshape_and_cache_flash(
|
triton_reshape_and_cache_flash(
|
||||||
key,
|
key,
|
||||||
|
|||||||
Reference in New Issue
Block a user