[Misc] Clean up cruft from previous FlashMLA sparse implementation (#26125)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -10,7 +10,7 @@ import torch
|
||||
from vllm.attention.ops.flashmla import (
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported,
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
@@ -27,13 +27,15 @@ def cal_diff(
|
||||
|
||||
|
||||
FLASH_MLA_UNSUPPORTED_REASON = (
|
||||
is_flashmla_supported()[1]
|
||||
if not is_flashmla_supported()[0]
|
||||
is_flashmla_dense_supported()[1]
|
||||
if not is_flashmla_dense_supported()[0]
|
||||
else "FlashMLA is supported"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_flashmla_supported()[0], reason=FLASH_MLA_UNSUPPORTED_REASON)
|
||||
@pytest.mark.skipif(
|
||||
not is_flashmla_dense_supported()[0], reason=FLASH_MLA_UNSUPPORTED_REASON
|
||||
)
|
||||
@pytest.mark.parametrize("b", [128])
|
||||
@pytest.mark.parametrize("s_q", [1, 2])
|
||||
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
|
||||
|
||||
Reference in New Issue
Block a user