[ROCm][CI] Fix tool use test stability - disable skinny GEMM, prefix caching, eliminate batch variance (#35553)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -171,7 +171,7 @@ Priority is **1 = highest** (tried first).
|
||||
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
|
||||
| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
|
||||
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
|
||||
| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
|
||||
| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A |
|
||||
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A |
|
||||
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 544 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ✅ | ✅ | ❌ | All | N/A |
|
||||
| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
|
||||
|
||||
@@ -108,3 +108,5 @@ bitsandbytes==0.49.2
|
||||
tensorizer==2.10.1
|
||||
# Multi-modal models test (`allendou/FireRedASR2-LLM-vllm`)
|
||||
kaldi-native-fbank==1.22.3
|
||||
# Pinning numpy version
|
||||
numpy==2.2.6
|
||||
|
||||
@@ -9,14 +9,13 @@ import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# downloading lora to test lora requests
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from ...utils import ROCM_ENV_OVERRIDES, ROCM_EXTRA_ARGS, RemoteOpenAIServer
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
@@ -142,19 +141,11 @@ def server():
|
||||
"--gpu-memory-utilization",
|
||||
"0.4",
|
||||
"--enforce-eager",
|
||||
]
|
||||
] + ROCM_EXTRA_ARGS
|
||||
|
||||
rocm_args = {
|
||||
"--max-num-seqs": "1",
|
||||
"--no-enable-prefix-caching": None,
|
||||
}
|
||||
if current_platform.is_rocm():
|
||||
for k, v in rocm_args.items():
|
||||
args.append(k)
|
||||
if v is not None:
|
||||
args.append(v)
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME, args, env_dict=ROCM_ENV_OVERRIDES
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@@ -239,12 +230,13 @@ def k2_server():
|
||||
"qwen3",
|
||||
"--gpu-memory-utilization",
|
||||
"0.4",
|
||||
]
|
||||
] + ROCM_EXTRA_ARGS
|
||||
# hack to test kimi_k2 tool use tool_id format.
|
||||
# avoid error in is_deepseek_mla check by setting kv_lora_rank=null
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME,
|
||||
args,
|
||||
env_dict=ROCM_ENV_OVERRIDES,
|
||||
override_hf_configs={"model_type": "kimi_k2", "kv_lora_rank": None},
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
@@ -109,6 +109,20 @@ else:
|
||||
VLLM_PATH = Path(__file__).parent.parent
|
||||
"""Path to root of the vLLM repository."""
|
||||
|
||||
# ROCm: disable skinny GEMM to avoid non-deterministic results from
|
||||
# atomic reductions in wvSplitKrc kernel.
|
||||
# See: https://github.com/vllm-project/vllm/pull/33493#issuecomment-3906083975
|
||||
ROCM_ENV_OVERRIDES = (
|
||||
{"VLLM_ROCM_USE_SKINNY_GEMM": "0"} if current_platform.is_rocm() else {}
|
||||
)
|
||||
# ROCm: disable prefix caching and eliminate batch variance to reduce
|
||||
# test flakiness.
|
||||
ROCM_EXTRA_ARGS = (
|
||||
["--no-enable-prefix-caching", "--max-num-seqs", "1"]
|
||||
if current_platform.is_rocm()
|
||||
else []
|
||||
)
|
||||
|
||||
|
||||
class RemoteVLLMServer:
|
||||
"""Base class for launching vLLM server subprocesses for testing.
|
||||
|
||||
@@ -741,6 +741,14 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
||||
"fp8_e5m2",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""ROCM AITER FA supports decoder and encoder-decoder (cross) attention."""
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [16, 32]
|
||||
|
||||
Reference in New Issue
Block a user