Use nightly's deepseek_v4.py + attention as base, add only NVFP4 mapper
The upstream deepseek_v4.py has imports that don't exist in the nightly Docker image (norm_gate_linear, breakable_cudagraph, etc.). Use the nightly's own files as the base and add only the minimal NVFP4 changes: - Add _make_deepseek_v4_nvfp4_weights_mapper() for checkpoint key mapping - Select NVFP4 mapper when quant_config is modelopt_fp4 - cos_sin_cache float32 fix in attention - Remove utils.py patch (not needed)
This commit is contained in:
@@ -23,14 +23,11 @@ from vllm.model_executor.layers.deepseek_v4_attention import (
|
||||
DeepseekV4MLAModules,
|
||||
DeepseekV4MultiHeadLatentAttentionWrapper,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
|
||||
fused_topk_bias,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.norm_gate_linear import (
|
||||
NormGateLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@@ -38,12 +35,6 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mhc import (
|
||||
HCHeadOp,
|
||||
MHCFusedPostPreOp,
|
||||
MHCPostOp,
|
||||
MHCPreOp,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import (
|
||||
QuantizationConfig,
|
||||
QuantizationMethods,
|
||||
@@ -758,23 +749,23 @@ class DeepseekV4MoE(nn.Module):
|
||||
"deep_gemm_mega_moe for this checkpoint."
|
||||
)
|
||||
|
||||
# Fused RMSNorm + gate: owns both ffn_norm and the gate matmul.
|
||||
self.norm_gate = NormGateLinear(
|
||||
hidden_size=config.hidden_size,
|
||||
num_experts=config.n_routed_experts,
|
||||
rms_eps=config.rms_norm_eps,
|
||||
prefix=f"{prefix}.norm_gate",
|
||||
self.gate = GateLinear(
|
||||
config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
out_dtype=torch.float32,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
# Routing-side tensors live on ``norm_gate`` directly (not on the
|
||||
# inner gate); they are initialized to None in NormGatedLinear and
|
||||
# populated below depending on the MoE variant.
|
||||
self.gate.e_score_correction_bias = None
|
||||
self.gate.tid2eid = None
|
||||
is_hash_moe = extract_layer_index(prefix) < config.num_hash_layers
|
||||
self.hash_indices_dtype = torch.int64 if self.use_mega_moe else torch.int32
|
||||
|
||||
if is_hash_moe:
|
||||
# hash MoE doesn't use e_score_correction_bias
|
||||
# Use randint instead of empty to avoid garbage values causing
|
||||
# invalid memory access in dummy mode (--load-format="dummy")
|
||||
self.norm_gate.tid2eid = nn.Parameter(
|
||||
self.gate.tid2eid = nn.Parameter(
|
||||
torch.randint(
|
||||
0,
|
||||
config.n_routed_experts,
|
||||
@@ -784,7 +775,7 @@ class DeepseekV4MoE(nn.Module):
|
||||
requires_grad=False,
|
||||
)
|
||||
elif getattr(config, "topk_method", None) == "noaux_tc":
|
||||
self.norm_gate.e_score_correction_bias = nn.Parameter(
|
||||
self.gate.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(config.n_routed_experts, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -847,9 +838,10 @@ class DeepseekV4MoE(nn.Module):
|
||||
self.n_local_experts = config.n_routed_experts // self.tp_size
|
||||
self.experts_start_idx = self.tp_rank * self.n_local_experts
|
||||
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
||||
# We don't pass `gate` into FusedMoE
|
||||
|
||||
self.experts = FusedMoE(
|
||||
shared_experts=self.shared_experts,
|
||||
gate=self.gate,
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -859,8 +851,8 @@ class DeepseekV4MoE(nn.Module):
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.norm_gate.e_score_correction_bias,
|
||||
hash_indices_table=self.norm_gate.tid2eid,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
hash_indices_table=self.gate.tid2eid,
|
||||
swiglu_limit=self.swiglu_limit,
|
||||
router_logits_dtype=torch.float32,
|
||||
)
|
||||
@@ -868,40 +860,40 @@ class DeepseekV4MoE(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
if self.norm_gate.tid2eid is not None and input_ids is None:
|
||||
if self.gate.tid2eid is not None and input_ids is None:
|
||||
raise ValueError("DeepSeek V4 hash MoE routing requires input_ids.")
|
||||
|
||||
if not self.use_mega_moe:
|
||||
return self._forward_fused_moe(hidden_states, input_ids)
|
||||
|
||||
org_shape = hidden_states.shape
|
||||
normed_x, router_logits = self.norm_gate(hidden_states)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
topk_weights, topk_ids = fused_topk_bias(
|
||||
hidden_states=normed_x,
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.norm_gate.e_score_correction_bias.data
|
||||
if self.norm_gate.e_score_correction_bias is not None
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias.data
|
||||
if self.gate.e_score_correction_bias is not None
|
||||
else None,
|
||||
topk=self.n_activated_experts,
|
||||
renormalize=self.renormalize,
|
||||
indices_type=self.hash_indices_dtype,
|
||||
input_tokens=input_ids,
|
||||
hash_indices_table=self.norm_gate.tid2eid,
|
||||
hash_indices_table=self.gate.tid2eid,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
activation_clamp = (
|
||||
float(self.swiglu_limit) if self.swiglu_limit is not None else None
|
||||
)
|
||||
final_hidden_states = self.experts(
|
||||
normed_x,
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation_clamp=activation_clamp,
|
||||
)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(normed_x)
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
final_hidden_states += shared_output
|
||||
|
||||
return final_hidden_states.view(org_shape)
|
||||
@@ -909,14 +901,21 @@ class DeepseekV4MoE(nn.Module):
|
||||
def _forward_fused_moe(
|
||||
self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
assert not self.experts.is_internal_router
|
||||
org_shape = hidden_states.shape
|
||||
normed_x, router_logits = self.norm_gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=normed_x,
|
||||
router_logits=router_logits,
|
||||
input_ids=input_ids,
|
||||
)
|
||||
if self.experts.is_internal_router:
|
||||
# In this case, the gate/router runs inside the FusedMoE class
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=hidden_states,
|
||||
input_ids=input_ids,
|
||||
)
|
||||
else:
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
input_ids=input_ids,
|
||||
)
|
||||
|
||||
return final_hidden_states.view(org_shape)
|
||||
|
||||
@@ -1120,8 +1119,7 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
self.ffn = DeepseekV4MoE(vllm_config, prefix=f"{prefix}.ffn")
|
||||
|
||||
self.attn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps)
|
||||
# ``ffn_norm`` is owned by ``self.ffn.norm_gate`` (fused with the
|
||||
# router gate matmul); see ``NormGatedLinear``.
|
||||
self.ffn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps)
|
||||
self.hc_mult = config.hc_mult
|
||||
self.hc_sinkhorn_iters = config.hc_sinkhorn_iters
|
||||
self.hc_eps = config.hc_eps
|
||||
@@ -1170,9 +1168,6 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.mhc_pre = MHCPreOp()
|
||||
self.mhc_post = MHCPostOp()
|
||||
self.mhc_fused_post_pre = MHCFusedPostPreOp()
|
||||
|
||||
def hc_pre(
|
||||
self,
|
||||
@@ -1181,7 +1176,7 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
hc_scale: torch.Tensor,
|
||||
hc_base: torch.Tensor,
|
||||
):
|
||||
post_mix, res_mix, layer_input = self.mhc_pre(
|
||||
post_mix, res_mix, layer_input = torch.ops.vllm.mhc_pre(
|
||||
residual=x,
|
||||
fn=hc_fn,
|
||||
hc_scale=hc_scale,
|
||||
@@ -1201,17 +1196,17 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
post: torch.Tensor,
|
||||
comb: torch.Tensor,
|
||||
):
|
||||
return self.mhc_post(x, residual, post, comb)
|
||||
return torch.ops.vllm.mhc_post(x, residual, post, comb)
|
||||
|
||||
def _forward_cuda(
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_ids: torch.Tensor | None,
|
||||
post_mix: torch.Tensor | None = None,
|
||||
res_mix: torch.Tensor | None = None,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
post_mix: torch.Tensor | None,
|
||||
res_mix: torch.Tensor | None,
|
||||
residual: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
if residual is None:
|
||||
# Run standalone hc_pre on first layer
|
||||
residual = x
|
||||
@@ -1219,7 +1214,7 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base
|
||||
)
|
||||
else:
|
||||
residual, post_mix, res_mix, x = self.mhc_fused_post_pre(
|
||||
residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre(
|
||||
x,
|
||||
residual,
|
||||
post_mix,
|
||||
@@ -1237,7 +1232,7 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
x = self.attn_norm(x)
|
||||
x = self.attn(positions, x, None)
|
||||
|
||||
residual, post_mix, res_mix, x = self.mhc_fused_post_pre(
|
||||
residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre(
|
||||
x,
|
||||
residual,
|
||||
post_mix,
|
||||
@@ -1251,65 +1246,29 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
self.hc_post_alpha,
|
||||
self.hc_sinkhorn_iters,
|
||||
)
|
||||
# ffn_norm is now folded into self.ffn.norm_gate; ffn() takes
|
||||
# the pre-norm activation directly.
|
||||
|
||||
x = self.ffn_norm(x)
|
||||
x = self.ffn(x, input_ids)
|
||||
return x, residual, post_mix, res_mix
|
||||
|
||||
def _forward_rocm(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_ids: torch.Tensor | None,
|
||||
post_mix: torch.Tensor | None = None,
|
||||
res_mix: torch.Tensor | None = None,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> tuple[
|
||||
torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None
|
||||
]:
|
||||
residual = x
|
||||
x, post, comb = self.hc_pre(
|
||||
x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base
|
||||
)
|
||||
x = self.attn_norm(x)
|
||||
x = self.attn(positions, x, None)
|
||||
x = self.hc_post(x, residual, post, comb)
|
||||
|
||||
residual = x
|
||||
x, post, comb = self.hc_pre(
|
||||
x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base
|
||||
)
|
||||
# ffn_norm is now folded into self.ffn.norm_gate; ffn() takes
|
||||
# the pre-norm activation directly.
|
||||
x = self.ffn(x, input_ids)
|
||||
x = self.hc_post(x, residual, post, comb)
|
||||
return x, None, None, None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_ids: torch.Tensor | None,
|
||||
post_mix: torch.Tensor | None = None,
|
||||
res_mix: torch.Tensor | None = None,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> tuple[
|
||||
torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None
|
||||
]:
|
||||
if current_platform.is_rocm():
|
||||
return self._forward_rocm(
|
||||
x, positions, input_ids, post_mix, res_mix, residual
|
||||
)
|
||||
|
||||
return self._forward_cuda(x, positions, input_ids, post_mix, res_mix, residual)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class DeepseekV4Model(nn.Module):
|
||||
def __init__(self, *, vllm_config: Vllm_config, prefix: str = ""):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
# Select weight mapper based on quantization method.
|
||||
# NVFP4 (modelopt_fp4) checkpoints use different key naming
|
||||
# than the default MXFP4 format.
|
||||
quant_config = vllm_config.quant_config
|
||||
if quant_config is not None and getattr(quant_config, "get_name", lambda: None)() == "modelopt_fp4":
|
||||
self.hf_to_vllm_mapper = _make_deepseek_v4_nvfp4_weights_mapper()
|
||||
elif getattr(config, "expert_dtype", "fp4") != "fp4":
|
||||
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp8")
|
||||
else:
|
||||
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4")
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.use_mega_moe = (
|
||||
@@ -1392,7 +1351,7 @@ class DeepseekV4Model(nn.Module):
|
||||
torch.empty(1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.hc_head_op = HCHeadOp()
|
||||
|
||||
# Pre-hc_head residual stream buffer for the MTP draft. Stable
|
||||
# address (outside the cudagraph pool) so the copy_ in forward()
|
||||
# refreshes it correctly across captured shapes.
|
||||
@@ -1462,7 +1421,7 @@ class DeepseekV4Model(nn.Module):
|
||||
res_mix,
|
||||
residual,
|
||||
)
|
||||
if layer is not None and current_platform.is_cuda():
|
||||
else:
|
||||
hidden_states = layer.hc_post(hidden_states, residual, post_mix, res_mix)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
@@ -1472,7 +1431,7 @@ class DeepseekV4Model(nn.Module):
|
||||
num_tokens = hidden_states.shape[0]
|
||||
self._mtp_hidden_buffer[:num_tokens].copy_(hidden_states.flatten(1))
|
||||
|
||||
hidden_states = self.hc_head_op(
|
||||
hidden_states = hc_head(
|
||||
hidden_states,
|
||||
self.hc_head_fn,
|
||||
self.hc_head_scale,
|
||||
@@ -1601,6 +1560,36 @@ class DeepseekV4Model(nn.Module):
|
||||
layer.ffn.finalize_mega_moe_weights()
|
||||
|
||||
|
||||
@torch.compile(backend=current_platform.simple_compile_backend)
|
||||
def hc_head(
|
||||
hidden_states: torch.Tensor,
|
||||
hc_fn: torch.Tensor,
|
||||
hc_scale: torch.Tensor,
|
||||
hc_base: torch.Tensor,
|
||||
rms_norm_eps: float,
|
||||
hc_eps: float,
|
||||
) -> torch.Tensor:
|
||||
hc_mult, hidden_size = hidden_states.shape[-2:]
|
||||
outer_shape = hidden_states.shape[:-2]
|
||||
hs_flat = hidden_states.view(-1, hc_mult, hidden_size)
|
||||
num_tokens = hs_flat.shape[0]
|
||||
out = torch.empty(
|
||||
num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device
|
||||
)
|
||||
torch.ops.vllm.hc_head_fused_kernel(
|
||||
hs_flat,
|
||||
hc_fn,
|
||||
hc_scale,
|
||||
hc_base,
|
||||
out,
|
||||
hidden_size,
|
||||
rms_norm_eps,
|
||||
hc_eps,
|
||||
hc_mult,
|
||||
)
|
||||
return out.view(*outer_shape, hidden_size)
|
||||
|
||||
|
||||
def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
|
||||
if expert_dtype == "fp4":
|
||||
# MXFP4 experts use Mxfp4MoEMethod, which registers scales as
|
||||
@@ -1630,13 +1619,7 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
|
||||
orig_to_new_suffix={
|
||||
"head.weight": "lm_head.weight",
|
||||
"embed.weight": "embed_tokens.weight",
|
||||
# Pre-MoE norm + gate are now owned by ``DeepseekV4MoE.norm_gate``
|
||||
# (see NormGatedLinear).
|
||||
".ffn_norm.weight": ".ffn.norm_gate.norm.weight",
|
||||
".ffn.gate.weight": ".ffn.norm_gate.gate.weight",
|
||||
".ffn.gate.bias": ".ffn.norm_gate.e_score_correction_bias",
|
||||
# Hash MoE table also moved off the inner gate.
|
||||
".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid",
|
||||
".ffn.gate.bias": ".ffn.gate.e_score_correction_bias",
|
||||
},
|
||||
orig_to_new_substr={
|
||||
".attn.compressor.": ".attn.mla_attn.compressor.",
|
||||
@@ -1655,21 +1638,15 @@ def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper:
|
||||
- Scales already have .weight_scale / .weight_scale_2 / .input_scale suffixes
|
||||
- Shared expert uses down_proj (not w2)
|
||||
- Self-attention uses .self_attn. prefix (same as checkpoint, renamed to .attn.)
|
||||
- Hadamard coding uses .attn_hc. and .ffn_hc. prefixes
|
||||
|
||||
This is the mapper that should be used when quantization is modelopt_fp4.
|
||||
"""
|
||||
# Expert weight renames: gate_proj→w1, up_proj→w3, down_proj→w2
|
||||
# Must match BEFORE the general suffix renames
|
||||
expert_rename_regex = {
|
||||
re.compile(r"(\.experts\.\d+\.)gate_proj\."): r"\1w1.",
|
||||
re.compile(r"(\.experts\.\d+\.)up_proj\."): r"\1w3.",
|
||||
re.compile(r"(\.experts\.\d+\.)down_proj\."): r"\1w2.",
|
||||
}
|
||||
|
||||
# Suffix renames for non-expert keys
|
||||
# NVFP4 checkpoints already use .weight_scale (not .scale), so no scale→weight_scale mapping needed
|
||||
# But .self_attn. → .attn. and .mlp. → .ffn. renames are needed
|
||||
suffix_renames = {
|
||||
"head.weight": "lm_head.weight",
|
||||
"embed.weight": "embed_tokens.weight",
|
||||
@@ -1679,7 +1656,6 @@ def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper:
|
||||
".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid",
|
||||
}
|
||||
|
||||
# Substr renames
|
||||
substr_renames = {
|
||||
".attn.compressor.": ".attn.mla_attn.compressor.",
|
||||
".mlp.shared_experts.gate_proj.": ".ffn.shared_experts.w1.",
|
||||
@@ -1687,8 +1663,6 @@ def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper:
|
||||
".mlp.shared_experts.down_proj.": ".ffn.shared_experts.down_proj.",
|
||||
".mlp.": ".ffn.",
|
||||
".self_attn.": ".attn.",
|
||||
".attn_hc.": ".attn.hc_op.",
|
||||
".ffn_hc.": ".ffn.hc_op.",
|
||||
}
|
||||
|
||||
return WeightsMapper(
|
||||
@@ -1696,8 +1670,6 @@ def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper:
|
||||
"layers.": "model.layers.",
|
||||
"embed.": "model.embed.",
|
||||
"norm.": "model.norm.",
|
||||
"hc_head": "model.hc_head",
|
||||
"mtp.": "model.mtp.",
|
||||
},
|
||||
orig_to_new_regex=expert_rename_regex,
|
||||
orig_to_new_suffix=suffix_renames,
|
||||
|
||||
@@ -46,7 +46,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import (
|
||||
QuantFP8,
|
||||
@@ -1109,6 +1109,7 @@ class DeepseekV4Indexer(nn.Module):
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.weights_proj",
|
||||
)
|
||||
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
|
||||
self.scale_fmt = "ue8m0"
|
||||
|
||||
@@ -1,289 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for selecting and loading models."""
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import (
|
||||
Attention,
|
||||
MLAAttention,
|
||||
MMEncoderAttention,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.model_loader.reload import (
|
||||
record_metadata_for_reloading,
|
||||
set_torchao_reload_attrs,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import SupportsQuant
|
||||
from vllm.tracing import instrument
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import get_accelerator_view_from_cpu_tensor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@instrument(span_name="Initialize model")
|
||||
def initialize_model(
|
||||
vllm_config: VllmConfig,
|
||||
*,
|
||||
prefix: str = "",
|
||||
model_class: type[nn.Module] | None = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
) -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
if model_config is None:
|
||||
model_config = vllm_config.model_config
|
||||
if model_class is None:
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
|
||||
if vllm_config.quant_config is not None:
|
||||
configure_quant_config(vllm_config.quant_config, model_class)
|
||||
|
||||
signatures = inspect.signature(model_class.__init__)
|
||||
all_params = [param.name for param in signatures.parameters.values()]
|
||||
if "vllm_config" in all_params and "prefix" in all_params:
|
||||
# new-style model class
|
||||
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
|
||||
model = model_class(vllm_config=vllm_config, prefix=prefix)
|
||||
record_metadata_for_reloading(model)
|
||||
return model
|
||||
|
||||
msg = (
|
||||
"vLLM model class should accept `vllm_config` and `prefix` as "
|
||||
"input arguments. Possibly you have an old-style model class"
|
||||
" registered from out of tree and it is used for new vLLM version. "
|
||||
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
|
||||
"for the design and update the model class accordingly."
|
||||
)
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
logger.warning(
|
||||
"Trying to guess the arguments for old-style model class %s",
|
||||
model_class,
|
||||
)
|
||||
# try to be compatible with old-style model class
|
||||
kwargs: dict[str, Any] = {}
|
||||
if "prefix" in all_params:
|
||||
kwargs["prefix"] = prefix
|
||||
if "config" in all_params:
|
||||
kwargs["config"] = model_config.hf_config
|
||||
if "cache_config" in all_params:
|
||||
kwargs["cache_config"] = vllm_config.cache_config
|
||||
if "quant_config" in all_params:
|
||||
kwargs["quant_config"] = vllm_config.quant_config
|
||||
if "lora_config" in all_params:
|
||||
kwargs["lora_config"] = vllm_config.lora_config
|
||||
if "scheduler_config" in all_params:
|
||||
kwargs["scheduler_config"] = vllm_config.scheduler_config
|
||||
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
|
||||
model = model_class(**kwargs)
|
||||
record_metadata_for_reloading(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def process_weights_after_loading(
|
||||
model: nn.Module, model_config: ModelConfig, target_device: torch.device
|
||||
) -> None:
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if isinstance(quant_method, QuantizeMethodBase):
|
||||
# When quant methods need to process weights after loading
|
||||
# (for repacking, quantizing, etc), they expect parameters
|
||||
# to be on the global target device. This scope is for the
|
||||
# case where cpu offloading is used, where we will move the
|
||||
# parameters onto device for processing and back off after.
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
|
||||
# Initialize post-load attention weights for Attention, MLA, and MM encoder.
|
||||
# NOTE: Happens after other modules so we can easily decompress weights.
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(
|
||||
module, (Attention, MLAAttention, MMEncoderAttention)
|
||||
) and hasattr(module, "process_weights_after_loading"):
|
||||
# TODO(lucas): see if there is a way to unify the signatures
|
||||
# of process_weights_after_loading
|
||||
with device_loading_context(module, target_device):
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
|
||||
if model_config.quantization == "torchao":
|
||||
set_torchao_reload_attrs(model, model_config)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
|
||||
if target_device.type == "cpu":
|
||||
# If target is CPU, no need to move anything
|
||||
yield module
|
||||
return
|
||||
|
||||
original_device_states: dict[str, torch.device] = {}
|
||||
uva_offloaded_parameters: list[str] = []
|
||||
|
||||
# Store original device states and move parameters to GPU if they're on CPU
|
||||
for name, p in module.named_parameters():
|
||||
if p.device.type == "cpu":
|
||||
original_device_states[name] = p.device
|
||||
p.data = p.data.to(target_device)
|
||||
if getattr(p, "_vllm_is_uva_offloaded", False):
|
||||
uva_offloaded_parameters.append(name)
|
||||
# Parameters already on target device are not touched
|
||||
|
||||
try:
|
||||
yield module
|
||||
|
||||
finally:
|
||||
use_pin_memory = (
|
||||
is_pin_memory_available()
|
||||
and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY
|
||||
)
|
||||
# Restore parameters to their original devices, ignoring new parameters
|
||||
for name, p in module.named_parameters():
|
||||
if name in original_device_states:
|
||||
original_device: torch.device = original_device_states[name]
|
||||
p.data = p.data.to(original_device)
|
||||
|
||||
# parameter is UVA offloaded, but was replaced with a new device tensor
|
||||
# re-offload it to CPU using UVA
|
||||
if name in uva_offloaded_parameters and not getattr(
|
||||
p, "_vllm_is_uva_offloaded", False
|
||||
):
|
||||
cpu_data = p.data.to(device="cpu")
|
||||
if use_pin_memory:
|
||||
cpu_data = cpu_data.pin_memory()
|
||||
p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
|
||||
p._vllm_is_uva_offloaded = True
|
||||
|
||||
|
||||
_MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
|
||||
"""Caches the outputs of `_get_model_architecture`."""
|
||||
|
||||
|
||||
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
from vllm.model_executor.models.adapters import as_embedding_model, as_seq_cls_model
|
||||
|
||||
architectures = getattr(model_config.hf_config, "architectures", None) or []
|
||||
|
||||
model_cls, arch = model_config.registry.resolve_model_cls(
|
||||
architectures,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if arch == model_config._get_transformers_backend_cls():
|
||||
assert model_config.model_impl != "vllm"
|
||||
if model_config.model_impl == "auto":
|
||||
logger.warning_once(
|
||||
"%s has no vLLM implementation, falling back to Transformers "
|
||||
"implementation. Some features may not be supported and "
|
||||
"performance may not be optimal.",
|
||||
arch,
|
||||
)
|
||||
|
||||
convert_type = model_config.convert_type
|
||||
if convert_type == "none":
|
||||
pass
|
||||
elif convert_type == "embed":
|
||||
logger.debug_once("Converting to embedding model.")
|
||||
model_cls = as_embedding_model(model_cls)
|
||||
elif convert_type == "classify":
|
||||
logger.debug_once("Converting to sequence classification model.")
|
||||
model_cls = as_seq_cls_model(model_cls)
|
||||
else:
|
||||
assert_never(convert_type)
|
||||
|
||||
return model_cls, arch
|
||||
|
||||
|
||||
def get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
key = hash(
|
||||
(
|
||||
model_config.model,
|
||||
model_config.convert_type,
|
||||
model_config.runner_type,
|
||||
model_config.trust_remote_code,
|
||||
model_config.model_impl,
|
||||
tuple(getattr(model_config.hf_config, "architectures", None) or []),
|
||||
)
|
||||
)
|
||||
if key in _MODEL_ARCH_BY_HASH:
|
||||
return _MODEL_ARCH_BY_HASH[key]
|
||||
|
||||
model_arch = _get_model_architecture(model_config)
|
||||
_MODEL_ARCH_BY_HASH[key] = model_arch
|
||||
return model_arch
|
||||
|
||||
|
||||
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
|
||||
return get_model_architecture(model_config)[0]
|
||||
|
||||
|
||||
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
||||
return get_model_architecture(model_config)[1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamMapping:
|
||||
"""
|
||||
A class to handle parameter mapping for model weight loading.
|
||||
It creates a bidirectional mapping between packed parameters and their
|
||||
constituent parts.
|
||||
"""
|
||||
|
||||
packed_mapping: dict[str, list[str]]
|
||||
inverse_packed_mapping: dict[str, tuple[str, int]] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
for packed_name, sub_params in self.packed_mapping.items():
|
||||
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
|
||||
if len(sub_params) == 1 and sub_params[0] == packed_name:
|
||||
continue
|
||||
for index, param_name in enumerate(sub_params):
|
||||
self.inverse_packed_mapping[param_name] = (
|
||||
packed_name,
|
||||
index,
|
||||
)
|
||||
|
||||
def get_sub_modules(self, module_name: str) -> tuple[str, list[str]] | None:
|
||||
for key, value in self.packed_mapping.items():
|
||||
if module_name.endswith(key):
|
||||
return key, value
|
||||
return None
|
||||
|
||||
|
||||
def configure_quant_config(
|
||||
quant_config: QuantizationConfig, model_class: type[nn.Module]
|
||||
):
|
||||
"""
|
||||
Pass packed_modules_mapping by reference to quant_config so that
|
||||
quant_config can properly match fused modules
|
||||
|
||||
Note that model attributes are passed by reference to quant_config,
|
||||
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
|
||||
|
||||
Once the `SupportsQuant` mixin has been added to all models, this
|
||||
function can be removed
|
||||
"""
|
||||
if not issubclass(model_class, SupportsQuant):
|
||||
hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
|
||||
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
|
||||
|
||||
# pass mappings by reference to quant_config
|
||||
if hf_to_vllm_mapper is not None:
|
||||
quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
|
||||
if packed_mapping is not None:
|
||||
quant_config.packed_modules_mapping = packed_mapping
|
||||
Reference in New Issue
Block a user