[Feat] Add CUDA torch fallbacks for fp8_mqa_logits/fp8_paged_mqa_logits_torch function (#35271)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -8,7 +8,10 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_paged_mqa_logits_metadata,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
@@ -344,7 +347,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||
|
||||
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
||||
if current_platform.is_cuda() and has_deep_gemm():
|
||||
if current_platform.is_cuda() and is_deep_gemm_supported():
|
||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
seq_lens, self.kv_cache_spec.block_size, self.num_sms
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user