[Bugfix][ROCm]Fix Qwen3-Next-80B-A3B-Thinking inference and optimize non-standard block size (544) support under rocm_atten (#31380)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -112,6 +112,7 @@ def test_contexted_kv_attention(
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
op: Callable,
|
||||
block_size: int = 32,
|
||||
) -> None:
|
||||
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
|
||||
pytest.skip(
|
||||
@@ -138,7 +139,6 @@ def test_contexted_kv_attention(
|
||||
MAX_CTX_LEN = 1024
|
||||
BS = 10
|
||||
cache_size = 640
|
||||
block_size = 32
|
||||
max_block_per_request = 64
|
||||
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
||||
# ensure one sequence in batch is a decode
|
||||
@@ -333,6 +333,7 @@ def test_contexted_kv_attention_alibi(
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
op: Callable,
|
||||
block_size: int = 32,
|
||||
) -> None:
|
||||
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
|
||||
pytest.skip(
|
||||
@@ -385,7 +386,6 @@ def test_contexted_kv_attention_alibi(
|
||||
MAX_CTX_LEN = 1024
|
||||
BS = 10
|
||||
cache_size = 640
|
||||
block_size = 32
|
||||
max_block_per_request = 64
|
||||
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
||||
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
||||
@@ -637,3 +637,34 @@ def test_contexted_kv_attention_alibi_f32(
|
||||
test_contexted_kv_attention_alibi(
|
||||
num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("head_size", [128])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("op", OPS)
|
||||
@torch.inference_mode()
|
||||
def test_qwen3_nonstandard_block_size(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
op: Callable,
|
||||
) -> None:
|
||||
"""
|
||||
A separate test function specifically added
|
||||
for Qwen3-Next-80B (Block Size 544).
|
||||
"""
|
||||
if not current_platform.is_rocm():
|
||||
pytest.skip("544 block size optimization is only for ROCm.")
|
||||
|
||||
test_contexted_kv_attention(
|
||||
num_heads=64,
|
||||
num_queries_per_kv=1,
|
||||
head_size=head_size,
|
||||
block_size=544,
|
||||
sliding_window=0,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype="auto",
|
||||
device=device,
|
||||
op=op,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user