[Revert] Remove CUDA torch fallbacks for fp8_mqa_logits/fp8_paged_mqa_logits_torch function (#37968)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_paged_mqa_logits_metadata,
|
||||
has_deep_gemm,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv
|
||||
@@ -449,7 +450,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
batch_size = num_decodes
|
||||
|
||||
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
||||
if current_platform.is_cuda() and is_deep_gemm_supported():
|
||||
if current_platform.is_cuda() and has_deep_gemm():
|
||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
seq_lens,
|
||||
self.kv_cache_spec.block_size,
|
||||
|
||||
Reference in New Issue
Block a user