Use nightly's deepseek_v4.py + attention as base, add only NVFP4 mapper

The upstream deepseek_v4.py has imports that don't exist in the nightly
Docker image (norm_gate_linear, breakable_cudagraph, etc.). Use the
nightly's own files as the base and add only the minimal NVFP4 changes:
- Add _make_deepseek_v4_nvfp4_weights_mapper() for checkpoint key mapping
- Select NVFP4 mapper when quant_config is modelopt_fp4
- cos_sin_cache float32 fix in attention
- Remove utils.py patch (not needed)
This commit is contained in:
2026-05-18 22:33:51 +00:00
parent a19ed4a18e
commit 7409204d71
3 changed files with 98 additions and 414 deletions

View File

@@ -23,14 +23,11 @@ from vllm.model_executor.layers.deepseek_v4_attention import (
DeepseekV4MLAModules,
DeepseekV4MultiHeadLatentAttentionWrapper,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
fused_topk_bias,
)
from vllm.model_executor.layers.fused_moe.router.norm_gate_linear import (
NormGateLinear,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@@ -38,12 +35,6 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mhc import (
HCHeadOp,
MHCFusedPostPreOp,
MHCPostOp,
MHCPreOp,
)
from vllm.model_executor.layers.quantization import (
QuantizationConfig,
QuantizationMethods,
@@ -758,23 +749,23 @@ class DeepseekV4MoE(nn.Module):
"deep_gemm_mega_moe for this checkpoint."
)
# Fused RMSNorm + gate: owns both ffn_norm and the gate matmul.
self.norm_gate = NormGateLinear(
hidden_size=config.hidden_size,
num_experts=config.n_routed_experts,
rms_eps=config.rms_norm_eps,
prefix=f"{prefix}.norm_gate",
self.gate = GateLinear(
config.hidden_size,
config.n_routed_experts,
out_dtype=torch.float32,
bias=False,
prefix=f"{prefix}.gate",
)
# Routing-side tensors live on ``norm_gate`` directly (not on the
# inner gate); they are initialized to None in NormGatedLinear and
# populated below depending on the MoE variant.
self.gate.e_score_correction_bias = None
self.gate.tid2eid = None
is_hash_moe = extract_layer_index(prefix) < config.num_hash_layers
self.hash_indices_dtype = torch.int64 if self.use_mega_moe else torch.int32
if is_hash_moe:
# hash MoE doesn't use e_score_correction_bias
# Use randint instead of empty to avoid garbage values causing
# invalid memory access in dummy mode (--load-format="dummy")
self.norm_gate.tid2eid = nn.Parameter(
self.gate.tid2eid = nn.Parameter(
torch.randint(
0,
config.n_routed_experts,
@@ -784,7 +775,7 @@ class DeepseekV4MoE(nn.Module):
requires_grad=False,
)
elif getattr(config, "topk_method", None) == "noaux_tc":
self.norm_gate.e_score_correction_bias = nn.Parameter(
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts, dtype=torch.float32),
requires_grad=False,
)
@@ -847,9 +838,10 @@ class DeepseekV4MoE(nn.Module):
self.n_local_experts = config.n_routed_experts // self.tp_size
self.experts_start_idx = self.tp_rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
# We don't pass `gate` into FusedMoE
self.experts = FusedMoE(
shared_experts=self.shared_experts,
gate=self.gate,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
@@ -859,8 +851,8 @@ class DeepseekV4MoE(nn.Module):
prefix=f"{prefix}.experts",
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.norm_gate.e_score_correction_bias,
hash_indices_table=self.norm_gate.tid2eid,
e_score_correction_bias=self.gate.e_score_correction_bias,
hash_indices_table=self.gate.tid2eid,
swiglu_limit=self.swiglu_limit,
router_logits_dtype=torch.float32,
)
@@ -868,40 +860,40 @@ class DeepseekV4MoE(nn.Module):
def forward(
self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None
) -> torch.Tensor:
if self.norm_gate.tid2eid is not None and input_ids is None:
if self.gate.tid2eid is not None and input_ids is None:
raise ValueError("DeepSeek V4 hash MoE routing requires input_ids.")
if not self.use_mega_moe:
return self._forward_fused_moe(hidden_states, input_ids)
org_shape = hidden_states.shape
normed_x, router_logits = self.norm_gate(hidden_states)
router_logits, _ = self.gate(hidden_states)
topk_weights, topk_ids = fused_topk_bias(
hidden_states=normed_x,
hidden_states=hidden_states,
gating_output=router_logits,
scoring_func=self.scoring_func,
e_score_correction_bias=self.norm_gate.e_score_correction_bias.data
if self.norm_gate.e_score_correction_bias is not None
e_score_correction_bias=self.gate.e_score_correction_bias.data
if self.gate.e_score_correction_bias is not None
else None,
topk=self.n_activated_experts,
renormalize=self.renormalize,
indices_type=self.hash_indices_dtype,
input_tokens=input_ids,
hash_indices_table=self.norm_gate.tid2eid,
hash_indices_table=self.gate.tid2eid,
routed_scaling_factor=self.routed_scaling_factor,
)
activation_clamp = (
float(self.swiglu_limit) if self.swiglu_limit is not None else None
)
final_hidden_states = self.experts(
normed_x,
hidden_states,
topk_weights,
topk_ids,
activation_clamp=activation_clamp,
)
if self.shared_experts is not None:
shared_output = self.shared_experts(normed_x)
shared_output = self.shared_experts(hidden_states)
final_hidden_states += shared_output
return final_hidden_states.view(org_shape)
@@ -909,14 +901,21 @@ class DeepseekV4MoE(nn.Module):
def _forward_fused_moe(
self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None
) -> torch.Tensor:
assert not self.experts.is_internal_router
org_shape = hidden_states.shape
normed_x, router_logits = self.norm_gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=normed_x,
router_logits=router_logits,
input_ids=input_ids,
)
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=hidden_states,
input_ids=input_ids,
)
else:
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
input_ids=input_ids,
)
return final_hidden_states.view(org_shape)
@@ -1120,8 +1119,7 @@ class DeepseekV4DecoderLayer(nn.Module):
self.ffn = DeepseekV4MoE(vllm_config, prefix=f"{prefix}.ffn")
self.attn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps)
# ``ffn_norm`` is owned by ``self.ffn.norm_gate`` (fused with the
# router gate matmul); see ``NormGatedLinear``.
self.ffn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps)
self.hc_mult = config.hc_mult
self.hc_sinkhorn_iters = config.hc_sinkhorn_iters
self.hc_eps = config.hc_eps
@@ -1170,9 +1168,6 @@ class DeepseekV4DecoderLayer(nn.Module):
),
requires_grad=False,
)
self.mhc_pre = MHCPreOp()
self.mhc_post = MHCPostOp()
self.mhc_fused_post_pre = MHCFusedPostPreOp()
def hc_pre(
self,
@@ -1181,7 +1176,7 @@ class DeepseekV4DecoderLayer(nn.Module):
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
):
post_mix, res_mix, layer_input = self.mhc_pre(
post_mix, res_mix, layer_input = torch.ops.vllm.mhc_pre(
residual=x,
fn=hc_fn,
hc_scale=hc_scale,
@@ -1201,17 +1196,17 @@ class DeepseekV4DecoderLayer(nn.Module):
post: torch.Tensor,
comb: torch.Tensor,
):
return self.mhc_post(x, residual, post, comb)
return torch.ops.vllm.mhc_post(x, residual, post, comb)
def _forward_cuda(
def forward(
self,
x: torch.Tensor,
positions: torch.Tensor,
input_ids: torch.Tensor | None,
post_mix: torch.Tensor | None = None,
res_mix: torch.Tensor | None = None,
residual: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
post_mix: torch.Tensor | None,
res_mix: torch.Tensor | None,
residual: torch.Tensor | None,
) -> torch.Tensor:
if residual is None:
# Run standalone hc_pre on first layer
residual = x
@@ -1219,7 +1214,7 @@ class DeepseekV4DecoderLayer(nn.Module):
x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base
)
else:
residual, post_mix, res_mix, x = self.mhc_fused_post_pre(
residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre(
x,
residual,
post_mix,
@@ -1237,7 +1232,7 @@ class DeepseekV4DecoderLayer(nn.Module):
x = self.attn_norm(x)
x = self.attn(positions, x, None)
residual, post_mix, res_mix, x = self.mhc_fused_post_pre(
residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre(
x,
residual,
post_mix,
@@ -1251,65 +1246,29 @@ class DeepseekV4DecoderLayer(nn.Module):
self.hc_post_alpha,
self.hc_sinkhorn_iters,
)
# ffn_norm is now folded into self.ffn.norm_gate; ffn() takes
# the pre-norm activation directly.
x = self.ffn_norm(x)
x = self.ffn(x, input_ids)
return x, residual, post_mix, res_mix
def _forward_rocm(
self,
x: torch.Tensor,
positions: torch.Tensor,
input_ids: torch.Tensor | None,
post_mix: torch.Tensor | None = None,
res_mix: torch.Tensor | None = None,
residual: torch.Tensor | None = None,
) -> tuple[
torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None
]:
residual = x
x, post, comb = self.hc_pre(
x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base
)
x = self.attn_norm(x)
x = self.attn(positions, x, None)
x = self.hc_post(x, residual, post, comb)
residual = x
x, post, comb = self.hc_pre(
x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base
)
# ffn_norm is now folded into self.ffn.norm_gate; ffn() takes
# the pre-norm activation directly.
x = self.ffn(x, input_ids)
x = self.hc_post(x, residual, post, comb)
return x, None, None, None
def forward(
self,
x: torch.Tensor,
positions: torch.Tensor,
input_ids: torch.Tensor | None,
post_mix: torch.Tensor | None = None,
res_mix: torch.Tensor | None = None,
residual: torch.Tensor | None = None,
) -> tuple[
torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None
]:
if current_platform.is_rocm():
return self._forward_rocm(
x, positions, input_ids, post_mix, res_mix, residual
)
return self._forward_cuda(x, positions, input_ids, post_mix, res_mix, residual)
@support_torch_compile
class DeepseekV4Model(nn.Module):
def __init__(self, *, vllm_config: Vllm_config, prefix: str = ""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
# Select weight mapper based on quantization method.
# NVFP4 (modelopt_fp4) checkpoints use different key naming
# than the default MXFP4 format.
quant_config = vllm_config.quant_config
if quant_config is not None and getattr(quant_config, "get_name", lambda: None)() == "modelopt_fp4":
self.hf_to_vllm_mapper = _make_deepseek_v4_nvfp4_weights_mapper()
elif getattr(config, "expert_dtype", "fp4") != "fp4":
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp8")
else:
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4")
quant_config = vllm_config.quant_config
self.config = config
self.use_mega_moe = (
@@ -1392,7 +1351,7 @@ class DeepseekV4Model(nn.Module):
torch.empty(1, dtype=torch.float32),
requires_grad=False,
)
self.hc_head_op = HCHeadOp()
# Pre-hc_head residual stream buffer for the MTP draft. Stable
# address (outside the cudagraph pool) so the copy_ in forward()
# refreshes it correctly across captured shapes.
@@ -1462,7 +1421,7 @@ class DeepseekV4Model(nn.Module):
res_mix,
residual,
)
if layer is not None and current_platform.is_cuda():
else:
hidden_states = layer.hc_post(hidden_states, residual, post_mix, res_mix)
if not get_pp_group().is_last_rank:
@@ -1472,7 +1431,7 @@ class DeepseekV4Model(nn.Module):
num_tokens = hidden_states.shape[0]
self._mtp_hidden_buffer[:num_tokens].copy_(hidden_states.flatten(1))
hidden_states = self.hc_head_op(
hidden_states = hc_head(
hidden_states,
self.hc_head_fn,
self.hc_head_scale,
@@ -1601,6 +1560,36 @@ class DeepseekV4Model(nn.Module):
layer.ffn.finalize_mega_moe_weights()
@torch.compile(backend=current_platform.simple_compile_backend)
def hc_head(
hidden_states: torch.Tensor,
hc_fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
rms_norm_eps: float,
hc_eps: float,
) -> torch.Tensor:
hc_mult, hidden_size = hidden_states.shape[-2:]
outer_shape = hidden_states.shape[:-2]
hs_flat = hidden_states.view(-1, hc_mult, hidden_size)
num_tokens = hs_flat.shape[0]
out = torch.empty(
num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device
)
torch.ops.vllm.hc_head_fused_kernel(
hs_flat,
hc_fn,
hc_scale,
hc_base,
out,
hidden_size,
rms_norm_eps,
hc_eps,
hc_mult,
)
return out.view(*outer_shape, hidden_size)
def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
if expert_dtype == "fp4":
# MXFP4 experts use Mxfp4MoEMethod, which registers scales as
@@ -1630,13 +1619,7 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
orig_to_new_suffix={
"head.weight": "lm_head.weight",
"embed.weight": "embed_tokens.weight",
# Pre-MoE norm + gate are now owned by ``DeepseekV4MoE.norm_gate``
# (see NormGatedLinear).
".ffn_norm.weight": ".ffn.norm_gate.norm.weight",
".ffn.gate.weight": ".ffn.norm_gate.gate.weight",
".ffn.gate.bias": ".ffn.norm_gate.e_score_correction_bias",
# Hash MoE table also moved off the inner gate.
".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid",
".ffn.gate.bias": ".ffn.gate.e_score_correction_bias",
},
orig_to_new_substr={
".attn.compressor.": ".attn.mla_attn.compressor.",
@@ -1655,21 +1638,15 @@ def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper:
- Scales already have .weight_scale / .weight_scale_2 / .input_scale suffixes
- Shared expert uses down_proj (not w2)
- Self-attention uses .self_attn. prefix (same as checkpoint, renamed to .attn.)
- Hadamard coding uses .attn_hc. and .ffn_hc. prefixes
This is the mapper that should be used when quantization is modelopt_fp4.
"""
# Expert weight renames: gate_proj→w1, up_proj→w3, down_proj→w2
# Must match BEFORE the general suffix renames
expert_rename_regex = {
re.compile(r"(\.experts\.\d+\.)gate_proj\."): r"\1w1.",
re.compile(r"(\.experts\.\d+\.)up_proj\."): r"\1w3.",
re.compile(r"(\.experts\.\d+\.)down_proj\."): r"\1w2.",
}
# Suffix renames for non-expert keys
# NVFP4 checkpoints already use .weight_scale (not .scale), so no scale→weight_scale mapping needed
# But .self_attn. → .attn. and .mlp. → .ffn. renames are needed
suffix_renames = {
"head.weight": "lm_head.weight",
"embed.weight": "embed_tokens.weight",
@@ -1679,7 +1656,6 @@ def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper:
".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid",
}
# Substr renames
substr_renames = {
".attn.compressor.": ".attn.mla_attn.compressor.",
".mlp.shared_experts.gate_proj.": ".ffn.shared_experts.w1.",
@@ -1687,8 +1663,6 @@ def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper:
".mlp.shared_experts.down_proj.": ".ffn.shared_experts.down_proj.",
".mlp.": ".ffn.",
".self_attn.": ".attn.",
".attn_hc.": ".attn.hc_op.",
".ffn_hc.": ".ffn.hc_op.",
}
return WeightsMapper(
@@ -1696,8 +1670,6 @@ def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper:
"layers.": "model.layers.",
"embed.": "model.embed.",
"norm.": "model.norm.",
"hc_head": "model.hc_head",
"mtp.": "model.mtp.",
},
orig_to_new_regex=expert_rename_regex,
orig_to_new_suffix=suffix_renames,

View File

@@ -46,7 +46,7 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import (
QuantFP8,
@@ -1109,6 +1109,7 @@ class DeepseekV4Indexer(nn.Module):
quant_config=None,
prefix=f"{prefix}.weights_proj",
)
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.softmax_scale = self.head_dim**-0.5
self.scale_fmt = "ue8m0"

View File

@@ -1,289 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for selecting and loading models."""
import inspect
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any
import torch
from torch import nn
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import (
Attention,
MLAAttention,
MMEncoderAttention,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.model_loader.reload import (
record_metadata_for_reloading,
set_torchao_reload_attrs,
)
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.tracing import instrument
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import get_accelerator_view_from_cpu_tensor
logger = init_logger(__name__)
@instrument(span_name="Initialize model")
def initialize_model(
vllm_config: VllmConfig,
*,
prefix: str = "",
model_class: type[nn.Module] | None = None,
model_config: ModelConfig | None = None,
) -> nn.Module:
"""Initialize a model with the given configurations."""
if model_config is None:
model_config = vllm_config.model_config
if model_class is None:
model_class, _ = get_model_architecture(model_config)
if vllm_config.quant_config is not None:
configure_quant_config(vllm_config.quant_config, model_class)
signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
model = model_class(vllm_config=vllm_config, prefix=prefix)
record_metadata_for_reloading(model)
return model
msg = (
"vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"for the design and update the model class accordingly."
)
warnings.warn(msg, DeprecationWarning, stacklevel=2)
logger.warning(
"Trying to guess the arguments for old-style model class %s",
model_class,
)
# try to be compatible with old-style model class
kwargs: dict[str, Any] = {}
if "prefix" in all_params:
kwargs["prefix"] = prefix
if "config" in all_params:
kwargs["config"] = model_config.hf_config
if "cache_config" in all_params:
kwargs["cache_config"] = vllm_config.cache_config
if "quant_config" in all_params:
kwargs["quant_config"] = vllm_config.quant_config
if "lora_config" in all_params:
kwargs["lora_config"] = vllm_config.lora_config
if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
model = model_class(**kwargs)
record_metadata_for_reloading(model)
return model
def process_weights_after_loading(
model: nn.Module, model_config: ModelConfig, target_device: torch.device
) -> None:
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
# Initialize post-load attention weights for Attention, MLA, and MM encoder.
# NOTE: Happens after other modules so we can easily decompress weights.
for _, module in model.named_modules():
if isinstance(
module, (Attention, MLAAttention, MMEncoderAttention)
) and hasattr(module, "process_weights_after_loading"):
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
with device_loading_context(module, target_device):
module.process_weights_after_loading(model_config.dtype)
if model_config.quantization == "torchao":
set_torchao_reload_attrs(model, model_config)
@contextmanager
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
if target_device.type == "cpu":
# If target is CPU, no need to move anything
yield module
return
original_device_states: dict[str, torch.device] = {}
uva_offloaded_parameters: list[str] = []
# Store original device states and move parameters to GPU if they're on CPU
for name, p in module.named_parameters():
if p.device.type == "cpu":
original_device_states[name] = p.device
p.data = p.data.to(target_device)
if getattr(p, "_vllm_is_uva_offloaded", False):
uva_offloaded_parameters.append(name)
# Parameters already on target device are not touched
try:
yield module
finally:
use_pin_memory = (
is_pin_memory_available()
and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY
)
# Restore parameters to their original devices, ignoring new parameters
for name, p in module.named_parameters():
if name in original_device_states:
original_device: torch.device = original_device_states[name]
p.data = p.data.to(original_device)
# parameter is UVA offloaded, but was replaced with a new device tensor
# re-offload it to CPU using UVA
if name in uva_offloaded_parameters and not getattr(
p, "_vllm_is_uva_offloaded", False
):
cpu_data = p.data.to(device="cpu")
if use_pin_memory:
cpu_data = cpu_data.pin_memory()
p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
p._vllm_is_uva_offloaded = True
_MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
"""Caches the outputs of `_get_model_architecture`."""
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
from vllm.model_executor.models.adapters import as_embedding_model, as_seq_cls_model
architectures = getattr(model_config.hf_config, "architectures", None) or []
model_cls, arch = model_config.registry.resolve_model_cls(
architectures,
model_config=model_config,
)
if arch == model_config._get_transformers_backend_cls():
assert model_config.model_impl != "vllm"
if model_config.model_impl == "auto":
logger.warning_once(
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.",
arch,
)
convert_type = model_config.convert_type
if convert_type == "none":
pass
elif convert_type == "embed":
logger.debug_once("Converting to embedding model.")
model_cls = as_embedding_model(model_cls)
elif convert_type == "classify":
logger.debug_once("Converting to sequence classification model.")
model_cls = as_seq_cls_model(model_cls)
else:
assert_never(convert_type)
return model_cls, arch
def get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
key = hash(
(
model_config.model,
model_config.convert_type,
model_config.runner_type,
model_config.trust_remote_code,
model_config.model_impl,
tuple(getattr(model_config.hf_config, "architectures", None) or []),
)
)
if key in _MODEL_ARCH_BY_HASH:
return _MODEL_ARCH_BY_HASH[key]
model_arch = _get_model_architecture(model_config)
_MODEL_ARCH_BY_HASH[key] = model_arch
return model_arch
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
return get_model_architecture(model_config)[0]
def get_architecture_class_name(model_config: ModelConfig) -> str:
return get_model_architecture(model_config)[1]
@dataclass
class ParamMapping:
"""
A class to handle parameter mapping for model weight loading.
It creates a bidirectional mapping between packed parameters and their
constituent parts.
"""
packed_mapping: dict[str, list[str]]
inverse_packed_mapping: dict[str, tuple[str, int]] = field(default_factory=dict)
def __post_init__(self):
for packed_name, sub_params in self.packed_mapping.items():
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
if len(sub_params) == 1 and sub_params[0] == packed_name:
continue
for index, param_name in enumerate(sub_params):
self.inverse_packed_mapping[param_name] = (
packed_name,
index,
)
def get_sub_modules(self, module_name: str) -> tuple[str, list[str]] | None:
for key, value in self.packed_mapping.items():
if module_name.endswith(key):
return key, value
return None
def configure_quant_config(
quant_config: QuantizationConfig, model_class: type[nn.Module]
):
"""
Pass packed_modules_mapping by reference to quant_config so that
quant_config can properly match fused modules
Note that model attributes are passed by reference to quant_config,
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
Once the `SupportsQuant` mixin has been added to all models, this
function can be removed
"""
if not issubclass(model_class, SupportsQuant):
hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
# pass mappings by reference to quant_config
if hf_to_vllm_mapper is not None:
quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
if packed_mapping is not None:
quant_config.packed_modules_mapping = packed_mapping