Clean vLLM integration: use official paths, BF16 wo_a, proper weight mapper
- deepseek_v4.py: Fresh upstream copy with minimal NVFP4 changes - wo_a uses quant_config=None (BF16 in NVFP4 checkpoint, no scales) - Added _make_deepseek_v4_nvfp4_weights_mapper() using official WeightsMapper API - Handles: self_attn→attn, mlp→ffn, gate_proj→w1, compressor renames, etc. - Mapper selected by quant_config.get_name() == 'modelopt_fp4' - deepseek_v4_attention.py: Fresh upstream copy with minimal NVFP4 changes - Removed _wo_a_act_quant and custom CuTeDSL wo_a runner - Added _apply_inv_rope_bf16() helper (inverse RoPE in BF16) - Detects BF16 wo_a (no weight_scale_inv) and uses BF16 path - FP8 einsum path kept as fallback for SM90 checkpoints - BF16 path: inverse RoPE → wo_a() → wo_b() (standard linear methods)
This commit is contained in:
@@ -23,11 +23,14 @@ from vllm.model_executor.layers.deepseek_v4_attention import (
|
||||
DeepseekV4MLAModules,
|
||||
DeepseekV4MultiHeadLatentAttentionWrapper,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
|
||||
fused_topk_bias,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.norm_gate_linear import (
|
||||
NormGateLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@@ -35,6 +38,12 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mhc import (
|
||||
HCHeadOp,
|
||||
MHCFusedPostPreOp,
|
||||
MHCPostOp,
|
||||
MHCPreOp,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import (
|
||||
QuantizationConfig,
|
||||
QuantizationMethods,
|
||||
@@ -749,23 +758,23 @@ class DeepseekV4MoE(nn.Module):
|
||||
"deep_gemm_mega_moe for this checkpoint."
|
||||
)
|
||||
|
||||
self.gate = GateLinear(
|
||||
config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
out_dtype=torch.float32,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.gate",
|
||||
# Fused RMSNorm + gate: owns both ffn_norm and the gate matmul.
|
||||
self.norm_gate = NormGateLinear(
|
||||
hidden_size=config.hidden_size,
|
||||
num_experts=config.n_routed_experts,
|
||||
rms_eps=config.rms_norm_eps,
|
||||
prefix=f"{prefix}.norm_gate",
|
||||
)
|
||||
self.gate.e_score_correction_bias = None
|
||||
self.gate.tid2eid = None
|
||||
# Routing-side tensors live on ``norm_gate`` directly (not on the
|
||||
# inner gate); they are initialized to None in NormGatedLinear and
|
||||
# populated below depending on the MoE variant.
|
||||
is_hash_moe = extract_layer_index(prefix) < config.num_hash_layers
|
||||
self.hash_indices_dtype = torch.int64 if self.use_mega_moe else torch.int32
|
||||
|
||||
if is_hash_moe:
|
||||
# hash MoE doesn't use e_score_correction_bias
|
||||
# Use randint instead of empty to avoid garbage values causing
|
||||
# invalid memory access in dummy mode (--load-format="dummy")
|
||||
self.gate.tid2eid = nn.Parameter(
|
||||
self.norm_gate.tid2eid = nn.Parameter(
|
||||
torch.randint(
|
||||
0,
|
||||
config.n_routed_experts,
|
||||
@@ -775,7 +784,7 @@ class DeepseekV4MoE(nn.Module):
|
||||
requires_grad=False,
|
||||
)
|
||||
elif getattr(config, "topk_method", None) == "noaux_tc":
|
||||
self.gate.e_score_correction_bias = nn.Parameter(
|
||||
self.norm_gate.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(config.n_routed_experts, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -838,10 +847,9 @@ class DeepseekV4MoE(nn.Module):
|
||||
self.n_local_experts = config.n_routed_experts // self.tp_size
|
||||
self.experts_start_idx = self.tp_rank * self.n_local_experts
|
||||
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
||||
|
||||
# We don't pass `gate` into FusedMoE
|
||||
self.experts = FusedMoE(
|
||||
shared_experts=self.shared_experts,
|
||||
gate=self.gate,
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -851,8 +859,8 @@ class DeepseekV4MoE(nn.Module):
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
hash_indices_table=self.gate.tid2eid,
|
||||
e_score_correction_bias=self.norm_gate.e_score_correction_bias,
|
||||
hash_indices_table=self.norm_gate.tid2eid,
|
||||
swiglu_limit=self.swiglu_limit,
|
||||
router_logits_dtype=torch.float32,
|
||||
)
|
||||
@@ -860,40 +868,40 @@ class DeepseekV4MoE(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
if self.gate.tid2eid is not None and input_ids is None:
|
||||
if self.norm_gate.tid2eid is not None and input_ids is None:
|
||||
raise ValueError("DeepSeek V4 hash MoE routing requires input_ids.")
|
||||
|
||||
if not self.use_mega_moe:
|
||||
return self._forward_fused_moe(hidden_states, input_ids)
|
||||
|
||||
org_shape = hidden_states.shape
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
normed_x, router_logits = self.norm_gate(hidden_states)
|
||||
topk_weights, topk_ids = fused_topk_bias(
|
||||
hidden_states=hidden_states,
|
||||
hidden_states=normed_x,
|
||||
gating_output=router_logits,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias.data
|
||||
if self.gate.e_score_correction_bias is not None
|
||||
e_score_correction_bias=self.norm_gate.e_score_correction_bias.data
|
||||
if self.norm_gate.e_score_correction_bias is not None
|
||||
else None,
|
||||
topk=self.n_activated_experts,
|
||||
renormalize=self.renormalize,
|
||||
indices_type=self.hash_indices_dtype,
|
||||
input_tokens=input_ids,
|
||||
hash_indices_table=self.gate.tid2eid,
|
||||
hash_indices_table=self.norm_gate.tid2eid,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
activation_clamp = (
|
||||
float(self.swiglu_limit) if self.swiglu_limit is not None else None
|
||||
)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states,
|
||||
normed_x,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation_clamp=activation_clamp,
|
||||
)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
shared_output = self.shared_experts(normed_x)
|
||||
final_hidden_states += shared_output
|
||||
|
||||
return final_hidden_states.view(org_shape)
|
||||
@@ -901,21 +909,14 @@ class DeepseekV4MoE(nn.Module):
|
||||
def _forward_fused_moe(
|
||||
self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
assert not self.experts.is_internal_router
|
||||
org_shape = hidden_states.shape
|
||||
if self.experts.is_internal_router:
|
||||
# In this case, the gate/router runs inside the FusedMoE class
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=hidden_states,
|
||||
input_ids=input_ids,
|
||||
)
|
||||
else:
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
input_ids=input_ids,
|
||||
)
|
||||
normed_x, router_logits = self.norm_gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=normed_x,
|
||||
router_logits=router_logits,
|
||||
input_ids=input_ids,
|
||||
)
|
||||
|
||||
return final_hidden_states.view(org_shape)
|
||||
|
||||
@@ -989,10 +990,8 @@ class DeepseekV4Attention(nn.Module):
|
||||
)
|
||||
|
||||
self.kv_norm = RMSNorm(self.head_dim, self.eps)
|
||||
# wo_a is NOT quantized in the NVFP4 checkpoint (modelopt left it as bfloat16),
|
||||
# but the attention forward pass expects FP8 (weight + weight_scale_inv).
|
||||
# Pass quant_config=None to load bfloat16, then process_weights_after_loading
|
||||
# will handle the FP8 quantization.
|
||||
# wo_a is BF16 in the NVFP4 checkpoint (no quantization scales).
|
||||
# Pass quant_config=None so it loads as a plain BF16 linear layer.
|
||||
self.wo_a = ColumnParallelLinear(
|
||||
self.n_heads * self.head_dim // self.n_groups,
|
||||
self.n_groups * self.o_lora_rank,
|
||||
@@ -1123,7 +1122,8 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
self.ffn = DeepseekV4MoE(vllm_config, prefix=f"{prefix}.ffn")
|
||||
|
||||
self.attn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps)
|
||||
self.ffn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps)
|
||||
# ``ffn_norm`` is owned by ``self.ffn.norm_gate`` (fused with the
|
||||
# router gate matmul); see ``NormGatedLinear``.
|
||||
self.hc_mult = config.hc_mult
|
||||
self.hc_sinkhorn_iters = config.hc_sinkhorn_iters
|
||||
self.hc_eps = config.hc_eps
|
||||
@@ -1172,6 +1172,9 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.mhc_pre = MHCPreOp()
|
||||
self.mhc_post = MHCPostOp()
|
||||
self.mhc_fused_post_pre = MHCFusedPostPreOp()
|
||||
|
||||
def hc_pre(
|
||||
self,
|
||||
@@ -1180,7 +1183,7 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
hc_scale: torch.Tensor,
|
||||
hc_base: torch.Tensor,
|
||||
):
|
||||
post_mix, res_mix, layer_input = torch.ops.vllm.mhc_pre(
|
||||
post_mix, res_mix, layer_input = self.mhc_pre(
|
||||
residual=x,
|
||||
fn=hc_fn,
|
||||
hc_scale=hc_scale,
|
||||
@@ -1200,17 +1203,17 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
post: torch.Tensor,
|
||||
comb: torch.Tensor,
|
||||
):
|
||||
return torch.ops.vllm.mhc_post(x, residual, post, comb)
|
||||
return self.mhc_post(x, residual, post, comb)
|
||||
|
||||
def forward(
|
||||
def _forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_ids: torch.Tensor | None,
|
||||
post_mix: torch.Tensor | None,
|
||||
res_mix: torch.Tensor | None,
|
||||
residual: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
post_mix: torch.Tensor | None = None,
|
||||
res_mix: torch.Tensor | None = None,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
# Run standalone hc_pre on first layer
|
||||
residual = x
|
||||
@@ -1218,7 +1221,7 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base
|
||||
)
|
||||
else:
|
||||
residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre(
|
||||
residual, post_mix, res_mix, x = self.mhc_fused_post_pre(
|
||||
x,
|
||||
residual,
|
||||
post_mix,
|
||||
@@ -1236,7 +1239,7 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
x = self.attn_norm(x)
|
||||
x = self.attn(positions, x, None)
|
||||
|
||||
residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre(
|
||||
residual, post_mix, res_mix, x = self.mhc_fused_post_pre(
|
||||
x,
|
||||
residual,
|
||||
post_mix,
|
||||
@@ -1250,11 +1253,58 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
self.hc_post_alpha,
|
||||
self.hc_sinkhorn_iters,
|
||||
)
|
||||
|
||||
x = self.ffn_norm(x)
|
||||
# ffn_norm is now folded into self.ffn.norm_gate; ffn() takes
|
||||
# the pre-norm activation directly.
|
||||
x = self.ffn(x, input_ids)
|
||||
return x, residual, post_mix, res_mix
|
||||
|
||||
def _forward_rocm(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_ids: torch.Tensor | None,
|
||||
post_mix: torch.Tensor | None = None,
|
||||
res_mix: torch.Tensor | None = None,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> tuple[
|
||||
torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None
|
||||
]:
|
||||
residual = x
|
||||
x, post, comb = self.hc_pre(
|
||||
x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base
|
||||
)
|
||||
x = self.attn_norm(x)
|
||||
x = self.attn(positions, x, None)
|
||||
x = self.hc_post(x, residual, post, comb)
|
||||
|
||||
residual = x
|
||||
x, post, comb = self.hc_pre(
|
||||
x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base
|
||||
)
|
||||
# ffn_norm is now folded into self.ffn.norm_gate; ffn() takes
|
||||
# the pre-norm activation directly.
|
||||
x = self.ffn(x, input_ids)
|
||||
x = self.hc_post(x, residual, post, comb)
|
||||
return x, None, None, None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_ids: torch.Tensor | None,
|
||||
post_mix: torch.Tensor | None = None,
|
||||
res_mix: torch.Tensor | None = None,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> tuple[
|
||||
torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None
|
||||
]:
|
||||
if current_platform.is_rocm():
|
||||
return self._forward_rocm(
|
||||
x, positions, input_ids, post_mix, res_mix, residual
|
||||
)
|
||||
|
||||
return self._forward_cuda(x, positions, input_ids, post_mix, res_mix, residual)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class DeepseekV4Model(nn.Module):
|
||||
@@ -1262,17 +1312,6 @@ class DeepseekV4Model(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
# Select weight mapper based on quantization method.
|
||||
# NVFP4 (modelopt_fp4) checkpoints use different key naming
|
||||
# than the default MXFP4 format.
|
||||
quant_config = vllm_config.quant_config
|
||||
if quant_config is not None and getattr(quant_config, "get_name", lambda: None)() == "modelopt_fp4":
|
||||
self.hf_to_vllm_mapper = _make_deepseek_v4_nvfp4_weights_mapper()
|
||||
elif getattr(config, "expert_dtype", "fp4") != "fp4":
|
||||
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp8")
|
||||
else:
|
||||
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4")
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.use_mega_moe = (
|
||||
@@ -1355,7 +1394,7 @@ class DeepseekV4Model(nn.Module):
|
||||
torch.empty(1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
self.hc_head_op = HCHeadOp()
|
||||
# Pre-hc_head residual stream buffer for the MTP draft. Stable
|
||||
# address (outside the cudagraph pool) so the copy_ in forward()
|
||||
# refreshes it correctly across captured shapes.
|
||||
@@ -1425,7 +1464,7 @@ class DeepseekV4Model(nn.Module):
|
||||
res_mix,
|
||||
residual,
|
||||
)
|
||||
else:
|
||||
if layer is not None and current_platform.is_cuda():
|
||||
hidden_states = layer.hc_post(hidden_states, residual, post_mix, res_mix)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
@@ -1435,7 +1474,7 @@ class DeepseekV4Model(nn.Module):
|
||||
num_tokens = hidden_states.shape[0]
|
||||
self._mtp_hidden_buffer[:num_tokens].copy_(hidden_states.flatten(1))
|
||||
|
||||
hidden_states = hc_head(
|
||||
hidden_states = self.hc_head_op(
|
||||
hidden_states,
|
||||
self.hc_head_fn,
|
||||
self.hc_head_scale,
|
||||
@@ -1455,9 +1494,6 @@ class DeepseekV4Model(nn.Module):
|
||||
("attn.fused_wqa_wkv", "attn.wkv", 1),
|
||||
("compressor.fused_wkv_wgate", "compressor.wkv", 0),
|
||||
("compressor.fused_wkv_wgate", "compressor.wgate", 1),
|
||||
# Indexer's compressor (same stacking pattern)
|
||||
("indexer.compressor.fused_wkv_wgate", "indexer.compressor.wkv", 0),
|
||||
("indexer.compressor.fused_wkv_wgate", "indexer.compressor.wgate", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
@@ -1473,14 +1509,6 @@ class DeepseekV4Model(nn.Module):
|
||||
# Pre-compute expert mapping ONCE.
|
||||
expert_mapping = self.get_expert_mapping()
|
||||
|
||||
# NVFP4 compressor/indexer scale params need special handling:
|
||||
# wkv.input_scale (shape [1]) + wgate.input_scale (shape [1])
|
||||
# must be concatenated into fused_wkv_wgate.input_scale (shape [2]).
|
||||
# The default stacking path fails because PerTensorScaleParameter's
|
||||
# weight_loader asserts shape equality.
|
||||
# We buffer them and load once both shards are available.
|
||||
compressor_scale_buffer: dict[str, dict[int, torch.Tensor]] = {}
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
@@ -1492,44 +1520,8 @@ class DeepseekV4Model(nn.Module):
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
break
|
||||
if name not in params_dict:
|
||||
# The stacked param doesn't exist — skip
|
||||
# (e.g. indexer.compressor.fused_wkv_wgate on layers
|
||||
# that don't have the full indexer structure)
|
||||
break
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
|
||||
# NVFP4 scale params for stacked fused_wkv_wgate need
|
||||
# special handling: each shard (wkv, wgate) has scale
|
||||
# shape [1] or [head_dim, K], but the fused param has
|
||||
# shape [2] or [2*head_dim, K]. The default stacking
|
||||
# weight_loader can't handle this for PerTensorScale or
|
||||
# ModelWeight scale params. Buffer and concatenate.
|
||||
is_compressor_scale = (
|
||||
"fused_wkv_wgate" in name
|
||||
and name.endswith((
|
||||
"input_scale",
|
||||
"weight_scale",
|
||||
"weight_scale_2",
|
||||
))
|
||||
)
|
||||
if is_compressor_scale:
|
||||
# Verify the fused param exists before buffering
|
||||
if name not in params_dict:
|
||||
print(
|
||||
f"COMPRESSOR_SCALE_SKIP: {name} not in params_dict",
|
||||
flush=True,
|
||||
)
|
||||
break
|
||||
if is_compressor_scale:
|
||||
# Buffer the shard for later concatenation
|
||||
if name not in compressor_scale_buffer:
|
||||
compressor_scale_buffer[name] = {}
|
||||
compressor_scale_buffer[name][shard_id] = loaded_weight
|
||||
loaded_params.add(name)
|
||||
break
|
||||
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
loaded_params.add(name)
|
||||
break
|
||||
@@ -1582,51 +1574,13 @@ class DeepseekV4Model(nn.Module):
|
||||
else:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name not in params_dict:
|
||||
print(f"Skipping weight {name} (not in model params)",
|
||||
flush=True)
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
try:
|
||||
weight_loader(param, loaded_weight)
|
||||
except (AssertionError, RuntimeError) as e:
|
||||
print(
|
||||
f"WEIGHT_LOAD_FAIL: name={name} "
|
||||
f"param_shape={param.data.shape if hasattr(param, 'data') else '?'} "
|
||||
f"loaded_shape={loaded_weight.shape} "
|
||||
f"loaded_dtype={loaded_weight.dtype} "
|
||||
f"error={e}",
|
||||
flush=True,
|
||||
)
|
||||
raise
|
||||
|
||||
# Load buffered compressor/indexer scale params.
|
||||
# These are NVFP4 quantization scales that need concatenation
|
||||
# across shards (wkv=shard0, wgate=shard1) before loading.
|
||||
for name, shards in compressor_scale_buffer.items():
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
if len(shards) == 2:
|
||||
# Concatenate shard 0 and shard 1 along dim 0.
|
||||
# Scales may be 0-dim scalars (input_scale, weight_scale_2)
|
||||
# or N-dim tensors (weight_scale); reshape scalars to 1-d.
|
||||
s0, s1 = shards[0], shards[1]
|
||||
if s0.ndim == 0:
|
||||
s0 = s0.reshape(1)
|
||||
if s1.ndim == 0:
|
||||
s1 = s1.reshape(1)
|
||||
stacked = torch.cat([s0, s1], dim=0)
|
||||
else:
|
||||
stacked = shards[0]
|
||||
assert param.data.shape == stacked.shape, (
|
||||
f"Scale shape mismatch for {name}: "
|
||||
f"param={param.data.shape} loaded={stacked.shape}"
|
||||
)
|
||||
param.data.copy_(stacked)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
|
||||
return loaded_params
|
||||
|
||||
@@ -1647,89 +1601,6 @@ class DeepseekV4Model(nn.Module):
|
||||
def finalize_mega_moe_weights(self) -> None:
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
layer.ffn.finalize_mega_moe_weights()
|
||||
# Initialize wo_a NVFP4 runner instead of quantizing to FP8
|
||||
attn = layer.attn
|
||||
if hasattr(attn, 'wo_a') and attn.wo_a.weight.dtype == torch.bfloat16:
|
||||
self._init_wo_a_nvfp4(attn)
|
||||
|
||||
@staticmethod
|
||||
def _init_wo_a_nvfp4(attn) -> None:
|
||||
"""Initialize CuTeDSL NVFP4 runner for wo_a.
|
||||
|
||||
Replaces the old _quantize_wo_a_to_fp8 approach. Instead of
|
||||
quantizing to FP8 and using DeepGEMM fp8_einsum (which crashes
|
||||
on Blackwell), we quantize to NVFP4 and use our CuTeDSL kernel.
|
||||
|
||||
wo_a is a grouped matmul (bmm) with n_local_groups groups.
|
||||
Each group: (tokens, heads_per_group * head_dim) × (heads_per_group * head_dim, o_lora_rank)
|
||||
"""
|
||||
from cutedsl.wo_a_grouped_linear import CuTeDSLNvfp4WoA
|
||||
|
||||
wo_a = attn.wo_a
|
||||
weight_bf16 = wo_a.weight.data # (out_features, in_features) = (n_groups * o_lora_rank, heads_per_group * head_dim)
|
||||
|
||||
n_local_groups = attn.n_local_groups
|
||||
heads_per_group = attn.n_local_heads // n_local_groups
|
||||
head_dim = attn.head_dim
|
||||
o_lora_rank = attn.o_lora_rank
|
||||
|
||||
runner = CuTeDSLNvfp4WoA(
|
||||
n_local_groups=n_local_groups,
|
||||
heads_per_group=heads_per_group,
|
||||
head_dim=head_dim,
|
||||
o_lora_rank=o_lora_rank,
|
||||
max_num_tokens=8192,
|
||||
device=weight_bf16.device,
|
||||
)
|
||||
|
||||
# The weight is (n_groups * o_lora_rank, heads_per_group * head_dim)
|
||||
# set_bf16_weight handles the 2D (dense) format
|
||||
runner.set_bf16_weight(weight_bf16)
|
||||
runner.finalize_weights()
|
||||
|
||||
# Warmup: compute activation global scale from sample data
|
||||
# This uses a representative random sample; the scale will be
|
||||
# recomputed on the first real forward pass with actual data.
|
||||
with torch.no_grad():
|
||||
sample = torch.randn(
|
||||
8, n_local_groups * heads_per_group, head_dim,
|
||||
dtype=torch.bfloat16, device=weight_bf16.device,
|
||||
) * 2.0
|
||||
runner._ensure_initialized()
|
||||
runner.compute_activation_global_scale(sample)
|
||||
|
||||
# Store the runner on the attention module
|
||||
attn._wo_a_nvfp4 = runner
|
||||
|
||||
|
||||
@torch.compile(backend=current_platform.simple_compile_backend)
|
||||
def hc_head(
|
||||
hidden_states: torch.Tensor,
|
||||
hc_fn: torch.Tensor,
|
||||
hc_scale: torch.Tensor,
|
||||
hc_base: torch.Tensor,
|
||||
rms_norm_eps: float,
|
||||
hc_eps: float,
|
||||
) -> torch.Tensor:
|
||||
hc_mult, hidden_size = hidden_states.shape[-2:]
|
||||
outer_shape = hidden_states.shape[:-2]
|
||||
hs_flat = hidden_states.view(-1, hc_mult, hidden_size)
|
||||
num_tokens = hs_flat.shape[0]
|
||||
out = torch.empty(
|
||||
num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device
|
||||
)
|
||||
torch.ops.vllm.hc_head_fused_kernel(
|
||||
hs_flat,
|
||||
hc_fn,
|
||||
hc_scale,
|
||||
hc_base,
|
||||
out,
|
||||
hidden_size,
|
||||
rms_norm_eps,
|
||||
hc_eps,
|
||||
hc_mult,
|
||||
)
|
||||
return out.view(*outer_shape, hidden_size)
|
||||
|
||||
|
||||
def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
|
||||
@@ -1761,7 +1632,13 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
|
||||
orig_to_new_suffix={
|
||||
"head.weight": "lm_head.weight",
|
||||
"embed.weight": "embed_tokens.weight",
|
||||
".ffn.gate.bias": ".ffn.gate.e_score_correction_bias",
|
||||
# Pre-MoE norm + gate are now owned by ``DeepseekV4MoE.norm_gate``
|
||||
# (see NormGatedLinear).
|
||||
".ffn_norm.weight": ".ffn.norm_gate.norm.weight",
|
||||
".ffn.gate.weight": ".ffn.norm_gate.gate.weight",
|
||||
".ffn.gate.bias": ".ffn.norm_gate.e_score_correction_bias",
|
||||
# Hash MoE table also moved off the inner gate.
|
||||
".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid",
|
||||
},
|
||||
orig_to_new_substr={
|
||||
".attn.compressor.": ".attn.mla_attn.compressor.",
|
||||
@@ -1770,18 +1647,16 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper:
|
||||
"""Weight mapper for NVFP4 (ModelOpt) DeepSeek-V4 checkpoints.
|
||||
|
||||
NVFP4 checkpoints use different key naming than the upstream MXFP4 format:
|
||||
- ``self_attn`` prefix instead of ``attn``
|
||||
- ``mlp`` prefix instead of ``ffn``
|
||||
- Expert weights: gate_proj/up_proj/down_proj (not w1/w3/w2)
|
||||
- Scales already have .weight_scale / .weight_scale_2 / .input_scale suffixes
|
||||
- Shared expert uses down_proj (not w2)
|
||||
- Self-attention uses .self_attn. prefix (same as checkpoint, renamed to .attn.)
|
||||
|
||||
This is the mapper that should be used when quantization is modelopt_fp4.
|
||||
- Compressor uses kv_proj/gate_proj (not wkv/wgate)
|
||||
- o_a_proj is BF16 (no quantization scales)
|
||||
"""
|
||||
expert_rename_regex = {
|
||||
re.compile(r"(\.experts\.\d+\.)gate_proj\."): r"\1w1.",
|
||||
@@ -1789,82 +1664,83 @@ def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper:
|
||||
re.compile(r"(\.experts\.\d+\.)down_proj\."): r"\1w2.",
|
||||
}
|
||||
|
||||
suffix_renames = {
|
||||
# The NVFP4 checkpoint already uses lm_head / embed_tokens directly,
|
||||
# no suffix renames needed (unlike the MXFP4 upstream format).
|
||||
}
|
||||
|
||||
# NOTE: specific renames MUST come before general ones (applied in order)
|
||||
substr_renames = {
|
||||
# === Indexer params (MUST come before .self_attn.compressor.
|
||||
# so that indexer keys are captured before the compressor prefix
|
||||
# rewrite moves them under mla_attn.compressor) ===
|
||||
# The checkpoint puts indexer under self_attn.compressor.indexer.*
|
||||
# but the model has indexer at attn.indexer.* (sibling of compressor,
|
||||
# NOT nested under it).
|
||||
".self_attn.compressor.indexer.q_b_proj.": ".attn.indexer.wq_b.",
|
||||
".self_attn.compressor.indexer.weights_proj.": ".attn.indexer.weights_proj.",
|
||||
".self_attn.compressor.indexer.kv_norm.": ".attn.indexer.k_norm.",
|
||||
".self_attn.compressor.indexer.kv_proj.": ".attn.indexer.compressor.wkv.",
|
||||
".self_attn.compressor.indexer.gate_proj.": ".attn.indexer.compressor.wgate.",
|
||||
".self_attn.compressor.indexer.position_bias": ".attn.indexer.compressor.ape",
|
||||
# === Compressor (non-indexer) NVFP4 renames ===
|
||||
# Checkpoint uses kv_proj/gate_proj, model uses wkv/wgate
|
||||
# (for stacking into fused_wkv_wgate).
|
||||
"compressor.kv_proj.": "compressor.wkv.",
|
||||
"compressor.gate_proj.": "compressor.wgate.",
|
||||
"compressor.kv_norm.": "compressor.norm.",
|
||||
"compressor.position_bias": "compressor.ape",
|
||||
# === Attention compressor (MUST come after indexer renames
|
||||
# so that remaining .self_attn.compressor. (non-indexer) keys
|
||||
# become .attn.mla_attn.compressor.) ===
|
||||
".self_attn.compressor.": ".attn.mla_attn.compressor.",
|
||||
# === Attention projections (specific before .self_attn. → .attn.) ===
|
||||
".self_attn.q_a_proj.": ".attn.wq_a.",
|
||||
".self_attn.kv_proj.": ".attn.wkv.",
|
||||
".self_attn.q_b_proj.": ".attn.wq_b.",
|
||||
".self_attn.o_a_proj.": ".attn.wo_a.",
|
||||
".self_attn.o_b_proj.": ".attn.wo_b.",
|
||||
".self_attn.q_a_norm.": ".attn.q_norm.",
|
||||
".self_attn.kv_norm.": ".attn.kv_norm.",
|
||||
".self_attn.sinks": ".attn.attn_sink",
|
||||
# Shared expert projections (specific before .mlp. → .ffn.)
|
||||
".mlp.shared_experts.gate_proj.": ".ffn.shared_experts.w1.",
|
||||
".mlp.shared_experts.up_proj.": ".ffn.shared_experts.w3.",
|
||||
".mlp.shared_experts.down_proj.": ".ffn.shared_experts.down_proj.",
|
||||
# General renames
|
||||
".mlp.": ".ffn.",
|
||||
".self_attn.": ".attn.",
|
||||
# Layer norms (checkpoint uses input_layernorm / post_attention_layernorm,
|
||||
# model uses attn_norm / ffn_norm)
|
||||
"input_layernorm.": "attn_norm.",
|
||||
"post_attention_layernorm.": "ffn_norm.",
|
||||
# Per-layer HC params (checkpoint uses attn_hc / ffn_hc with dot,
|
||||
# model uses hc_attn / hc_ffn with underscore)
|
||||
".attn_hc.fn": ".hc_attn_fn",
|
||||
".attn_hc.base": ".hc_attn_base",
|
||||
".attn_hc.scale": ".hc_attn_scale",
|
||||
".ffn_hc.fn": ".hc_ffn_fn",
|
||||
".ffn_hc.base": ".hc_ffn_base",
|
||||
".ffn_hc.scale": ".hc_ffn_scale",
|
||||
# Top-level hc_head params (checkpoint uses hc_head.fn etc,
|
||||
# model uses hc_head_fn etc)
|
||||
"hc_head.fn": "hc_head_fn",
|
||||
"hc_head.base": "hc_head_base",
|
||||
"hc_head.scale": "hc_head_scale",
|
||||
}
|
||||
|
||||
return WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"layers.": "model.layers.",
|
||||
"embed_tokens.": "model.embed_tokens.",
|
||||
"embed.": "model.embed.",
|
||||
"norm.": "model.norm.",
|
||||
"hc_head": "model.hc_head",
|
||||
"mtp.": "model.mtp.",
|
||||
},
|
||||
orig_to_new_regex=expert_rename_regex,
|
||||
orig_to_new_suffix=suffix_renames,
|
||||
orig_to_new_substr=substr_renames,
|
||||
# No suffix renames needed — NVFP4 checkpoint uses
|
||||
# .weight_scale / .weight_scale_2 / .input_scale directly.
|
||||
orig_to_new_suffix={
|
||||
"head.weight": "lm_head.weight",
|
||||
"embed.weight": "embed_tokens.weight",
|
||||
# Pre-MoE norm + gate are now owned by DeepseekV4MoE.norm_gate
|
||||
".ffn_norm.weight": ".ffn.norm_gate.norm.weight",
|
||||
".ffn.gate.weight": ".ffn.norm_gate.gate.weight",
|
||||
".ffn.gate.bias": ".ffn.norm_gate.e_score_correction_bias",
|
||||
".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid",
|
||||
},
|
||||
# Specific renames MUST come before general ones (applied in order).
|
||||
orig_to_new_substr={
|
||||
# Indexer params (MUST come before .self_attn.compressor.
|
||||
# so indexer keys are captured before the compressor prefix
|
||||
# rewrite moves them under mla_attn.compressor).
|
||||
".self_attn.compressor.indexer.q_b_proj.":
|
||||
".attn.indexer.wq_b.",
|
||||
".self_attn.compressor.indexer.weights_proj.":
|
||||
".attn.indexer.weights_proj.",
|
||||
".self_attn.compressor.indexer.kv_norm.":
|
||||
".attn.indexer.k_norm.",
|
||||
".self_attn.compressor.indexer.kv_proj.":
|
||||
".attn.indexer.compressor.wkv.",
|
||||
".self_attn.compressor.indexer.gate_proj.":
|
||||
".attn.indexer.compressor.wgate.",
|
||||
".self_attn.compressor.indexer.position_bias":
|
||||
".attn.indexer.compressor.ape",
|
||||
# Compressor (non-indexer) renames
|
||||
"compressor.kv_proj.": "compressor.wkv.",
|
||||
"compressor.gate_proj.": "compressor.wgate.",
|
||||
"compressor.kv_norm.": "compressor.norm.",
|
||||
"compressor.position_bias": "compressor.ape",
|
||||
# Attention compressor (after indexer renames)
|
||||
".self_attn.compressor.": ".attn.mla_attn.compressor.",
|
||||
# Attention projections (specific before .self_attn. → .attn.)
|
||||
".self_attn.q_a_proj.": ".attn.wq_a.",
|
||||
".self_attn.kv_proj.": ".attn.wkv.",
|
||||
".self_attn.q_b_proj.": ".attn.wq_b.",
|
||||
".self_attn.o_a_proj.": ".attn.wo_a.",
|
||||
".self_attn.o_b_proj.": ".attn.wo_b.",
|
||||
".self_attn.q_a_norm.": ".attn.q_norm.",
|
||||
".self_attn.kv_norm.": ".attn.kv_norm.",
|
||||
".self_attn.sinks": ".attn.attn_sink",
|
||||
# Shared expert projections (specific before .mlp. → .ffn.)
|
||||
".mlp.shared_experts.gate_proj.":
|
||||
".ffn.shared_experts.w1.",
|
||||
".mlp.shared_experts.up_proj.":
|
||||
".ffn.shared_experts.w3.",
|
||||
".mlp.shared_experts.down_proj.":
|
||||
".ffn.shared_experts.down_proj.",
|
||||
# General renames
|
||||
".mlp.": ".ffn.",
|
||||
".self_attn.": ".attn.",
|
||||
# Layer norms
|
||||
"input_layernorm.": "attn_norm.",
|
||||
"post_attention_layernorm.": "ffn_norm.",
|
||||
# HC params
|
||||
".attn_hc.fn": ".hc_attn_fn",
|
||||
".attn_hc.base": ".hc_attn_base",
|
||||
".attn_hc.scale": ".hc_attn_scale",
|
||||
".ffn_hc.fn": ".hc_ffn_fn",
|
||||
".ffn_hc.base": ".hc_ffn_base",
|
||||
".ffn_hc.scale": ".hc_ffn_scale",
|
||||
"hc_head.fn": "hc_head_fn",
|
||||
"hc_head.base": "hc_head_base",
|
||||
"hc_head.scale": "hc_head_scale",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1879,21 +1755,18 @@ class DeepseekV4ForCausalLM(nn.Module, SupportsPP):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.config = config
|
||||
expert_dtype = getattr(config, "expert_dtype", "fp4")
|
||||
quant_config = vllm_config.quant_config
|
||||
# Select weight mapper based on quantization method.
|
||||
# NVFP4 (modelopt_fp4) checkpoints use different key naming
|
||||
# than the default MXFP4 format.
|
||||
quant_config = vllm_config.quant_config
|
||||
if quant_config is not None and getattr(quant_config, "get_name", lambda: None)() == "modelopt_fp4":
|
||||
if (quant_config is not None
|
||||
and quant_config.get_name() == "modelopt_fp4"):
|
||||
self.hf_to_vllm_mapper = _make_deepseek_v4_nvfp4_weights_mapper()
|
||||
elif getattr(config, "expert_dtype", "fp4") != "fp4":
|
||||
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp8")
|
||||
else:
|
||||
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4")
|
||||
self.config = config
|
||||
expert_dtype = getattr(config, "expert_dtype", "fp4")
|
||||
if expert_dtype != "fp4":
|
||||
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper(expert_dtype)
|
||||
elif expert_dtype != "fp4":
|
||||
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper(
|
||||
expert_dtype)
|
||||
|
||||
self.model = self.model_cls(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
|
||||
@@ -14,6 +14,7 @@ import torch.nn.functional as F
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.breakable_cudagraph import eager_break_during_capture
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ReplicatedLinear,
|
||||
)
|
||||
@@ -46,14 +47,9 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor
|
||||
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import (
|
||||
QuantFP8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.multi_stream_utils import (
|
||||
execute_in_parallel,
|
||||
@@ -186,19 +182,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
|
||||
self.kv_norm = mla_modules.kv_norm
|
||||
self.wo_a = mla_modules.wo_a
|
||||
# NVFP4 runner for wo_a — replaces DeepGEMM fp8_einsum.
|
||||
# Initialized in DeepseekV4Model.finalize_mega_moe_weights()
|
||||
# after wo_a BF16 weights are loaded.
|
||||
self._wo_a_nvfp4 = None
|
||||
|
||||
self._wo_a_act_quant = QuantFP8(
|
||||
static=False,
|
||||
group_shape=GroupShape(1, 128),
|
||||
use_ue8m0=True,
|
||||
)
|
||||
# Bypass packed-for-deepgemm path — we need FP32 scales (not packed
|
||||
# INT32) so fp8_einsum can handle layout transform internally.
|
||||
self._wo_a_act_quant.use_deep_gemm_supported = False
|
||||
self.wo_b = mla_modules.wo_b
|
||||
|
||||
# Pick fp8_einsum recipe based on GPU arch:
|
||||
@@ -321,25 +304,29 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
)
|
||||
return self.wo_b(z.flatten(1))
|
||||
|
||||
# O projection: inverse RoPE + NVFP4 grouped GEMM + wo_b
|
||||
# Using our CuTeDSL NVFP4 kernel instead of DeepGEMM fp8_einsum
|
||||
if self._wo_a_nvfp4 is not None:
|
||||
from cutedsl.inverse_rope import inverse_rope_bf16
|
||||
o_inv = inverse_rope_bf16(
|
||||
# Detect if wo_a has FP8 weights (weight_scale_inv attribute).
|
||||
# NVFP4 checkpoints leave wo_a as BF16 (no quantization scales),
|
||||
# so we use inverse RoPE in BF16 + regular matmul instead of
|
||||
# the FP8 einsum path (which crashes on Blackwell SM100).
|
||||
has_fp8_weights = hasattr(self.wo_a, 'weight_scale_inv')
|
||||
|
||||
if not has_fp8_weights:
|
||||
# BF16 wo_a path: inverse RoPE in BF16, then regular matmul
|
||||
o_inv = _apply_inv_rope_bf16(
|
||||
o, positions,
|
||||
self.rotary_emb.cos_sin_cache.to(torch.float32),
|
||||
nope_dim=self.nope_head_dim,
|
||||
rope_dim=self.rope_head_dim,
|
||||
)
|
||||
# Activation global scale is computed during init (finalize_mega_moe_weights)
|
||||
z = self._wo_a_nvfp4(o_inv)
|
||||
z, _ = self.wo_a(o_inv.reshape(num_tokens, -1))
|
||||
z = z.view(num_tokens, self.n_local_groups, self.o_lora_rank)
|
||||
return self.wo_b(z.flatten(1))
|
||||
|
||||
# Fallback: original DeepGEMM path (for non-Blackwell or before init)
|
||||
# FP8 wo_a path: fused inverse RoPE + FP8 quant + einsum (SM90 only)
|
||||
o_fp8, o_scale = fused_inv_rope_fp8_quant(
|
||||
o,
|
||||
positions,
|
||||
self.rotary_emb.cos_sin_cache.to(torch.float32),
|
||||
self.rotary_emb.cos_sin_cache,
|
||||
n_groups=self.n_local_groups,
|
||||
heads_per_group=self.n_local_heads // self.n_local_groups,
|
||||
nope_dim=self.nope_head_dim,
|
||||
@@ -384,14 +371,11 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
compressor = self.compressor
|
||||
|
||||
def compressor_kv_score() -> torch.Tensor:
|
||||
# Use forward() for quantized layers (NVFP4, FP8, etc.)
|
||||
# — raw torch.mm doesn't work with packed/dequantized weights.
|
||||
# MergedColumnParallelLinear with return_bias=False returns
|
||||
# a tensor directly.
|
||||
result = compressor.fused_wkv_wgate(hidden_states)
|
||||
if isinstance(result, tuple):
|
||||
result = result[0]
|
||||
return result.to(torch.float32)
|
||||
return torch.mm(
|
||||
hidden_states,
|
||||
compressor.fused_wkv_wgate.weight.T,
|
||||
out_dtype=torch.float32,
|
||||
)
|
||||
|
||||
aux_fns[0] = compressor_kv_score
|
||||
|
||||
@@ -404,10 +388,11 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
return weights
|
||||
|
||||
def indexer_compressor_kv_score() -> torch.Tensor:
|
||||
result = indexer.compressor.fused_wkv_wgate(hidden_states)
|
||||
if isinstance(result, tuple):
|
||||
result = result[0]
|
||||
return result.to(torch.float32)
|
||||
return torch.mm(
|
||||
hidden_states,
|
||||
indexer.compressor.fused_wkv_wgate.weight.T,
|
||||
out_dtype=torch.float32,
|
||||
)
|
||||
|
||||
aux_fns[1] = indexer_weights_proj
|
||||
aux_fns[2] = indexer_compressor_kv_score
|
||||
@@ -567,12 +552,48 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
swa_kv_cache_2d,
|
||||
swa_metadata.slot_mapping,
|
||||
positions.to(torch.int64),
|
||||
self.rotary_emb.cos_sin_cache.to(torch.float32),
|
||||
self.rotary_emb.cos_sin_cache,
|
||||
self.eps,
|
||||
swa_metadata.block_size,
|
||||
)
|
||||
|
||||
|
||||
def _apply_inv_rope_bf16(
|
||||
o: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
nope_dim: int,
|
||||
rope_dim: int,
|
||||
) -> torch.Tensor:
|
||||
"""Apply inverse RoPE to attention output in BF16.
|
||||
|
||||
Inverse RoPE is just RoPE with cos → cos, sin → -sin.
|
||||
Uses GPT-J style (interleaved) rotary embedding.
|
||||
"""
|
||||
if rope_dim == 0 or o.numel() == 0:
|
||||
return o
|
||||
half_rot = rope_dim // 2
|
||||
o_f32 = o.to(torch.float32)
|
||||
cache = cos_sin_cache.index_select(0, positions.to(torch.long))
|
||||
cos = cache[:, :half_rot].to(torch.float32)
|
||||
sin = cache[:, half_rot : 2 * half_rot].to(torch.float32)
|
||||
view_shape = (positions.shape[0], 1, half_rot)
|
||||
cos = cos.view(view_shape)
|
||||
sin = sin.view(view_shape)
|
||||
rope = o_f32[..., nope_dim:]
|
||||
y_even = rope[..., 0::2]
|
||||
y_odd = rope[..., 1::2]
|
||||
# Inverse: sin → -sin
|
||||
rope_out = torch.stack(
|
||||
(y_even * cos - y_odd * sin, y_odd * cos + y_even * sin),
|
||||
dim=-1,
|
||||
).flatten(-2)
|
||||
o_f32 = o_f32.clone()
|
||||
o_f32[..., nope_dim:] = rope_out
|
||||
return o_f32.to(o.dtype)
|
||||
|
||||
|
||||
@eager_break_during_capture
|
||||
def deepseek_v4_attention(
|
||||
hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
@@ -1126,10 +1147,9 @@ class DeepseekV4Indexer(nn.Module):
|
||||
hidden_size,
|
||||
self.n_head,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.weights_proj",
|
||||
)
|
||||
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
|
||||
self.scale_fmt = "ue8m0"
|
||||
|
||||
Reference in New Issue
Block a user