[Attention] Use sparse prefill kernel for fp8 kv-cache in DeepSeek-v3.2 (#27532)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -83,6 +83,7 @@ from vllm.v1.attention.backends.mla.indexer import (
|
||||
DeepseekV32IndexerMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
|
||||
from .utils import (
|
||||
@@ -616,8 +617,15 @@ def sparse_attn_indexer(
|
||||
# careful! this will be None in dummy run
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
# assert isinstance(attn_metadata, dict)
|
||||
if not isinstance(attn_metadata, dict):
|
||||
# Reserve workspace for indexer during profiling run
|
||||
current_workspace_manager().get_simultaneous(
|
||||
((total_seq_lens, head_dim), torch.float8_e4m3fn),
|
||||
((total_seq_lens, 4), torch.uint8),
|
||||
)
|
||||
|
||||
return sparse_attn_indexer_fake(
|
||||
hidden_states,
|
||||
k_cache_prefix,
|
||||
@@ -651,17 +659,17 @@ def sparse_attn_indexer(
|
||||
topk_indices_buffer[: hidden_states.shape[0]] = -1
|
||||
if has_prefill:
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
|
||||
# Get the full shared workspace buffers once (will allocate on first use)
|
||||
workspace_manager = current_workspace_manager()
|
||||
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
|
||||
((total_seq_lens, head_dim), fp8_dtype),
|
||||
((total_seq_lens, 4), torch.uint8),
|
||||
)
|
||||
|
||||
for chunk in prefill_metadata.chunks:
|
||||
k_fp8 = torch.empty(
|
||||
[chunk.total_seq_lens, head_dim],
|
||||
device=k.device,
|
||||
dtype=fp8_dtype,
|
||||
)
|
||||
k_scale = torch.empty(
|
||||
[chunk.total_seq_lens, 4],
|
||||
device=k.device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
|
||||
k_scale = k_scale_full[: chunk.total_seq_lens]
|
||||
ops.cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
k_fp8,
|
||||
@@ -777,15 +785,6 @@ def sparse_attn_indexer_fake(
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
# profile run
|
||||
# NOTE(Chen): create the max possible flattened_kv. So that
|
||||
# profile_run can get correct memory usage.
|
||||
_flattened_kv = torch.empty(
|
||||
[total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
|
||||
)
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
_k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous()
|
||||
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user