diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index a26fd8fbc..41805e99b 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -8,7 +8,7 @@ import torch from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported +from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, has_deep_gemm from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -342,7 +342,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): offsets = None seq_lens = common_attn_metadata.seq_lens[:num_decodes] - if is_deep_gemm_supported(): + + # DeepGEMM is required for the paged MQA logits on CUDA devices + if current_platform.is_cuda() and has_deep_gemm(): self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( seq_lens, self.kv_cache_spec.block_size, self.num_sms )