[Bugfix][ROCm]Fix Qwen3-Next-80B-A3B-Thinking inference and optimize non-standard block size (544) support under rocm_atten (#31380)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm
2026-01-09 12:28:02 +01:00
committed by GitHub
parent c8ed39b9dd
commit 1a19e9cd87
5 changed files with 282 additions and 83 deletions

View File

@@ -15,6 +15,9 @@ from vllm.attention.backends.abstract import (
)
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@@ -321,16 +324,38 @@ class RocmAttentionImpl(AttentionImpl):
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
# Get the actual block_size from value_cache
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
# Determine if it is a power of 2
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
if is_pow2:
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
# force using our modified Triton logic
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)