- Create CuTeDSLNvFp4LinearKernel extending NvFp4LinearKernel base class - Register it via init_nvfp4_linear_kernel() selection mechanism (inserted at top of _POSSIBLE_NVFP4_KERNELS, before FlashInfer) - process_weights_after_loading: uint8→FP4, permute, create CuTeDSL runner - apply_weights: route through CuTeDSL GEMM - Update Dockerfile: copy kernel + registration script - Fix attention: always use forward() for quantized compressor/indexer layers (dtype check was fragile after kernel swaps weights to dummy BF16)
1190 lines
44 KiB
Python
1190 lines
44 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
DeepseekV4 MLA Attention Layer
|
|
"""
|
|
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any, cast
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from transformers import DeepseekV2Config, DeepseekV3Config
|
|
|
|
import vllm.envs as envs
|
|
from vllm.model_executor.layers.linear import (
|
|
ReplicatedLinear,
|
|
)
|
|
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer
|
|
from vllm.utils.deep_gemm import fp8_einsum
|
|
from vllm.utils.torch_utils import direct_register_custom_op
|
|
from vllm.v1.attention.ops.deepseek_v4_ops import (
|
|
combine_topk_swa_indices,
|
|
compute_global_topk_indices_and_lens,
|
|
dequantize_and_gather_k_cache,
|
|
fused_indexer_q_rope_quant,
|
|
fused_inv_rope_fp8_quant,
|
|
fused_q_kv_rmsnorm,
|
|
)
|
|
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.v1.attention.backends.mla.sparse_swa import (
|
|
DeepseekSparseSWAMetadata,
|
|
)
|
|
|
|
from vllm.config import (
|
|
CacheConfig,
|
|
VllmConfig,
|
|
get_current_vllm_config,
|
|
)
|
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
from vllm.forward_context import ForwardContext, get_forward_context
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.custom_op import PluggableLayer
|
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
|
from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor
|
|
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import (
|
|
QuantFP8,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
GroupShape,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils.multi_stream_utils import (
|
|
execute_in_parallel,
|
|
maybe_execute_in_parallel,
|
|
)
|
|
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
|
|
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
|
DeepseekV4FlashMLASparseBackend,
|
|
FlashMLASparseBackend,
|
|
FlashMLASparseMetadata,
|
|
)
|
|
from vllm.v1.attention.backends.mla.indexer import (
|
|
DeepseekV4IndexerBackend,
|
|
get_max_prefill_buffer_size,
|
|
)
|
|
from vllm.v1.attention.backends.mla.sparse_swa import DeepseekV4SWACache
|
|
from vllm.v1.attention.ops.flashmla import (
|
|
flash_mla_sparse_fwd,
|
|
flash_mla_with_kvcache,
|
|
)
|
|
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
|
|
from vllm.v1.worker.workspace import current_workspace_manager
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather
|
|
# workspace allocated at _forward_prefill (and the matching profile-time
|
|
# reservation in attention_impl's dummy-run branch).
|
|
PREFILL_CHUNK_SIZE = 4
|
|
|
|
|
|
@dataclass
|
|
class DeepseekV4MLAModules:
|
|
"""Modules used in DeepseekV4 MLA."""
|
|
|
|
vllm_config: VllmConfig
|
|
fused_wqa_wkv: torch.nn.Module
|
|
q_norm: torch.nn.Module
|
|
wq_b: torch.nn.Module
|
|
kv_norm: torch.nn.Module
|
|
wo_a: torch.nn.Module
|
|
wo_b: torch.nn.Module
|
|
attn_sink: torch.nn.Module
|
|
rotary_emb: torch.nn.Module
|
|
indexer: torch.nn.Module | None
|
|
indexer_rotary_emb: torch.nn.Module
|
|
topk_indices_buffer: torch.Tensor | None
|
|
aux_stream_list: list[torch.cuda.Stream] | None = None
|
|
|
|
|
|
# --8<-- [start:multi_head_latent_attention]
|
|
@PluggableLayer.register("deepseek_v4_multi_head_latent_attention")
|
|
class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
|
"""Pluggable MLA layer which allows OOT backends to add
|
|
custom implementations of the outer MLA layer (including rope & o_proj).
|
|
Note that currently oot platforms can still use CustomOp.register_oot to
|
|
replace MLA layer entirely, although we use PluggableLayer to register
|
|
this layer now.
|
|
|
|
This class takes positions and hidden_states as input.
|
|
The input tensors can either contain prefill tokens or decode tokens.
|
|
The class does the following:
|
|
|
|
1. MLA Preprocess.
|
|
2. Perform multi-head attention to prefill tokens and
|
|
multi-query attention to decode tokens separately.
|
|
3. Return the output tensor.
|
|
"""
|
|
|
|
# --8<-- [end:multi_head_latent_attention]
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
head_dim: int,
|
|
scale: float,
|
|
qk_nope_head_dim: int,
|
|
qk_rope_head_dim: int,
|
|
v_head_dim: int,
|
|
q_lora_rank: int | None,
|
|
kv_lora_rank: int,
|
|
o_lora_rank: int | None,
|
|
mla_modules: DeepseekV4MLAModules,
|
|
window_size: int,
|
|
compress_ratio: int | None,
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.n_local_heads = num_heads
|
|
self.head_dim = head_dim
|
|
self.scale = scale
|
|
|
|
# FlashMLA sparse kernel only supports 64 or 128 heads; pad up to the
|
|
# next supported size. Must match DeepseekV4MLAAttention.padded_heads.
|
|
if num_heads <= 64:
|
|
self.padded_heads = 64
|
|
elif num_heads <= 128:
|
|
self.padded_heads = 128
|
|
else:
|
|
raise ValueError(
|
|
f"DeepseekV4 attention does not support {num_heads} heads "
|
|
"(must be <= 128)."
|
|
)
|
|
|
|
self.q_lora_rank = q_lora_rank
|
|
self.kv_lora_rank = kv_lora_rank
|
|
self.window_size = window_size
|
|
self.compress_ratio = compress_ratio if compress_ratio is not None else 1
|
|
self.prefix = prefix
|
|
|
|
# Extract config from vllm_config
|
|
config = mla_modules.vllm_config.model_config.hf_config
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
# DeepseekV4-specific attributes (num_heads is already TP-adjusted)
|
|
self.eps = config.rms_norm_eps
|
|
self.rope_head_dim = config.qk_rope_head_dim
|
|
self.nope_head_dim = head_dim - self.rope_head_dim
|
|
self.n_local_groups = config.o_groups // tp_size
|
|
self.o_lora_rank = config.o_lora_rank
|
|
|
|
# Store projection modules
|
|
self.fused_wqa_wkv = mla_modules.fused_wqa_wkv
|
|
self.q_norm = mla_modules.q_norm
|
|
self.wq_b = mla_modules.wq_b
|
|
|
|
self.kv_norm = mla_modules.kv_norm
|
|
self.wo_a = mla_modules.wo_a
|
|
|
|
self._wo_a_act_quant = QuantFP8(
|
|
static=False,
|
|
group_shape=GroupShape(1, 128),
|
|
use_ue8m0=True,
|
|
)
|
|
# Bypass packed-for-deepgemm path — we need FP32 scales (not packed
|
|
# INT32) so fp8_einsum can handle layout transform internally.
|
|
self._wo_a_act_quant.use_deep_gemm_supported = False
|
|
self.wo_b = mla_modules.wo_b
|
|
|
|
# Pick fp8_einsum recipe based on GPU arch:
|
|
# SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128
|
|
# SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1
|
|
cap = current_platform.get_device_capability()
|
|
assert cap is not None, "DeepseekV4 attention requires a CUDA device"
|
|
self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
|
|
self._tma_aligned_scales = cap.major >= 10
|
|
|
|
self.rotary_emb = mla_modules.rotary_emb
|
|
self.indexer_rotary_emb = mla_modules.indexer_rotary_emb
|
|
self.topk_indices_buffer = mla_modules.topk_indices_buffer
|
|
|
|
self.indexer = mla_modules.indexer
|
|
|
|
# Per-head RMS normalization for Q (no learnable weights)
|
|
self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False)
|
|
|
|
# TODO(yifan): currently hardcoded for FP8 sparse, make it more generic
|
|
head_bytes = (
|
|
self.nope_head_dim # 448 fp8 NoPE
|
|
+ self.rope_head_dim * 2 # 64 bf16 RoPE
|
|
+ self.nope_head_dim // 64 # 7B scale factors
|
|
+ 1 # 1B pad
|
|
)
|
|
|
|
# Will be None on ROCm for now.
|
|
self.aux_stream_list = mla_modules.aux_stream_list
|
|
# [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events;
|
|
# [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins
|
|
# before post-GEMM starts.
|
|
self.ln_events = [torch.cuda.Event() for _ in range(4)]
|
|
|
|
assert cache_config is not None, "DeepseekV4 attention requires cache_config"
|
|
self.swa_cache_layer = DeepseekV4SWACache(
|
|
head_dim=self.head_dim,
|
|
window_size=self.window_size,
|
|
dtype=torch.uint8,
|
|
prefix=f"{prefix}.swa_cache",
|
|
cache_config=cache_config,
|
|
)
|
|
|
|
self.mla_attn = DeepseekV4MLAAttention(
|
|
num_heads=self.n_local_heads,
|
|
head_dim=self.head_dim,
|
|
scale=self.scale,
|
|
qk_nope_head_dim=self.nope_head_dim,
|
|
qk_rope_head_dim=self.rope_head_dim,
|
|
q_lora_rank=self.q_lora_rank,
|
|
kv_lora_rank=self.kv_lora_rank,
|
|
compress_ratio=self.compress_ratio,
|
|
window_size=self.window_size,
|
|
head_bytes=head_bytes,
|
|
swa_cache_layer=self.swa_cache_layer,
|
|
attn_sink=mla_modules.attn_sink, # already padded with -inf
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=prefix,
|
|
indexer=self.indexer,
|
|
topk_indices_buffer=self.topk_indices_buffer,
|
|
)
|
|
# Register this layer in the compilation config's static forward context
|
|
# This allows the custom op to retrieve the layer during execution
|
|
compilation_config = mla_modules.vllm_config.compilation_config
|
|
# HACK
|
|
self.layer_name = prefix + ".deepseek_v4_multi_head_latent_attention"
|
|
if self.layer_name in compilation_config.static_forward_context:
|
|
raise ValueError(f"Duplicate layer name: {self.layer_name}")
|
|
compilation_config.static_forward_context[self.layer_name] = self
|
|
|
|
# Create the compressor for layers with compress_ratio > 1; after
|
|
# creating the DeepseekV4MLAAttention layer to get its cache.
|
|
self.compressor = None
|
|
if self.compress_ratio > 1:
|
|
self.compressor = DeepseekCompressor(
|
|
vllm_config=mla_modules.vllm_config,
|
|
compress_ratio=self.compress_ratio,
|
|
hidden_size=self.hidden_size,
|
|
head_dim=self.head_dim,
|
|
rotate=True,
|
|
prefix=f"{prefix}.compressor",
|
|
k_cache_prefix=self.mla_attn.prefix,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
llama_4_scaling: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
# Pre-allocate attention output with FlashMLA-padded head count.
|
|
# The op writes into `o_padded`; we slice to n_local_heads after.
|
|
num_tokens = hidden_states.shape[0]
|
|
o_padded = torch.empty(
|
|
(num_tokens, self.padded_heads, self.head_dim),
|
|
dtype=hidden_states.dtype,
|
|
device=hidden_states.device,
|
|
)
|
|
|
|
# Attention (inside custom op for torch.compile boundary)
|
|
torch.ops.vllm.deepseek_v4_attention(
|
|
hidden_states,
|
|
positions,
|
|
o_padded,
|
|
self.layer_name,
|
|
)
|
|
o = o_padded[:, : self.n_local_heads, :]
|
|
|
|
# Keep ROCm on the BF16 reference wo_a path util kernel ready.
|
|
if current_platform.is_rocm():
|
|
z = rocm_inv_rope_einsum(
|
|
self.rotary_emb,
|
|
o,
|
|
positions,
|
|
self.rope_head_dim,
|
|
self.n_local_groups,
|
|
self.o_lora_rank,
|
|
self.wo_a,
|
|
)
|
|
return self.wo_b(z.flatten(1))
|
|
|
|
# O projection: inverse RoPE + FP8 quant + einsum + wo_b
|
|
o_fp8, o_scale = fused_inv_rope_fp8_quant(
|
|
o,
|
|
positions,
|
|
self.rotary_emb.cos_sin_cache.to(torch.float32),
|
|
n_groups=self.n_local_groups,
|
|
heads_per_group=self.n_local_heads // self.n_local_groups,
|
|
nope_dim=self.nope_head_dim,
|
|
rope_dim=self.rope_head_dim,
|
|
tma_aligned_scales=self._tma_aligned_scales,
|
|
)
|
|
|
|
wo_a_fp8 = self.wo_a.weight
|
|
wo_a_scale = self.wo_a.weight_scale_inv
|
|
|
|
z = torch.empty(
|
|
(num_tokens, self.n_local_groups, self.o_lora_rank),
|
|
device=o.device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
torch.ops.vllm.deepseek_v4_fp8_einsum(
|
|
o_fp8,
|
|
o_scale,
|
|
wo_a_fp8,
|
|
wo_a_scale,
|
|
z,
|
|
"bhr,hdr->bhd",
|
|
list(self._einsum_recipe),
|
|
)
|
|
|
|
return self.wo_b(z.flatten(1))
|
|
|
|
def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
|
|
aux_streams = self.aux_stream_list
|
|
if aux_streams is not None:
|
|
assert len(aux_streams) >= 3
|
|
aux_streams = aux_streams[:3]
|
|
|
|
# fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs
|
|
# on aux streams 0..2 when their owning module exists. ln_events[0]
|
|
# is the fan-out start event; ln_events[1..3] are per-aux done events.
|
|
# On ROCm, aux_streams is None and execute_in_parallel runs serially.
|
|
aux_fns: list[Callable[[], Any] | None] = [None, None, None]
|
|
|
|
if self.compressor is not None:
|
|
# Local ref so the closure keeps a non-None type for mypy.
|
|
compressor = self.compressor
|
|
|
|
def compressor_kv_score() -> torch.Tensor:
|
|
# Use forward() for quantized layers (NVFP4, FP8, etc.)
|
|
# — raw torch.mm doesn't work with packed/dequantized weights.
|
|
# MergedColumnParallelLinear with return_bias=False returns
|
|
# a tensor directly.
|
|
result = compressor.fused_wkv_wgate(hidden_states)
|
|
if isinstance(result, tuple):
|
|
result = result[0]
|
|
return result.to(torch.float32)
|
|
|
|
aux_fns[0] = compressor_kv_score
|
|
|
|
if self.indexer is not None:
|
|
indexer = self.indexer
|
|
|
|
def indexer_weights_proj() -> torch.Tensor:
|
|
# ReplicatedLinear returns (output, bias); bias is None.
|
|
weights, _ = indexer.weights_proj(hidden_states)
|
|
return weights
|
|
|
|
def indexer_compressor_kv_score() -> torch.Tensor:
|
|
result = indexer.compressor.fused_wkv_wgate(hidden_states)
|
|
if isinstance(result, tuple):
|
|
result = result[0]
|
|
return result.to(torch.float32)
|
|
|
|
aux_fns[1] = indexer_weights_proj
|
|
aux_fns[2] = indexer_compressor_kv_score
|
|
|
|
def fused_wqa_wkv() -> torch.Tensor:
|
|
# MergedColumnParallelLinear returns (output, bias); bias is None.
|
|
qr_kv, _ = self.fused_wqa_wkv(hidden_states)
|
|
return qr_kv
|
|
|
|
qr_kv, (kv_score, indexer_weights, indexer_kv_score) = execute_in_parallel(
|
|
fused_wqa_wkv,
|
|
aux_fns,
|
|
self.ln_events[0],
|
|
self.ln_events[1:4],
|
|
aux_streams,
|
|
enable=hidden_states.shape[0]
|
|
<= envs.VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD,
|
|
)
|
|
|
|
return qr_kv, kv_score, indexer_kv_score, indexer_weights
|
|
|
|
def attention_impl(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place
|
|
) -> None:
|
|
forward_context = get_forward_context()
|
|
attn_metadata = forward_context.attn_metadata
|
|
|
|
qr_kv, kv_score, indexer_kv_score, indexer_weights = (
|
|
self.attn_gemm_parallel_execute(hidden_states)
|
|
)
|
|
|
|
qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)
|
|
qr, kv = fused_q_kv_rmsnorm(
|
|
qr,
|
|
kv,
|
|
self.q_norm.weight.data,
|
|
self.kv_norm.weight.data,
|
|
self.eps,
|
|
)
|
|
|
|
# wq_b + kv_insert (+ MLA compressor when an indexer is present) ride
|
|
# on the default stream so q stays on its consumer stream (mla_attn
|
|
# downstream reads q on default). Indexer/compressor go on aux for
|
|
# overlap with default's GEMM + cache write.
|
|
if self.indexer is not None:
|
|
aux_stream = (
|
|
self.aux_stream_list[0] if self.aux_stream_list is not None else None
|
|
)
|
|
indexer = self.indexer
|
|
# Local ref so the closure keeps a non-None type for mypy.
|
|
assert self.compressor is not None
|
|
compressor = self.compressor
|
|
|
|
def wq_b_kv_insert_and_compress() -> torch.Tensor:
|
|
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
|
|
self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
|
|
compressor(kv_score, positions, self.rotary_emb)
|
|
return q
|
|
|
|
q, _ = maybe_execute_in_parallel(
|
|
wq_b_kv_insert_and_compress,
|
|
lambda: indexer(
|
|
hidden_states,
|
|
qr,
|
|
indexer_kv_score,
|
|
indexer_weights,
|
|
positions,
|
|
self.indexer_rotary_emb,
|
|
),
|
|
self.ln_events[0],
|
|
self.ln_events[1],
|
|
aux_stream,
|
|
)
|
|
elif self.compressor is not None:
|
|
# wq_b + kv_insert on default, compressor on aux.
|
|
aux_stream = (
|
|
self.aux_stream_list[0] if self.aux_stream_list is not None else None
|
|
)
|
|
compressor = self.compressor
|
|
|
|
def wq_b_kv_insert() -> torch.Tensor:
|
|
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
|
|
self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
|
|
return q
|
|
|
|
q, _ = maybe_execute_in_parallel(
|
|
wq_b_kv_insert,
|
|
lambda: compressor(kv_score, positions, self.rotary_emb),
|
|
self.ln_events[0],
|
|
self.ln_events[1],
|
|
aux_stream,
|
|
)
|
|
else:
|
|
# SWA-only layer: no compressor, no overlap.
|
|
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
|
|
self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
|
|
|
|
# Handle dummy run (no metadata).
|
|
if not isinstance(attn_metadata, dict):
|
|
# Reserve _forward_prefill's bf16-gather workspace; the dummy
|
|
# run returns before mla_attn runs, so without this the shared
|
|
# workspace locks below the real prefill size.
|
|
sub = self.mla_attn
|
|
swa_only = sub.compress_ratio <= 1
|
|
N = (
|
|
0
|
|
if swa_only
|
|
else (sub.max_model_len + sub.compress_ratio - 1) // sub.compress_ratio
|
|
)
|
|
M = N + sub.window_size + sub.max_num_batched_tokens
|
|
current_workspace_manager().get_simultaneous(
|
|
((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
|
|
)
|
|
out.zero_()
|
|
return
|
|
|
|
# Pad q to FlashMLA-required head count (64 or 128)
|
|
if self.n_local_heads < self.padded_heads:
|
|
pad_size = self.padded_heads - self.n_local_heads
|
|
q = F.pad(q, (0, 0, 0, pad_size), value=0.0)
|
|
|
|
# MLA attention writes into the pre-allocated `out` buffer
|
|
# ([num_tokens, padded_heads, head_dim]).
|
|
self.mla_attn(q, kv, positions, output=out)
|
|
|
|
def _fused_qnorm_rope_kv_insert(
|
|
self,
|
|
q: torch.Tensor,
|
|
kv: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
attn_metadata: (
|
|
dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None
|
|
),
|
|
) -> None:
|
|
if not isinstance(attn_metadata, dict):
|
|
return
|
|
|
|
swa_metadata = cast(
|
|
"DeepseekSparseSWAMetadata | None",
|
|
attn_metadata.get(self.swa_cache_layer.prefix),
|
|
)
|
|
assert swa_metadata is not None
|
|
|
|
swa_kv_cache = self.swa_cache_layer.kv_cache
|
|
swa_kv_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1)
|
|
|
|
# Horizontally fused:
|
|
# Q side: q_head_norm (per-head RMSNorm, no weight) + GPT-J RoPE
|
|
# KV side: GPT-J RoPE + UE8M0 FP8 quant + paged cache insert
|
|
# kv is unchanged; mla_attn reads kv solely via swa_kv_cache.
|
|
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
|
q,
|
|
kv,
|
|
swa_kv_cache_2d,
|
|
swa_metadata.slot_mapping,
|
|
positions.to(torch.int64),
|
|
self.rotary_emb.cos_sin_cache.to(torch.float32),
|
|
self.eps,
|
|
swa_metadata.block_size,
|
|
)
|
|
|
|
|
|
def deepseek_v4_attention(
|
|
hidden_states: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
out: torch.Tensor,
|
|
layer_name: str,
|
|
) -> None:
|
|
forward_context: ForwardContext = get_forward_context()
|
|
self = forward_context.no_compile_layers[layer_name]
|
|
self.attention_impl(hidden_states, positions, out)
|
|
|
|
|
|
def deepseek_v4_attention_fake(
|
|
hidden_states: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
out: torch.Tensor,
|
|
layer_name: str,
|
|
) -> None:
|
|
return None
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="deepseek_v4_attention",
|
|
op_func=deepseek_v4_attention,
|
|
mutates_args=["out"],
|
|
fake_impl=deepseek_v4_attention_fake,
|
|
)
|
|
|
|
|
|
def deepseek_v4_fp8_einsum(
|
|
a: torch.Tensor,
|
|
a_scale: torch.Tensor,
|
|
b: torch.Tensor,
|
|
b_scale: torch.Tensor,
|
|
out: torch.Tensor,
|
|
equation: str,
|
|
recipe: list[int],
|
|
) -> None:
|
|
fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe))
|
|
|
|
|
|
def deepseek_v4_fp8_einsum_fake(
|
|
a: torch.Tensor,
|
|
a_scale: torch.Tensor,
|
|
b: torch.Tensor,
|
|
b_scale: torch.Tensor,
|
|
out: torch.Tensor,
|
|
equation: str,
|
|
recipe: list[int],
|
|
) -> None:
|
|
return None
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="deepseek_v4_fp8_einsum",
|
|
op_func=deepseek_v4_fp8_einsum,
|
|
mutates_args=["out"],
|
|
fake_impl=deepseek_v4_fp8_einsum_fake,
|
|
)
|
|
|
|
|
|
class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
|
|
# FlashMLA FP8 sparse only supports 64 or 128 heads
|
|
SUPPORTED_HEAD_COUNTS = (64, 128)
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_dim: int,
|
|
scale: float,
|
|
qk_nope_head_dim: int,
|
|
qk_rope_head_dim: int,
|
|
q_lora_rank: int | None,
|
|
kv_lora_rank: int,
|
|
compress_ratio: int,
|
|
window_size: int,
|
|
head_bytes: int,
|
|
swa_cache_layer: DeepseekV4SWACache,
|
|
attn_sink: torch.Tensor,
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
# Sparse MLA Args
|
|
indexer: object | None = None,
|
|
topk_indices_buffer: torch.Tensor | None = None,
|
|
aux_stream: torch.cuda.Stream | None = None,
|
|
**extra_impl_args,
|
|
) -> None:
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.num_kv_heads = 1
|
|
self.head_dim = head_dim
|
|
self.scale = scale
|
|
self.window_size = window_size
|
|
self.head_bytes = head_bytes
|
|
self.compress_ratio = compress_ratio
|
|
self.q_lora_rank = q_lora_rank
|
|
self.kv_lora_rank = kv_lora_rank
|
|
self.nope_head_dim = qk_nope_head_dim
|
|
self.rope_head_dim = qk_rope_head_dim
|
|
self.indexer = indexer
|
|
self.topk_indices_buffer = topk_indices_buffer
|
|
|
|
self.prefix = prefix # Alias for compatibility with compressor
|
|
|
|
self.aux_stream = aux_stream
|
|
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
|
|
|
|
# Determine padded head count for FlashMLA
|
|
if num_heads not in self.SUPPORTED_HEAD_COUNTS:
|
|
if num_heads < 64:
|
|
self.padded_heads = 64
|
|
elif num_heads < 128:
|
|
self.padded_heads = 128
|
|
else:
|
|
raise ValueError(
|
|
f"DeepseekV4MLAAttention does not support {num_heads} heads. "
|
|
f"Supported: <= 128 (will be padded to 64 or 128)"
|
|
)
|
|
else:
|
|
self.padded_heads = num_heads
|
|
|
|
# Store attention sink
|
|
assert attn_sink is not None
|
|
self.attn_sink: torch.Tensor = attn_sink
|
|
# Store SWA cache
|
|
assert swa_cache_layer is not None
|
|
self.swa_cache_layer: DeepseekV4SWACache = swa_cache_layer
|
|
|
|
# Get vllm config for cache setup
|
|
vllm_config = get_current_vllm_config()
|
|
self.max_num_batched_tokens = (
|
|
vllm_config.scheduler_config.max_num_batched_tokens
|
|
)
|
|
self.max_model_len = vllm_config.model_config.max_model_len
|
|
# DeepseekV4 only supports fp8 kv-cache format for now.
|
|
kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8"
|
|
|
|
assert kv_cache_dtype.startswith("fp8"), (
|
|
f"DeepseekV4 only supports fp8 kv-cache format for now, "
|
|
f"got {kv_cache_dtype}"
|
|
)
|
|
assert issubclass(self.get_attn_backend(), FlashMLASparseBackend), (
|
|
"Only FlashMLA Sparse Attention backend is supported for DeepseekV4 for now"
|
|
)
|
|
# FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format
|
|
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
|
|
if (
|
|
issubclass(self.get_attn_backend(), FlashMLASparseBackend)
|
|
and kv_cache_dtype.startswith("fp8")
|
|
and kv_cache_dtype != "fp8_ds_mla"
|
|
):
|
|
assert cache_config is not None
|
|
cache_config.cache_dtype = "fp8_ds_mla"
|
|
kv_cache_dtype = "fp8_ds_mla"
|
|
logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.")
|
|
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
|
|
# Register with compilation context for metadata lookup
|
|
compilation_config = vllm_config.compilation_config
|
|
if prefix and prefix in compilation_config.static_forward_context:
|
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
|
if prefix:
|
|
compilation_config.static_forward_context[prefix] = self
|
|
|
|
self.kv_cache = torch.tensor([])
|
|
|
|
def get_attn_backend(self) -> type[AttentionBackend]:
|
|
if current_platform.is_rocm():
|
|
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import (
|
|
DeepseekV4ROCMAiterMLASparseBackend,
|
|
)
|
|
|
|
return DeepseekV4ROCMAiterMLASparseBackend
|
|
return DeepseekV4FlashMLASparseBackend
|
|
|
|
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
|
|
if (
|
|
self.compress_ratio <= 1
|
|
): # SWA part. Allocated separately as DeepseekV4SWACache.
|
|
return None
|
|
return MLAAttentionSpec(
|
|
block_size=vllm_config.cache_config.block_size,
|
|
num_kv_heads=1,
|
|
head_size=self.head_dim,
|
|
dtype=torch.uint8,
|
|
compress_ratio=self.compress_ratio,
|
|
cache_dtype_str=self.kv_cache_dtype,
|
|
alignment=576, # NOTE: FlashMLA requires 576B alignment
|
|
model_version="deepseek_v4",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
q: torch.Tensor,
|
|
kv: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
output: torch.Tensor,
|
|
) -> None:
|
|
assert output.shape == q.shape, (
|
|
f"output buffer shape {output.shape} must match q shape {q.shape}"
|
|
)
|
|
assert output.dtype == q.dtype, (
|
|
f"output buffer dtype {output.dtype} must match q dtype {q.dtype}"
|
|
)
|
|
|
|
if current_platform.is_rocm():
|
|
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import (
|
|
DeepseekV4ROCMAiterMLASparseImpl,
|
|
)
|
|
|
|
DeepseekV4ROCMAiterMLASparseImpl.forward(self, q, kv, positions, output)
|
|
return
|
|
|
|
# Get SWA and indexer metadata from forward context
|
|
forward_context = get_forward_context()
|
|
attn_metadata = forward_context.attn_metadata
|
|
assert isinstance(attn_metadata, dict)
|
|
flashmla_metadata = cast(
|
|
FlashMLASparseMetadata | None, attn_metadata.get(self.prefix)
|
|
)
|
|
swa_metadata = cast(
|
|
"DeepseekSparseSWAMetadata | None",
|
|
attn_metadata.get(self.swa_cache_layer.prefix),
|
|
)
|
|
assert swa_metadata is not None
|
|
|
|
swa_only = self.compress_ratio <= 1
|
|
# SWA-only layers (compress_ratio <= 1) don't have their own KV cache
|
|
# allocation, so self.kv_cache may be empty after profiling cleanup.
|
|
self_kv_cache = self.kv_cache if not swa_only else None
|
|
swa_kv_cache = self.swa_cache_layer.kv_cache
|
|
|
|
# Split prefill and decode
|
|
num_decodes = swa_metadata.num_decodes
|
|
num_prefills = swa_metadata.num_prefills
|
|
num_decode_tokens = swa_metadata.num_decode_tokens
|
|
|
|
if num_prefills > 0:
|
|
self._forward_prefill(
|
|
q=q[num_decode_tokens:],
|
|
positions=positions[num_decode_tokens:],
|
|
compressed_k_cache=self_kv_cache,
|
|
swa_k_cache=swa_kv_cache,
|
|
output=output[num_decode_tokens:],
|
|
attn_metadata=flashmla_metadata,
|
|
swa_metadata=swa_metadata,
|
|
)
|
|
if num_decodes > 0:
|
|
self._forward_decode(
|
|
q=q[:num_decode_tokens],
|
|
kv_cache=self_kv_cache,
|
|
swa_metadata=swa_metadata,
|
|
attn_metadata=flashmla_metadata,
|
|
swa_only=swa_only,
|
|
output=output[:num_decode_tokens],
|
|
)
|
|
|
|
def _forward_decode(
|
|
self,
|
|
q: torch.Tensor,
|
|
kv_cache: torch.Tensor | None, # Only used when compress_ratio > 1
|
|
swa_metadata: "DeepseekSparseSWAMetadata",
|
|
attn_metadata: FlashMLASparseMetadata | None,
|
|
swa_only: bool,
|
|
output: torch.Tensor,
|
|
) -> None:
|
|
num_decodes = swa_metadata.num_decodes
|
|
num_decode_tokens = swa_metadata.num_decode_tokens
|
|
|
|
topk_indices = None
|
|
topk_lens = None
|
|
if not swa_only:
|
|
assert attn_metadata is not None
|
|
assert swa_metadata.is_valid_token is not None
|
|
block_size = attn_metadata.block_size // self.compress_ratio
|
|
is_valid = swa_metadata.is_valid_token[:num_decode_tokens]
|
|
if self.compress_ratio == 4:
|
|
# C4A: local indices differ per layer (filled by Indexer).
|
|
assert self.topk_indices_buffer is not None
|
|
global_indices, topk_lens = compute_global_topk_indices_and_lens(
|
|
self.topk_indices_buffer[:num_decode_tokens],
|
|
swa_metadata.token_to_req_indices,
|
|
attn_metadata.block_table[:num_decodes],
|
|
block_size,
|
|
is_valid,
|
|
)
|
|
topk_indices = global_indices.view(num_decode_tokens, 1, -1)
|
|
else:
|
|
# C128A: pre-computed during metadata build.
|
|
topk_indices = attn_metadata.c128a_global_decode_topk_indices
|
|
topk_lens = attn_metadata.c128a_decode_topk_lens
|
|
|
|
swa_indices = swa_metadata.decode_swa_indices
|
|
swa_lens = swa_metadata.decode_swa_lens
|
|
|
|
# We treat queries in the same seq as different queries
|
|
# and later we only attend by generated indices.
|
|
# q arrives pre-padded to self.padded_heads by the outer wrapper.
|
|
q = q.unsqueeze(1)
|
|
|
|
# Prepare SWA cache (num_blocks, swa_block_size, 1, head_bytes)
|
|
# Use unsqueeze to preserve strides (handles padded blocks correctly)
|
|
swa_cache = self.swa_cache_layer.kv_cache.unsqueeze(-2)
|
|
# Reshape KV cache to (num_blocks, block_size, 1, head_bytes)
|
|
if kv_cache is not None:
|
|
kv_cache = kv_cache.unsqueeze(-2)
|
|
|
|
# One FlashMLASchedMeta per layer type, shared across all same-type
|
|
# layers within this decode step. The first forward call per type
|
|
# triggers the in-kernel planner (allocating tile_scheduler_metadata
|
|
# and num_splits via PyTorch's graph-aware allocator so CUDA graph
|
|
# capture reuses the same addresses on replay); subsequent same-type
|
|
# layers see have_initialized=True and skip the planner.
|
|
if self.compress_ratio <= 1:
|
|
tile_metadata = swa_metadata.tile_sched_swaonly
|
|
elif self.compress_ratio == 4:
|
|
tile_metadata = swa_metadata.tile_sched_c4a
|
|
elif self.compress_ratio == 128:
|
|
tile_metadata = swa_metadata.tile_sched_c128a
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported compress_ratio={self.compress_ratio}; "
|
|
"expected 1, 4, or 128."
|
|
)
|
|
assert tile_metadata is not None, (
|
|
"swa_metadata missing tile_sched entry for "
|
|
f"compress_ratio={self.compress_ratio}; "
|
|
"DeepseekSparseSWAMetadataBuilder.build_tile_scheduler did not "
|
|
"allocate one for this layer type."
|
|
)
|
|
|
|
out, _ = flash_mla_with_kvcache(
|
|
q=q,
|
|
k_cache=swa_cache,
|
|
block_table=None,
|
|
head_dim_v=512,
|
|
tile_scheduler_metadata=tile_metadata,
|
|
cache_seqlens=None,
|
|
is_fp8_kvcache=True,
|
|
indices=swa_indices,
|
|
topk_length=swa_lens,
|
|
softmax_scale=self.scale,
|
|
attn_sink=self.attn_sink,
|
|
extra_k_cache=kv_cache if not swa_only else None,
|
|
extra_indices_in_kvcache=topk_indices,
|
|
extra_topk_length=topk_lens,
|
|
out=output.unsqueeze(1),
|
|
)
|
|
|
|
def _forward_prefill(
|
|
self,
|
|
q: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
compressed_k_cache: torch.Tensor | None, # Only used when compress_ratio > 1
|
|
swa_k_cache: torch.Tensor,
|
|
output: torch.Tensor,
|
|
attn_metadata: FlashMLASparseMetadata | None,
|
|
swa_metadata: "DeepseekSparseSWAMetadata",
|
|
) -> None:
|
|
swa_only = attn_metadata is None
|
|
|
|
num_prefills = swa_metadata.num_prefills
|
|
num_prefill_tokens = swa_metadata.num_prefill_tokens
|
|
num_decodes = swa_metadata.num_decodes
|
|
num_decode_tokens = swa_metadata.num_decode_tokens
|
|
|
|
# Use pre-computed prefill metadata.
|
|
seq_lens = swa_metadata.prefill_seq_lens
|
|
gather_lens = swa_metadata.prefill_gather_lens
|
|
assert seq_lens is not None
|
|
assert gather_lens is not None
|
|
|
|
# Derive prefill-local token offsets from the full query_start_loc_cpu.
|
|
query_start_loc_cpu = swa_metadata.query_start_loc_cpu
|
|
query_start_loc = swa_metadata.query_start_loc
|
|
assert query_start_loc_cpu is not None
|
|
assert query_start_loc is not None
|
|
prefill_token_base = query_start_loc_cpu[num_decodes]
|
|
|
|
if not swa_only:
|
|
if self.compress_ratio == 4:
|
|
assert self.topk_indices_buffer is not None
|
|
topk_indices = self.topk_indices_buffer[num_decode_tokens:]
|
|
topk_indices = topk_indices[:num_prefill_tokens]
|
|
else:
|
|
# C128A: pre-computed during metadata build.
|
|
assert attn_metadata is not None
|
|
topk_indices = attn_metadata.c128a_prefill_topk_indices
|
|
top_k = topk_indices.shape[-1]
|
|
# Compressed region must fit the full compressed pool (seq_len //
|
|
# compress_ratio), not just top_k. top_k bounds how many indices
|
|
# the indexer selects, not the pool size it indexes into.
|
|
N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio
|
|
else:
|
|
# NOTE(woosuk): topk_indices will not be used for SWA-only layers.
|
|
assert self.topk_indices_buffer is not None
|
|
topk_indices = self.topk_indices_buffer[num_decode_tokens:]
|
|
top_k = 0
|
|
N = 0
|
|
|
|
M = N + self.window_size + self.max_num_batched_tokens
|
|
num_chunks = (num_prefills + PREFILL_CHUNK_SIZE - 1) // PREFILL_CHUNK_SIZE
|
|
|
|
workspace_manager = current_workspace_manager()
|
|
kv = workspace_manager.get_simultaneous(
|
|
((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
|
|
)[0]
|
|
for chunk_idx in range(num_chunks):
|
|
chunk_start = chunk_idx * PREFILL_CHUNK_SIZE
|
|
chunk_end = min(chunk_start + PREFILL_CHUNK_SIZE, num_prefills)
|
|
chunk_size = chunk_end - chunk_start
|
|
if not swa_only:
|
|
# Gather compressed KV
|
|
assert attn_metadata is not None
|
|
block_table = attn_metadata.block_table[num_decodes:]
|
|
dequantize_and_gather_k_cache(
|
|
kv[:chunk_size],
|
|
compressed_k_cache,
|
|
seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio,
|
|
gather_lens=None,
|
|
block_table=block_table[chunk_start:chunk_end],
|
|
block_size=attn_metadata.block_size // self.compress_ratio,
|
|
offset=0,
|
|
)
|
|
|
|
# Gather SWA KV
|
|
swa_block_table = swa_metadata.block_table[num_decodes:]
|
|
dequantize_and_gather_k_cache(
|
|
kv[:chunk_size],
|
|
swa_k_cache,
|
|
seq_lens=seq_lens[chunk_start:chunk_end],
|
|
gather_lens=gather_lens[chunk_start:chunk_end],
|
|
block_table=swa_block_table[chunk_start:chunk_end],
|
|
block_size=swa_metadata.block_size,
|
|
offset=N,
|
|
)
|
|
|
|
# Combine the topk indices and SWA indices for gathered KV cache
|
|
query_start = (
|
|
query_start_loc_cpu[num_decodes + chunk_start] - prefill_token_base
|
|
)
|
|
query_end = (
|
|
query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base
|
|
)
|
|
|
|
combined_indices, combined_lens = combine_topk_swa_indices(
|
|
topk_indices[query_start:query_end],
|
|
query_start_loc[
|
|
num_decodes + chunk_start : num_decodes + chunk_end + 1
|
|
],
|
|
seq_lens[chunk_start:chunk_end],
|
|
gather_lens[chunk_start:chunk_end],
|
|
self.window_size,
|
|
self.compress_ratio,
|
|
top_k,
|
|
M,
|
|
N,
|
|
)
|
|
flash_mla_sparse_fwd(
|
|
q=q[query_start:query_end],
|
|
kv=kv.view(-1, 1, q.shape[-1]),
|
|
indices=combined_indices.unsqueeze(1),
|
|
sm_scale=self.scale,
|
|
attn_sink=self.attn_sink,
|
|
topk_length=combined_lens,
|
|
out=output[query_start:query_end],
|
|
)
|
|
|
|
|
|
class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase):
|
|
def __init__(
|
|
self,
|
|
head_dim: int,
|
|
dtype: torch.dtype,
|
|
prefix: str,
|
|
cache_config: CacheConfig,
|
|
compress_ratio: int = 1,
|
|
):
|
|
super().__init__()
|
|
self.kv_cache = torch.tensor([])
|
|
self.head_dim = head_dim
|
|
self.prefix = prefix
|
|
self.cache_config = cache_config
|
|
self.dtype = dtype
|
|
self.compress_ratio = compress_ratio
|
|
compilation_config = get_current_vllm_config().compilation_config
|
|
if prefix in compilation_config.static_forward_context:
|
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
|
compilation_config.static_forward_context[prefix] = self
|
|
|
|
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
|
# head_dim already carries the fp8 scale padding
|
|
# compress_ratio=1 for V3.2, >1 for DeepseekV4; both use the same cache layout.
|
|
return MLAAttentionSpec(
|
|
block_size=self.cache_config.block_size,
|
|
num_kv_heads=1,
|
|
head_size=self.head_dim,
|
|
dtype=self.dtype,
|
|
compress_ratio=self.compress_ratio,
|
|
# DeepseekV4 aligns indexer pages to FlashMLA's 576B so they can pack with
|
|
# the indexer's compressor state cache. V3.2 keeps the legacy layout.
|
|
alignment=576,
|
|
)
|
|
|
|
def forward(self): ...
|
|
|
|
def get_attn_backend(self) -> type[AttentionBackend]:
|
|
return DeepseekV4IndexerBackend
|
|
|
|
|
|
class DeepseekV4Indexer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
config: DeepseekV2Config | DeepseekV3Config,
|
|
hidden_size: int,
|
|
q_lora_rank: int,
|
|
quant_config: QuantizationConfig | None,
|
|
cache_config: CacheConfig | None,
|
|
topk_indices_buffer: torch.Tensor | None,
|
|
compress_ratio: int = 1,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.vllm_config = vllm_config
|
|
self.config = config
|
|
self.quant_config = quant_config
|
|
# self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
|
|
self.topk_tokens = config.index_topk
|
|
self.n_head = config.index_n_heads # 64
|
|
self.head_dim = config.index_head_dim # 128
|
|
self.rope_dim = config.qk_rope_head_dim # 64
|
|
self.q_lora_rank = q_lora_rank # 1536
|
|
self.compress_ratio = compress_ratio
|
|
self.use_fp4_kv = self.vllm_config.attention_config.use_fp4_indexer_cache
|
|
logger.info_once(
|
|
"Using %s indexer cache for Lightning Indexer.",
|
|
"MXFP4" if self.use_fp4_kv else "FP8",
|
|
)
|
|
|
|
# no tensor parallel, just replicated
|
|
self.wq_b = ReplicatedLinear(
|
|
self.q_lora_rank,
|
|
self.head_dim * self.n_head,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.wq_b",
|
|
)
|
|
self.weights_proj = ReplicatedLinear(
|
|
hidden_size,
|
|
self.n_head,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.weights_proj",
|
|
)
|
|
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
|
|
self.softmax_scale = self.head_dim**-0.5
|
|
|
|
self.scale_fmt = "ue8m0"
|
|
self.quant_block_size = 128 # TODO: get from config
|
|
self.topk_indices_buffer = topk_indices_buffer
|
|
|
|
self.max_model_len = (
|
|
vllm_config.model_config.max_model_len // self.compress_ratio
|
|
)
|
|
self.prefix = prefix
|
|
|
|
self.max_total_seq_len = (
|
|
get_max_prefill_buffer_size(vllm_config) // self.compress_ratio
|
|
)
|
|
|
|
assert cache_config is not None, "Deepseek V4 indexer requires cache_config"
|
|
# NOTE(yifan): FP8 indxer cache use the same layout as V3.2:
|
|
# head_dim bytes = 128 fp8 + 4 fp32 scale = 132.
|
|
# For FP4 indexer cache, we still allocate the same amount of memory as FP8,
|
|
# but only use the first half of the memory.
|
|
k_cache_head_dim = self.head_dim + self.head_dim // self.quant_block_size * 4
|
|
self.k_cache = DeepseekV4IndexerCache(
|
|
head_dim=k_cache_head_dim,
|
|
dtype=torch.uint8,
|
|
prefix=f"{prefix}.k_cache",
|
|
cache_config=cache_config,
|
|
compress_ratio=self.compress_ratio,
|
|
)
|
|
self.compressor = DeepseekCompressor(
|
|
vllm_config=vllm_config,
|
|
compress_ratio=self.compress_ratio,
|
|
hidden_size=hidden_size,
|
|
head_dim=self.head_dim,
|
|
rotate=True,
|
|
prefix=f"{prefix}.compressor",
|
|
k_cache_prefix=self.k_cache.prefix,
|
|
use_fp4_cache=self.use_fp4_kv,
|
|
)
|
|
|
|
self.indexer_op = SparseAttnIndexer(
|
|
self.k_cache,
|
|
self.quant_block_size,
|
|
self.scale_fmt,
|
|
self.topk_tokens,
|
|
self.head_dim,
|
|
self.max_model_len,
|
|
self.max_total_seq_len,
|
|
self.topk_indices_buffer,
|
|
skip_k_cache_insert=True,
|
|
use_fp4_cache=self.use_fp4_kv,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
qr: torch.Tensor,
|
|
compressed_kv_score: torch.Tensor,
|
|
indexer_weights: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
rotary_emb: nn.Module,
|
|
) -> torch.Tensor:
|
|
# ReplicatedLinear returns (output, bias); bias is None.
|
|
q, _ = self.wq_b(qr)
|
|
q = q.view(-1, self.n_head, self.head_dim)
|
|
k = self.compressor(compressed_kv_score, positions, rotary_emb)
|
|
q_quant, weights = fused_indexer_q_rope_quant(
|
|
positions,
|
|
q,
|
|
rotary_emb.cos_sin_cache,
|
|
indexer_weights,
|
|
self.softmax_scale,
|
|
self.n_head**-0.5,
|
|
use_fp4=self.use_fp4_kv,
|
|
)
|
|
return self.indexer_op(hidden_states, q_quant, k, weights)
|