[Bugfix] Enforce DeepGEMM when using sparse_attn_indexer on CUDA (#34374)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -10,6 +10,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mla.indexer import (
|
||||
DeepseekV32IndexerMetadata,
|
||||
@@ -277,6 +278,10 @@ class SparseAttnIndexer(CustomOp):
|
||||
self.max_model_len = max_model_len
|
||||
self.max_total_seq_len = max_total_seq_len
|
||||
self.topk_indices_buffer = topk_indices_buffer
|
||||
if current_platform.is_cuda() and not has_deep_gemm():
|
||||
raise RuntimeError(
|
||||
"Sparse Attention Indexer CUDA op requires DeepGEMM to be installed."
|
||||
)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user