[Bugfix][ROCm]Fix Qwen3-Next-80B-A3B-Thinking inference and optimize non-standard block size (544) support under rocm_atten (#31380)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm
2026-01-09 12:28:02 +01:00
committed by GitHub
parent c8ed39b9dd
commit 1a19e9cd87
5 changed files with 282 additions and 83 deletions

View File

@@ -112,6 +112,7 @@ def test_contexted_kv_attention(
kv_cache_dtype: str,
device: str,
op: Callable,
block_size: int = 32,
) -> None:
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
pytest.skip(
@@ -138,7 +139,6 @@ def test_contexted_kv_attention(
MAX_CTX_LEN = 1024
BS = 10
cache_size = 640
block_size = 32
max_block_per_request = 64
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
# ensure one sequence in batch is a decode
@@ -333,6 +333,7 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype: str,
device: str,
op: Callable,
block_size: int = 32,
) -> None:
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
pytest.skip(
@@ -385,7 +386,6 @@ def test_contexted_kv_attention_alibi(
MAX_CTX_LEN = 1024
BS = 10
cache_size = 640
block_size = 32
max_block_per_request = 64
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
@@ -637,3 +637,34 @@ def test_contexted_kv_attention_alibi_f32(
test_contexted_kv_attention_alibi(
num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op
)
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_qwen3_nonstandard_block_size(
head_size: int,
dtype: torch.dtype,
device: str,
op: Callable,
) -> None:
"""
A separate test function specifically added
for Qwen3-Next-80B (Block Size 544).
"""
if not current_platform.is_rocm():
pytest.skip("544 block size optimization is only for ROCm.")
test_contexted_kv_attention(
num_heads=64,
num_queries_per_kv=1,
head_size=head_size,
block_size=544,
sliding_window=0,
dtype=dtype,
kv_cache_dtype="auto",
device=device,
op=op,
)