tweax n shit
This commit is contained in:
@@ -8,18 +8,18 @@ RUN apt-get update && apt-get install -y git screen cmake && rm -rf /var/lib/apt
|
||||
|
||||
# Clone and build DeepGEMM with NVFP4 mega_moe kernel
|
||||
# CACHE_BUSTER: increment to force fresh clone
|
||||
RUN git clone -b nvfp4-mega-moe https://sweetapi.com/biondizzle/DeepGEMM.git /root/DeepGEMM && PATCH_CACHE_BUSTER=69
|
||||
RUN git clone -b nvfp4-mega-moe https://sweetapi.com/biondizzle/DeepGEMM.git /root/DeepGEMM && PATCH_CACHE_BUSTER=70
|
||||
|
||||
# Build DeepGEMM with proper CUDA/NVRTC paths
|
||||
ENV CPATH="/usr/local/lib/python3.12/dist-packages/flashinfer/data/cutlass/include:/usr/local/lib/python3.12/dist-packages/nvidia/cu13/include:/usr/local/cuda-13.0/include:${CPATH}"
|
||||
ENV PYTHONPATH="/root/DeepGEMM:${PYTHONPATH}"
|
||||
# NVRTC lives in the pip nvidia/cu13 package, but the linker expects it in cuda/lib64
|
||||
# Create a symlink so -lnvrtc resolves
|
||||
RUN ln -sf /usr/local/lib/python3.12/dist-packages/nvidia/cu13/lib/libnvrtc.so.13 /usr/local/cuda/lib64/libnvrtc.so && PATCH_CACHE_BUSTER=69
|
||||
RUN ln -sf /usr/local/lib/python3.12/dist-packages/nvidia/cu13/lib/libnvrtc.so.13 /usr/local/cuda/lib64/libnvrtc.so && PATCH_CACHE_BUSTER=70
|
||||
RUN cd /root/DeepGEMM && python3 setup.py build_ext --inplace && PATCH_CACHE_BUSTER=69
|
||||
|
||||
# Bust cache for patch changes — ARG before COPY ensures layer invalidation
|
||||
ARG PATCH_CACHE_BUSTER=69
|
||||
ARG PATCH_CACHE_BUSTER=70
|
||||
# Copy our DeepSeek V4 patch over vLLM's model file
|
||||
COPY patches/deepseek_v4.py /usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v4.py
|
||||
# Copy the NVFP4 staging kernel (BF16→E2M1+UE4M3 quantization for activations)
|
||||
|
||||
@@ -12,7 +12,7 @@ services:
|
||||
command:
|
||||
- /model
|
||||
- --trust-remote-code
|
||||
- --kv-cache-dtype=fp8
|
||||
#- --kv-cache-dtype=fp8 # maybe we just let it figure its own shit out
|
||||
#- --block-size=256
|
||||
- --enable-expert-parallel
|
||||
- --tensor-parallel-size=8
|
||||
|
||||
@@ -1,20 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# ==============================================================================
|
||||
# DeepSeek V4 NVFP4 Patch — Version Banner (printed at import time)
|
||||
# ==============================================================================
|
||||
import datetime as _dt
|
||||
import os as _os
|
||||
_git_commit = _os.popen("git -C /root/nvidia-meeting/deepseek-v4-quant rev-parse --short HEAD 2>/dev/null || echo 'unknown'").read().strip()
|
||||
print(f"""
|
||||
{'='*70}
|
||||
DeepSeek V4 NVFP4 Patch (commit {_git_commit})
|
||||
Loaded: {_dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S UTC')}
|
||||
{'='*70}
|
||||
""")
|
||||
# ==============================================================================
|
||||
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
from itertools import islice
|
||||
@@ -165,8 +151,9 @@ class DeepseekV4FP8Config(Fp8Config):
|
||||
try:
|
||||
hf_config = get_current_vllm_config().model_config.hf_config
|
||||
except Exception:
|
||||
# vllm_config not yet set; defer the decision until a
|
||||
# later call lands inside set_current_vllm_config.
|
||||
# vllm_config not yet set; return safe default but do NOT
|
||||
# cache — a later call inside set_current_vllm_config may
|
||||
# resolve differently.
|
||||
return "fp4"
|
||||
expert_dtype = getattr(hf_config, "expert_dtype", "fp4")
|
||||
if expert_dtype not in _DEEPSEEK_V4_EXPERT_DTYPES:
|
||||
@@ -175,11 +162,6 @@ class DeepseekV4FP8Config(Fp8Config):
|
||||
f"expected one of {_DEEPSEEK_V4_EXPERT_DTYPES}."
|
||||
)
|
||||
self._resolved_expert_dtype = expert_dtype
|
||||
from vllm.logger import init_logger
|
||||
|
||||
init_logger(__name__).info_once(
|
||||
"DeepSeek V4 expert_dtype resolved to %r", expert_dtype
|
||||
)
|
||||
return self._resolved_expert_dtype
|
||||
|
||||
@property
|
||||
@@ -614,26 +596,25 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
return_success: bool = False,
|
||||
) -> bool | None:
|
||||
) -> bool:
|
||||
local_expert_id = self._map_global_expert_id(expert_id)
|
||||
if local_expert_id == -1:
|
||||
return False if return_success else None
|
||||
return False
|
||||
|
||||
# Scalar params (weight_scale_2, input_scale): 1D per-expert
|
||||
if "weight_scale_2" in weight_name or "input_scale" in weight_name:
|
||||
param.data[local_expert_id].copy_(loaded_weight)
|
||||
return True if return_success else None
|
||||
return True
|
||||
|
||||
expert_data = param.data[local_expert_id]
|
||||
if shard_id in ("w1", "w3"):
|
||||
if "w13_" not in weight_name:
|
||||
return False if return_success else None
|
||||
return False
|
||||
shard_offset = 0 if shard_id == "w1" else self.intermediate_size
|
||||
expert_data = expert_data.narrow(0, shard_offset, self.intermediate_size)
|
||||
elif shard_id == "w2":
|
||||
if "w2_" not in weight_name:
|
||||
return False if return_success else None
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"Unsupported expert shard id: {shard_id}")
|
||||
|
||||
@@ -644,7 +625,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
f"vs checkpoint {tuple(loaded_weight.shape)}"
|
||||
)
|
||||
expert_data.copy_(loaded_weight)
|
||||
return True if return_success else None
|
||||
return True
|
||||
|
||||
def _check_runtime_supported(self) -> None:
|
||||
if not torch.cuda.is_available():
|
||||
@@ -654,7 +635,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
raise NotImplementedError(
|
||||
"DeepSeek V4 MegaMoE expert weights must be loaded on CUDA."
|
||||
)
|
||||
if torch.cuda.get_device_capability(device)[0] != 10:
|
||||
if torch.cuda.get_device_capability(device)[0] < 10:
|
||||
raise NotImplementedError("DeepGEMM MegaMoE requires SM100 GPUs.")
|
||||
if self.hidden_size % 128 != 0 or self.intermediate_size % 128 != 0:
|
||||
raise ValueError(
|
||||
@@ -707,98 +688,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
"""
|
||||
return sf.to(torch.float32)
|
||||
|
||||
def _nvfp4_to_bf16(
|
||||
self,
|
||||
w_uint8: torch.Tensor, # [E, M, K//2]
|
||||
w_scale_f8: torch.Tensor, # [E, M, K//16] (float8_e4m3fn, UE4M3 standard NVFP4)
|
||||
w_scale_2: torch.Tensor, # [E] float32 global scale
|
||||
w_input_scale: torch.Tensor, # [E] float32 activation scale
|
||||
) -> torch.Tensor:
|
||||
"""Dequantize NVFP4 expert weights to BF16.
|
||||
|
||||
Formula: weight_bf16 = e2m1_value * block_scale_ue4m3 * global_scale
|
||||
(input_scale is for activations, not weights)
|
||||
"""
|
||||
device = w_uint8.device
|
||||
E, M, K2 = w_uint8.shape
|
||||
K = K2 * 2 # unpacked dim
|
||||
|
||||
# Unpack E2M1 FP4 → BF16
|
||||
e2m1_lut = torch.tensor(
|
||||
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0],
|
||||
dtype=torch.bfloat16, device=device,
|
||||
)
|
||||
even_raw = (w_uint8 & 0x0F).int()
|
||||
odd_raw = ((w_uint8 >> 4) & 0x0F).int()
|
||||
even_sign = torch.where(even_raw >= 8, -1.0, 1.0).to(torch.bfloat16)
|
||||
odd_sign = torch.where(odd_raw >= 8, -1.0, 1.0).to(torch.bfloat16)
|
||||
even_vals = even_sign * e2m1_lut[even_raw & 0x07]
|
||||
odd_vals = odd_sign * e2m1_lut[odd_raw & 0x07]
|
||||
w_bf16 = torch.stack([even_vals, odd_vals], dim=-1).reshape(E, M, K)
|
||||
|
||||
# Dequantize: e2m1 * block_scale * global_scale
|
||||
block_scale = self._ue8m0_to_float32(w_scale_f8) # [E, M, K//16]
|
||||
GROUP_SIZE = 16
|
||||
# Expand block scale to match weight elements
|
||||
block_scale_expanded = block_scale.unsqueeze(-1).expand(
|
||||
-1, -1, -1, GROUP_SIZE
|
||||
).reshape(E, M, K)
|
||||
|
||||
# Global scale: [E] → [E, 1, 1]
|
||||
global_scale = w_scale_2.view(E, 1, 1)
|
||||
|
||||
w_dequant = w_bf16.float() * block_scale_expanded * global_scale
|
||||
return w_dequant.to(torch.bfloat16)
|
||||
|
||||
def _bf16_to_mxfp4(
|
||||
self,
|
||||
w_bf16: torch.Tensor, # [E, M, K]
|
||||
group_size: int = 32,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Re-quantize BF16 → MXFP4 (E2M1 packed uint8 + UE8M0 uint8 scales).
|
||||
|
||||
Returns (w_packed_uint8, w_scale_uint8) where:
|
||||
w_packed_uint8: [E, M, K//2]
|
||||
w_scale_uint8: [E, M, K//group_size] as uint8 UE8M0 bytes
|
||||
"""
|
||||
device = w_bf16.device
|
||||
E, M, K = w_bf16.shape
|
||||
|
||||
# Block quantization
|
||||
n_groups = K // group_size
|
||||
w_groups = w_bf16.reshape(E, M, n_groups, group_size)
|
||||
|
||||
# Compute block amax
|
||||
amax = w_groups.abs().amax(dim=-1) # [E, M, n_groups]
|
||||
|
||||
# UE8M0 scale: floor to nearest power of 2
|
||||
# value = 2^(exp-127), so exp = floor(log2(amax)) + 127
|
||||
amax_clamped = amax.clamp(min=2**-126)
|
||||
scale_exp = torch.floor(torch.log2(amax_clamped)).to(torch.int32) + 127
|
||||
scale_exp = scale_exp.clamp(1, 254).to(torch.uint8) # avoid 0 and 255
|
||||
|
||||
# Recover float32 scale for quantization
|
||||
scale_f32 = (scale_exp.to(torch.int32) << 23).view(torch.float32) # [E, M, n_groups]
|
||||
|
||||
# Quantize each element: value / scale → nearest E2M1
|
||||
E2M1_POS = torch.tensor(
|
||||
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0],
|
||||
dtype=torch.float32, device=device,
|
||||
)
|
||||
scaled = w_groups.float() / scale_f32.unsqueeze(-1) # [E, M, n_groups, group_size]
|
||||
scaled_abs = scaled.abs()
|
||||
diff = (scaled_abs.unsqueeze(-1) - E2M1_POS).abs()
|
||||
fp4_idx = diff.argmin(dim=-1) # [E, M, n_groups, group_size]
|
||||
sign = (scaled < 0).int()
|
||||
fp4_val = (sign << 3) | fp4_idx.int()
|
||||
|
||||
# Pack 2 FP4 values per byte
|
||||
fp4_flat = fp4_val.reshape(E, M, K)
|
||||
even = fp4_flat[:, :, 0::2]
|
||||
odd = fp4_flat[:, :, 1::2]
|
||||
w_packed = ((odd << 4) | even).to(torch.uint8).view(torch.int8)
|
||||
|
||||
return w_packed, scale_exp
|
||||
|
||||
def get_symm_buffer(self):
|
||||
import deep_gemm
|
||||
@@ -893,8 +783,8 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
activation_clamp=activation_clamp,
|
||||
fast_math=fast_math,
|
||||
)
|
||||
# Sync to catch mega_moe CUDA errors immediately
|
||||
torch.cuda.synchronize()
|
||||
if os.environ.get('NVFP4_DEBUG_SYNC', '') == '1':
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
DeepseekV4MegaMoEExperts.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]
|
||||
@@ -1239,7 +1129,7 @@ class DeepseekV4Attention(nn.Module):
|
||||
self.rope_parameters = config.rope_scaling
|
||||
|
||||
# Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it)
|
||||
rope_parameters = config.rope_parameters
|
||||
rope_parameters = dict(config.rope_parameters)
|
||||
rope_parameters["rope_theta"] = (
|
||||
config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta
|
||||
)
|
||||
@@ -1653,14 +1543,19 @@ class DeepseekV4Model(nn.Module):
|
||||
else:
|
||||
if ".experts." in name:
|
||||
# E8M0 scales are stored as float8_e8m0fnu in
|
||||
# checkpoints but the MoE param is uint8. copy_()
|
||||
# would do a numeric conversion (e.g. 2^-7 → 0),
|
||||
# destroying the raw exponent bytes.
|
||||
# MXFP4 checkpoints but NVFP4 uses float8_e4m3fn.
|
||||
# The uint8 view+copy path is only valid for MXFP4;
|
||||
# for NVFP4 it would paste raw E8M0 bytes into an
|
||||
# E4M3 buffer, producing garbage.
|
||||
if (
|
||||
"weight_scale" in name
|
||||
and loaded_weight.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
loaded_weight = loaded_weight.view(torch.uint8)
|
||||
assert False, (
|
||||
f"E8M0 weight_scale encountered for NVFP4 experts "
|
||||
f"({name}) — this is only valid for MXFP4. "
|
||||
f"Check checkpoint dtype."
|
||||
)
|
||||
for mapping in expert_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
@@ -1681,7 +1576,6 @@ class DeepseekV4Model(nn.Module):
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
name = name_mapped
|
||||
@@ -1851,7 +1745,10 @@ class DeepseekV4Model(nn.Module):
|
||||
fp8_from_bf16 = 0
|
||||
bf16_converted = 0
|
||||
compressor_converted = 0
|
||||
diag_printed = False
|
||||
|
||||
# Build shard index once for compressor reconstruction (avoids N×M full-shard loads)
|
||||
_shard_index = self._build_shard_index("/model") if os.path.isdir("/model") else None
|
||||
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
attn = layer.attn
|
||||
|
||||
@@ -1864,20 +1761,11 @@ class DeepseekV4Model(nn.Module):
|
||||
continue
|
||||
if mod.weight.dtype in (torch.uint8, torch.int8):
|
||||
# NVFP4 -> dequant to bf16 -> requant to FP8
|
||||
if not diag_printed and layer_idx == 0:
|
||||
ws = getattr(mod, 'weight_scale', None)
|
||||
ws2 = getattr(mod, 'weight_scale_2', None)
|
||||
print(f"[DIAG-wo_a:0] dtype={mod.weight.dtype} shape={mod.weight.shape} "
|
||||
f"ws_dtype={ws.dtype if ws is not None else None} ws_shape={ws.shape if ws is not None else None} "
|
||||
f"ws2_val={ws2.data.item() if ws2 is not None else None}")
|
||||
self._convert_nvfp4_to_fp8(mod, E2M1_LUT, FP8_MAX)
|
||||
fp8_converted += 1
|
||||
elif mod.weight.dtype == torch.bfloat16:
|
||||
if not diag_printed and layer_idx == 0:
|
||||
print(f"[DIAG-wo_a:0] dtype=bf16 shape={mod.weight.shape} (direct bf16→fp8)")
|
||||
self._convert_bf16_to_fp8(mod, FP8_MAX)
|
||||
fp8_from_bf16 += 1
|
||||
diag_printed = True
|
||||
|
||||
# BF16 conversion: attention layers via .forward()
|
||||
for proj_name in bf16_proj_names:
|
||||
@@ -1886,10 +1774,6 @@ class DeepseekV4Model(nn.Module):
|
||||
mod = getattr(attn, proj_name)
|
||||
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
|
||||
continue
|
||||
if not diag_printed and layer_idx == 0:
|
||||
ws = getattr(mod, 'weight_scale', None)
|
||||
print(f"[DIAG-bf16:0/{proj_name}] dtype={mod.weight.dtype} shape={mod.weight.shape} "
|
||||
f"ws_dtype={ws.dtype if ws is not None else None} ws_shape={ws.shape if ws is not None else None}")
|
||||
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
|
||||
bf16_converted += 1
|
||||
|
||||
@@ -1906,14 +1790,14 @@ class DeepseekV4Model(nn.Module):
|
||||
compressor = getattr(mla_attn, "compressor", None)
|
||||
if compressor is not None and hasattr(compressor, "fused_wkv_wgate"):
|
||||
compressor_converted += self._reconstruct_compressor_weight(
|
||||
compressor.fused_wkv_wgate, attn, layer_idx, E2M1_LUT)
|
||||
compressor.fused_wkv_wgate, attn, layer_idx, E2M1_LUT, _shard_index=_shard_index)
|
||||
# Indexer compressor (C4A layers only)
|
||||
indexer = getattr(mla_attn, "indexer", None)
|
||||
if indexer is not None:
|
||||
idx_compressor = getattr(indexer, "compressor", None)
|
||||
if idx_compressor is not None and hasattr(idx_compressor, "fused_wkv_wgate"):
|
||||
compressor_converted += self._reconstruct_compressor_weight(
|
||||
idx_compressor.fused_wkv_wgate, indexer, layer_idx, E2M1_LUT, sub_path=".indexer")
|
||||
idx_compressor.fused_wkv_wgate, indexer, layer_idx, E2M1_LUT, sub_path=".indexer", _shard_index=_shard_index)
|
||||
|
||||
# Shared experts
|
||||
ffn = layer.ffn
|
||||
@@ -1968,17 +1852,10 @@ class DeepseekV4Model(nn.Module):
|
||||
else:
|
||||
w_dequant = w_bf16
|
||||
|
||||
# Replace weight with bf16 version
|
||||
# Diagnostics: check for NaN/Inf from bad dequant
|
||||
if w_dequant.isnan().any() or w_dequant.isinf().any():
|
||||
nan_count = w_dequant.isnan().sum().item()
|
||||
inf_count = w_dequant.isinf().sum().item()
|
||||
print(f"[NVFP4-DEQUANT-WARN] {getattr(mod, 'prefix', 'unknown')}: "
|
||||
f"shape={w_dequant.shape}, dtype={w_dequant.dtype}, "
|
||||
f"NaN={nan_count}, Inf={inf_count}, "
|
||||
f"block_scale range=[{block_scale.min().item():.6f}, {block_scale.max().item():.6f}], "
|
||||
f"global_scale={global_scale:.6f}")
|
||||
# Free source tensors eagerly to avoid holding uint8+bf16+fp32 simultaneously
|
||||
del w_uint8, w_bf16
|
||||
mod.weight = torch.nn.Parameter(w_dequant, requires_grad=False)
|
||||
del w_dequant
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
mod.quant_method = UnquantizedLinearMethod()
|
||||
for attr in ("weight_scale", "weight_scale_2", "input_scale",
|
||||
@@ -2061,25 +1938,38 @@ class DeepseekV4Model(nn.Module):
|
||||
bmm_batch_size=bmm_batch_size,
|
||||
)
|
||||
|
||||
# Free source tensors eagerly
|
||||
del w_uint8, w_bf16, w_dequant
|
||||
mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)
|
||||
del w_fp8
|
||||
# weight_scale_inv is what the attention runtime reads as b_scale
|
||||
# for deepseek_v4_fp8_einsum -> DeepGEMM fp8_einsum.
|
||||
# It must be the DeepGEMM-formatted block scale (dg_ws), NOT the
|
||||
# per-tensor scalar. See: deepseek_v4_attention.py line 319.
|
||||
mod.weight_scale_inv = torch.nn.Parameter(ws, requires_grad=False)
|
||||
# weight_scale is not used at runtime for BMM layers; remove it
|
||||
# to avoid confusing other code paths.
|
||||
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
|
||||
if hasattr(mod, attr):
|
||||
delattr(mod, attr)
|
||||
del ws
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
mod.quant_method = UnquantizedLinearMethod()
|
||||
|
||||
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
|
||||
if hasattr(mod, attr):
|
||||
delattr(mod, attr)
|
||||
|
||||
def _reconstruct_compressor_weight(self, fused_mod, parent_mod, layer_idx, e2m1_lut, sub_path=""):
|
||||
@staticmethod
|
||||
def _build_shard_index(ckpt_dir: str) -> dict[str, str]:
|
||||
"""Build key→shard_path index from safetensors metadata (no tensor I/O)."""
|
||||
import glob
|
||||
from safetensors import safe_open
|
||||
index = {}
|
||||
for shard_file in sorted(glob.glob(os.path.join(ckpt_dir, "model-*.safetensors"))):
|
||||
try:
|
||||
with safe_open(shard_file, framework="pt") as f:
|
||||
for key in f.keys():
|
||||
index[key] = shard_file
|
||||
except Exception:
|
||||
continue
|
||||
return index
|
||||
|
||||
def _reconstruct_compressor_weight(self, fused_mod, parent_mod, layer_idx, e2m1_lut, sub_path="", _shard_index=None):
|
||||
"""Reconstruct compressor fused_wkv_wgate from checkpoint.
|
||||
|
||||
Compressor weights are SKIPPED during loading because NVFP4 uint8 data
|
||||
@@ -2087,8 +1977,7 @@ class DeepseekV4Model(nn.Module):
|
||||
We read the original uint8 data from the safetensors checkpoint, unpack
|
||||
E2M1, dequantize, and stack into the fused weight param.
|
||||
"""
|
||||
import glob
|
||||
from safetensors.torch import load_file
|
||||
from safetensors import safe_open
|
||||
|
||||
# Find the checkpoint directory
|
||||
# The model weights are mounted at /model in Docker
|
||||
@@ -2103,49 +1992,45 @@ class DeepseekV4Model(nn.Module):
|
||||
# We read from checkpoint (before mapper), so use original names
|
||||
layer_prefix = f"model.layers.{layer_idx}.self_attn.compressor{sub_path}"
|
||||
|
||||
# Find which shard contains this layer's compressor weights
|
||||
wkv_key = f"{layer_prefix}.kv_proj.weight"
|
||||
wgate_key = f"{layer_prefix}.gate_proj.weight"
|
||||
wkv_scale_key = f"{layer_prefix}.kv_proj.weight_scale"
|
||||
wgate_scale_key = f"{layer_prefix}.gate_proj.weight_scale"
|
||||
wkv_scale2_key = f"{layer_prefix}.kv_proj.weight_scale_2"
|
||||
wgate_scale2_key = f"{layer_prefix}.gate_proj.weight_scale_2"
|
||||
wkv_iscale_key = f"{layer_prefix}.kv_proj.input_scale"
|
||||
wgate_iscale_key = f"{layer_prefix}.gate_proj.input_scale"
|
||||
# All keys we need from the checkpoint
|
||||
keys = {
|
||||
'wkv_uint8': f"{layer_prefix}.kv_proj.weight",
|
||||
'wgate_uint8': f"{layer_prefix}.gate_proj.weight",
|
||||
'wkv_block_scale': f"{layer_prefix}.kv_proj.weight_scale",
|
||||
'wgate_block_scale': f"{layer_prefix}.gate_proj.weight_scale",
|
||||
'wkv_global_scale': f"{layer_prefix}.kv_proj.weight_scale_2",
|
||||
'wgate_global_scale': f"{layer_prefix}.gate_proj.weight_scale_2",
|
||||
'wkv_input_scale': f"{layer_prefix}.kv_proj.input_scale",
|
||||
'wgate_input_scale': f"{layer_prefix}.gate_proj.input_scale",
|
||||
}
|
||||
|
||||
# Load from safetensors
|
||||
wkv_uint8 = None
|
||||
wgate_uint8 = None
|
||||
wkv_block_scale = None
|
||||
wgate_block_scale = None
|
||||
wkv_global_scale = None
|
||||
wgate_global_scale = None
|
||||
wkv_input_scale = None
|
||||
wgate_input_scale = None
|
||||
|
||||
shard_files = sorted(glob.glob(os.path.join(ckpt_dir, "model-*.safetensors")))
|
||||
for shard_file in shard_files:
|
||||
# Read tensors using shard index for targeted access (no full-shard loads)
|
||||
tensors = {}
|
||||
for name, key in keys.items():
|
||||
shard_path = (_shard_index or {}).get(key)
|
||||
if shard_path is None:
|
||||
continue
|
||||
try:
|
||||
shard_data = load_file(shard_file)
|
||||
with safe_open(shard_path, framework="pt") as f:
|
||||
if key in f.keys():
|
||||
tensors[name] = f.get_tensor(key)
|
||||
except Exception:
|
||||
continue
|
||||
if wkv_key in shard_data:
|
||||
wkv_uint8 = shard_data[wkv_key]
|
||||
wkv_block_scale = shard_data.get(wkv_scale_key)
|
||||
wkv_global_scale = shard_data.get(wkv_scale2_key)
|
||||
wkv_input_scale = shard_data.get(wkv_iscale_key)
|
||||
if wgate_key in shard_data:
|
||||
wgate_uint8 = shard_data[wgate_key]
|
||||
wgate_block_scale = shard_data.get(wgate_scale_key)
|
||||
wgate_global_scale = shard_data.get(wgate_scale2_key)
|
||||
wgate_input_scale = shard_data.get(wgate_iscale_key)
|
||||
if wkv_uint8 is not None and wgate_uint8 is not None:
|
||||
break
|
||||
|
||||
wkv_uint8 = tensors.get('wkv_uint8')
|
||||
wgate_uint8 = tensors.get('wgate_uint8')
|
||||
|
||||
if wkv_uint8 is None or wgate_uint8 is None:
|
||||
# Layer might not have a compressor (compress_ratio=1 layers)
|
||||
return 0
|
||||
|
||||
wkv_block_scale = tensors.get('wkv_block_scale')
|
||||
wgate_block_scale = tensors.get('wgate_block_scale')
|
||||
wkv_global_scale = tensors.get('wkv_global_scale')
|
||||
wgate_global_scale = tensors.get('wgate_global_scale')
|
||||
wkv_input_scale = tensors.get('wkv_input_scale')
|
||||
wgate_input_scale = tensors.get('wgate_input_scale')
|
||||
|
||||
device = fused_mod.weight.device
|
||||
wkv_uint8 = wkv_uint8.to(device)
|
||||
wgate_uint8 = wgate_uint8.to(device)
|
||||
@@ -2469,12 +2354,10 @@ class DeepseekV4ForCausalLM(nn.Module):
|
||||
loader = AutoWeightsLoader(self, skip_substrs=["mtp."])
|
||||
loaded_params = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
self.model.finalize_mega_moe_weights()
|
||||
# Sync after mega_moe weight transform to catch CUDA errors early
|
||||
torch.cuda.synchronize()
|
||||
print("[NVFP4] mega_moe finalize_weights done, CUDA OK")
|
||||
self.model._convert_nvfp4_post_load()
|
||||
torch.cuda.synchronize()
|
||||
print("[NVFP4] post-load conversion done, CUDA OK")
|
||||
if os.environ.get('NVFP4_DEBUG_SYNC', '') == '1':
|
||||
torch.cuda.synchronize()
|
||||
print("[NVFP4] post-load conversion done, CUDA OK")
|
||||
return loaded_params
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
|
||||
Reference in New Issue
Block a user