[BugFix][V1][ROCm] Triton MLA uses V0 backend on V1 engine (#19067)
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
This commit is contained in:
@@ -106,10 +106,8 @@ def test_env(
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
if use_v1 and name != "TRITON_MLA":
|
||||
assert backend.get_name() == f"{name}_VLLM_V1"
|
||||
else:
|
||||
assert backend.get_name() == name
|
||||
expected = f"{name}_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(16,
|
||||
|
||||
@@ -35,7 +35,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
|
||||
False, True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
assert (backend.get_name() == "TRITON_MLA"
|
||||
or backend.get_name() == "TRITON_MLA_VLLM_V1")
|
||||
|
||||
# If attention backend is None
|
||||
# If use_mla is true
|
||||
@@ -43,7 +44,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
m.setenv(STR_BACKEND_ENV_VAR, None)
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
|
||||
False, True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
assert (backend.get_name() == "TRITON_MLA"
|
||||
or backend.get_name() == "TRITON_MLA_VLLM_V1")
|
||||
|
||||
# change the attention backend to AITER MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
|
||||
|
||||
Reference in New Issue
Block a user