[Revert] Remove CUDA torch fallbacks for fp8_mqa_logits/fp8_paged_mqa_logits_torch function (#37968)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -9,13 +9,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_mqa_logits,
|
||||
fp8_mqa_logits_torch,
|
||||
fp8_paged_mqa_logits,
|
||||
fp8_paged_mqa_logits_torch,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits, has_deep_gemm
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mla.indexer import (
|
||||
DeepseekV32IndexerMetadata,
|
||||
@@ -114,23 +108,14 @@ def sparse_attn_indexer(
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
)
|
||||
if is_deep_gemm_supported():
|
||||
logits = fp8_mqa_logits(
|
||||
q_fp8[chunk.token_start : chunk.token_end],
|
||||
(k_fp8, k_scale.view(torch.float32).flatten()),
|
||||
weights[chunk.token_start : chunk.token_end],
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
clean_logits=False,
|
||||
)
|
||||
else:
|
||||
logits = fp8_mqa_logits_torch(
|
||||
q_fp8[chunk.token_start : chunk.token_end],
|
||||
(k_fp8, k_scale.view(torch.float32).flatten()),
|
||||
weights[chunk.token_start : chunk.token_end],
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
)
|
||||
logits = fp8_mqa_logits(
|
||||
q_fp8[chunk.token_start : chunk.token_end],
|
||||
(k_fp8, k_scale.view(torch.float32).flatten()),
|
||||
weights[chunk.token_start : chunk.token_end],
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
clean_logits=False,
|
||||
)
|
||||
num_rows = logits.shape[0]
|
||||
|
||||
topk_indices = topk_indices_buffer[
|
||||
@@ -194,26 +179,16 @@ def sparse_attn_indexer(
|
||||
next_n = padded_q_fp8_decode_tokens.shape[1]
|
||||
assert batch_size == decode_metadata.seq_lens.shape[0]
|
||||
num_padded_tokens = batch_size * next_n
|
||||
if is_deep_gemm_supported():
|
||||
logits = fp8_paged_mqa_logits(
|
||||
padded_q_fp8_decode_tokens,
|
||||
kv_cache,
|
||||
weights[:num_padded_tokens],
|
||||
decode_metadata.seq_lens,
|
||||
decode_metadata.block_table,
|
||||
decode_metadata.schedule_metadata,
|
||||
max_model_len=max_model_len,
|
||||
clean_logits=False,
|
||||
)
|
||||
else:
|
||||
logits = fp8_paged_mqa_logits_torch(
|
||||
padded_q_fp8_decode_tokens,
|
||||
kv_cache,
|
||||
weights[:num_padded_tokens],
|
||||
decode_metadata.seq_lens,
|
||||
decode_metadata.block_table,
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
logits = fp8_paged_mqa_logits(
|
||||
padded_q_fp8_decode_tokens,
|
||||
kv_cache,
|
||||
weights[:num_padded_tokens],
|
||||
decode_metadata.seq_lens,
|
||||
decode_metadata.block_table,
|
||||
decode_metadata.schedule_metadata,
|
||||
max_model_len=max_model_len,
|
||||
clean_logits=False,
|
||||
)
|
||||
num_rows = logits.shape[0]
|
||||
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
|
||||
|
||||
@@ -333,12 +308,9 @@ class SparseAttnIndexer(CustomOp):
|
||||
self.max_model_len = max_model_len
|
||||
self.max_total_seq_len = max_total_seq_len
|
||||
self.topk_indices_buffer = topk_indices_buffer
|
||||
if current_platform.is_cuda() and not is_deep_gemm_supported():
|
||||
logger.warning_once(
|
||||
"DeepGEMM is not supported or available. SparseAttnIndexer will use a "
|
||||
"less efficient PyTorch implementation. "
|
||||
"Please make sure you have the required hardware and software setup "
|
||||
"for DeepGEMM to achieve optimal performance."
|
||||
if current_platform.is_cuda() and not has_deep_gemm():
|
||||
raise RuntimeError(
|
||||
"Sparse Attention Indexer CUDA op requires DeepGEMM to be installed."
|
||||
)
|
||||
|
||||
def forward_native(
|
||||
|
||||
Reference in New Issue
Block a user