[ROCm][CI] Fix test_max_len.py for Rocm (#29916)
Signed-off-by: charlifu <charlifu@amd.com> Signed-off-by: Charlie Fu <Charlie.Fu@amd.com>
This commit is contained in:
@@ -339,7 +339,7 @@ def test_load_model(
|
||||
"multi-token eagle spec decode on current platform"
|
||||
)
|
||||
|
||||
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
|
||||
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# Setup draft model mock
|
||||
@@ -434,7 +434,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
"because it requires special input mocking."
|
||||
)
|
||||
|
||||
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
|
||||
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# Use GPU device
|
||||
@@ -541,6 +541,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||
AttentionBackendEnum.TREE_ATTN
|
||||
)
|
||||
elif attn_backend == "ROCM_AITER_FA":
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||
AttentionBackendEnum.ROCM_AITER_FA
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported attention backend: {attn_backend}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user