clean up and possible big fix

This commit is contained in:
2026-05-14 23:41:10 +00:00
parent 9f01307c5b
commit 756ea2192f
2 changed files with 45 additions and 258 deletions

View File

@@ -525,130 +525,28 @@ class DeepseekV4MegaMoEExperts(nn.Module):
fast_math: bool,
) -> None:
import os
import nvfp4_megamoe_kernel as deep_gemm
# One-time state dump: model params vs checkpoint keys for layer 0 attn
if int(os.environ.get('MEGA_MOE_DEBUG', '0')) and not getattr(self, '_state_dump_done', False):
self._state_dump_done = True
from vllm.distributed import get_tensor_model_parallel_rank
if get_tensor_model_parallel_rank() == 0:
import gc
models = [o for o in gc.get_objects()
if type(o).__name__ in ('DeepseekV4Model', 'DeepseekV4ForCausalLM')]
if models:
m = models[0]
sd = dict(m.named_parameters())
l0 = sorted(k for k in sd if 'layers.0.' in k and ('attn' in k or 'self_attn' in k))
print("=== MODEL params (layer 0 attn) ===")
for k in l0:
t = sd[k]
nz = (t != 0).any().item()
print(f" {k}\n shape={tuple(t.shape)} dtype={t.dtype} any_nonzero={nz}")
from safetensors import safe_open
import glob, json
model_dir = '/model'
idx_path = os.path.join(model_dir, 'model.safetensors.index.json')
if os.path.exists(idx_path):
with open(idx_path) as f:
idx = json.load(f)
l0_ckpt = sorted(k for k in idx['weight_map']
if 'layers.0.' in k and ('attn' in k or 'self_attn' in k))
print(f"\n=== CHECKPOINT keys (layer 0 attn) ===")
shard_of = idx['weight_map']
shards_needed = sorted(set(shard_of[k] for k in l0_ckpt))
shapes = {}
for shard_name in shards_needed:
shard_path = os.path.join(model_dir, shard_name)
with safe_open(shard_path, framework='pt') as h:
for k in l0_ckpt:
if shard_of[k] == shard_name:
shapes[k] = (tuple(h.get_tensor(k).shape), h.get_tensor(k).dtype)
for k in l0_ckpt:
s, d = shapes.get(k, ('?', '?'))
print(f" {k}\n shape={s} dtype={d}")
else:
print(f"[state-dump] No index.json at {idx_path}")
from nvfp4_megamoe_kernel import stage_activation, nvfp4_mega_moe_full
symm_buffer = self.get_symm_buffer()
symm_buffer.experts_start_idx = self.experts_start_idx
num_tokens = hidden_states.shape[0]
# NaN-trace: check hidden_states BEFORE staging
if int(os.environ.get('MEGA_MOE_DEBUG', '0')):
hs_f32 = hidden_states.to(torch.float32)
nan_frac = torch.isnan(hs_f32).float().mean().item()
nan_per_tok = torch.isnan(hs_f32).any(dim=-1)
bad_tok_ids = nan_per_tok.nonzero(as_tuple=True)[0][:10].tolist()
finite_mask = ~torch.isnan(hs_f32) & ~torch.isinf(hs_f32)
finite_max = hs_f32[finite_mask].abs().max().item() if finite_mask.any() else float('nan')
print(f"[PRE-stage] hidden_states: "
f"nan_frac={nan_frac:.4f} "
f"inf_any={torch.isinf(hs_f32).any().item()} "
f"finite_max={finite_max:.4e} "
f"first_bad_token_idxs={bad_tok_ids} "
f"experts_start_idx={self.experts_start_idx}")
# Quantize activation using the kernel's PyTorch stage_activation
# (same code path the kernel uses for L1→L2 requantization).
# This replaces the broken Triton staging kernel — no more uint32
# pack/unpack, no more Triton tensor indexing issues.
from nvfp4_megamoe_kernel import stage_activation
x_fp4, x_sf = stage_activation(hidden_states)
symm_buffer.x[:num_tokens].copy_(x_fp4)
symm_buffer.x_sf[:num_tokens].copy_(x_sf)
symm_buffer.topk_idx[:num_tokens].copy_(topk_ids)
symm_buffer.topk_weights[:num_tokens].copy_(topk_weights)
# Debug: check staging output
if int(os.environ.get('MEGA_MOE_DEBUG', '0')):
print(f"[MEGA_MOE_DEBUG] After staging: x dtype={symm_buffer.x.dtype} shape={symm_buffer.x.shape}")
print(f"[MEGA_MOE_DEBUG] x_sf dtype={symm_buffer.x_sf.dtype} shape={symm_buffer.x_sf.shape}")
print(f"[MEGA_MOE_DEBUG] topk_idx dtype={symm_buffer.topk_idx.dtype} shape={symm_buffer.topk_idx.shape}")
print(f"[MEGA_MOE_DEBUG] topk_weights dtype={symm_buffer.topk_weights.dtype} shape={symm_buffer.topk_weights.shape}")
# Check for NaN/Inf in the staging output
x_sample = symm_buffer.x[:num_tokens]
sf_sample = symm_buffer.x_sf[:num_tokens]
print(f"[MEGA_MOE_DEBUG] x range: min={x_sample.min().item()} max={x_sample.max().item()}")
if sf_sample.numel() > 0:
print(f"[MEGA_MOE_DEBUG] x_sf range: min={sf_sample.to(torch.float32).min().item()} max={sf_sample.to(torch.float32).max().item()}")
topk_sample = symm_buffer.topk_idx[:num_tokens]
print(f"[MEGA_MOE_DEBUG] topk_idx range: min={topk_sample.min().item()} max={topk_sample.max().item()}")
torch.cuda.synchronize()
print("[MEGA_MOE_DEBUG] Staging CUDA sync OK")
# This method must have been already called during the weight loading phase.
# We call it again here to cover the dummy weight loading case.
self.finalize_weights()
assert self._transformed_l1_weights is not None
assert self._transformed_l2_weights is not None
from nvfp4_megamoe_kernel import nvfp4_mega_moe_full as fp8_nvfp4_mega_moe
# Debug: dump shapes before mega_moe
if int(os.environ.get('MEGA_MOE_DEBUG', '0')):
l1_w, l1_sf = self._transformed_l1_weights
l2_w, l2_sf = self._transformed_l2_weights
print(f"[MEGA_MOE_DEBUG] num_tokens={num_tokens}, hidden={hidden_states.shape[1]}")
print(f"[MEGA_MOE_DEBUG] l1_w: dtype={l1_w.dtype} shape={l1_w.shape} stride={l1_w.stride()}")
print(f"[MEGA_MOE_DEBUG] l1_sf: dtype={l1_sf.dtype} shape={l1_sf.shape} stride={l1_sf.stride()}")
print(f"[MEGA_MOE_DEBUG] l2_w: dtype={l2_w.dtype} shape={l2_w.shape} stride={l2_w.stride()}")
print(f"[MEGA_MOE_DEBUG] l2_sf: dtype={l2_sf.dtype} shape={l2_sf.shape} stride={l2_sf.stride()}")
print(f"[MEGA_MOE_DEBUG] symm_buffer nbytes={symm_buffer.buffer.nbytes} rank={symm_buffer.group.rank()}")
print(f"[MEGA_MOE_DEBUG] num_experts={self.num_experts} topk={topk_ids.shape[1]} max_tokens={self.max_num_tokens}")
print(f"[MEGA_MOE_DEBUG] y: dtype={y.dtype} shape={y.shape}")
# Force CUDA sync to catch any prior async errors
torch.cuda.synchronize()
print("[MEGA_MOE_DEBUG] CUDA sync OK — prior ops clean")
# MEGA_MOE_STATIC: skip the kernel entirely, return zeros
# Tests whether the crash is in the kernel launch or in prior data prep
if int(os.environ.get('MEGA_MOE_STATIC', '0')):
print(f"[MEGA_MOE_STATIC] Skipping fp8_nvfp4_mega_moe, returning zeros")
y.zero_()
return
fp8_nvfp4_mega_moe(
nvfp4_mega_moe_full(
y,
self._transformed_l1_weights,
self._transformed_l2_weights,
@@ -1358,6 +1256,33 @@ class DeepseekV4Model(nn.Module):
("compressor.fused_wkv_wgate", "compressor.wkv", 0),
("compressor.fused_wkv_wgate", "compressor.wgate", 1),
]
# Checkpoint key → model param name substitutions.
# Applied to each (name, weight) pair before matching against
# params_dict. Order matters: longer/more-specific patterns first.
CKPT_KEY_SUBST = {
# self_attn projection names → vLLM attn attribute names
".self_attn.q_a_proj.": ".attn.wq_a.",
".self_attn.q_b_proj.": ".attn.wq_b.",
".self_attn.q_a_norm.": ".attn.q_norm.",
".self_attn.o_a_proj.": ".attn.wo_a.",
".self_attn.o_b_proj.": ".attn.wo_b.",
".self_attn.sinks": ".attn.attn_sink",
".self_attn.kv_proj.": ".attn.wkv.",
".self_attn.kv_norm.": ".attn.kv_norm.",
# Compressor: self_attn.compressor → attn.mla_attn.compressor
".self_attn.compressor.kv_norm.": ".attn.kv_norm.",
".self_attn.compressor.": ".attn.mla_attn.compressor.",
# Compressor projections for stacking (fused_wkv_wgate)
".compressor.kv_proj.": ".compressor.wkv.",
".compressor.gate_proj.": ".compressor.gate.",
# Shared expert projections (stacking into gate_up_proj)
".shared_experts.gate_proj.": ".shared_experts.w1.",
".shared_experts.up_proj.": ".shared_experts.w3.",
# modelopt uses mlp, vllm uses ffn internally
".mlp.": ".ffn.",
}
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
@@ -1372,29 +1297,18 @@ class DeepseekV4Model(nn.Module):
# Pre-compute expert mapping ONCE.
expert_mapping = self.get_expert_mapping()
# Debug: dump incoming checkpoint names for layer 0 attention o_* keys
if os.environ.get('MEGA_MOE_DEBUG'):
_o_sample = []
_o_seen = set()
for _n, _ in weights:
if 'layers.0.self_attn' in _n and 'o_' in _n:
if _n not in _o_seen:
_o_seen.add(_n)
_o_sample.append(_n)
if len(_o_sample) > 20:
break
if _o_sample:
print(f"[LOAD-RAW] incoming layer 0 o_* checkpoint names:")
for _n in _o_sample:
print(f" {_n}")
for name, loaded_weight in weights:
# Debug: trace o_b_proj/wo_b/o_a_proj/wo_a through the loader
if os.environ.get('MEGA_MOE_DEBUG') and ('o_b_proj' in name or 'wo_b' in name or 'o_a_proj' in name or 'wo_a' in name):
print(f"[LOAD-TRACE] candidate name={name!r} "
f"in_params_dict={name in params_dict} "
f"loaded_dtype={loaded_weight.dtype} "
f"loaded_shape={tuple(loaded_weight.shape)}")
# Strip 'model.' prefix from checkpoint keys.
# vLLM's weight iteration yields keys like 'model.layers.0...'
# but named_parameters() on DeepseekV4Model returns 'layers.0...'
if name.startswith("model."):
name = name[len("model."):]
# Apply checkpoint → model name substitutions
for ckpt_pat, model_pat in CKPT_KEY_SUBST.items():
if ckpt_pat in name:
name = name.replace(ckpt_pat, model_pat)
break # first match wins (order matters)
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
@@ -2191,18 +2105,16 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
class DeepseekV4ForCausalLM(nn.Module):
model_cls = DeepseekV4Model
# Default mapper assumes the original FP4-expert checkpoint layout.
# Overridden per-instance in __init__ when expert_dtype != "fp4".
hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4")
# NOTE: We do NOT set hf_to_vllm_mapper here because our custom
# load_weights handles all checkpoint→model name remapping inline.
# If hf_to_vllm_mapper is set, vLLM's AutoWeightsLoader may be invoked
# INSTEAD of our load_weights, silently dropping NVFP4 weight loading.
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
expert_dtype = getattr(config, "expert_dtype", "fp4")
if expert_dtype != "fp4":
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper(expert_dtype)
self.model = self.model_cls(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
@@ -2268,48 +2180,6 @@ class DeepseekV4ForCausalLM(nn.Module):
if os.environ.get('NVFP4_DEBUG_SYNC', '') == '1':
torch.cuda.synchronize()
print("[NVFP4] post-load conversion done, CUDA OK")
# Post-load NaN scale scan — find any scale tensors that are NaN
# after weight loading + post-load conversion
nan_attrs = []
for name, module in self.named_modules():
for attr in ('weight_scale', 'weight_scale_inv', 'weight_scale_2',
'input_scale', 'act_scale'):
if hasattr(module, attr):
t = getattr(module, attr)
if torch.is_tensor(t) and torch.isnan(t.to(torch.float32)).any().item():
nan_attrs.append((name, attr, tuple(t.shape), str(t.dtype)))
if nan_attrs:
print(f"[POST-LOAD] {len(nan_attrs)} NaN scale tensors after loading:")
for n, a, s, d in nan_attrs[:20]:
print(f" {n}.{a} shape={s} dtype={d}")
else:
print("[POST-LOAD] No NaN scale tensors found — scales are clean")
# Dump layer 0 attn keys: model state_dict vs checkpoint
if int(os.environ.get('NVFP4_DEBUG', '0')):
sd = self.state_dict()
layer0_keys = sorted(k for k in sd if 'layers.0.attn' in k or 'layers.0.self_attn' in k)
print("=== MODEL state_dict (layer 0 attn): ===")
for k in layer0_keys:
t = sd[k]
nz = (t != 0).any().item() if torch.is_tensor(t) else '?'
print(f" {k} shape={tuple(t.shape)} dtype={t.dtype} any_nonzero={nz}")
from safetensors import safe_open
import glob
ckpt_files = sorted(glob.glob(os.path.join('/model', '*.safetensors')))
print(f"\n=== CHECKPOINT files: {len(ckpt_files)} shards ===")
seen = set()
for f in ckpt_files[:5]:
with safe_open(f, framework='pt') as h:
for k in h.keys():
if ('layers.0.' in k) and ('attn' in k or 'self_attn' in k):
if k not in seen:
seen.add(k)
t = h.get_tensor(k)
print(f" {k} shape={tuple(t.shape)} dtype={t.dtype}")
return loaded_params
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:

View File

@@ -285,19 +285,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
import os
layer_idx = getattr(self, 'layer_idx', '?')
_debug = int(os.environ.get('MEGA_MOE_DEBUG', '0'))
# NaN-trace: check attention inputs
if _debug:
hs_f32 = hidden_states.to(torch.float32)
nf = torch.isnan(hs_f32).float().mean().item()
if nf > 0:
print(f"[NAN @ L{layer_idx} attn-in/hidden_states] nan_frac={nf:.4f} "
f"inf_any={torch.isinf(hs_f32).any().item()} "
f"shape={tuple(hidden_states.shape)} dtype={hidden_states.dtype}")
# Pre-allocate attention output with FlashMLA-padded head count.
# The op writes into `o_padded`; we slice to n_local_heads after.
num_tokens = hidden_states.shape[0]
@@ -316,15 +303,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
)
o = o_padded[:, : self.n_local_heads, :]
# NaN-trace: check attention output
if _debug:
o_f32 = o.to(torch.float32)
nf = torch.isnan(o_f32).float().mean().item()
if nf > 0:
print(f"[NAN @ L{layer_idx} attn-out/o_sliced] nan_frac={nf:.4f} "
f"inf_any={torch.isinf(o_f32).any().item()} "
f"shape={tuple(o.shape)} dtype={o.dtype}")
# O projection: inverse RoPE + FP8 quant + einsum + wo_b
o_fp8, o_scale = fused_inv_rope_fp8_quant(
o,
@@ -337,19 +315,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
tma_aligned_scales=self._tma_aligned_scales,
)
# NaN-trace: check rope+quant output
if _debug:
of32 = o_fp8.to(torch.float32)
nf = torch.isnan(of32).float().mean().item()
if nf > 0:
print(f"[NAN @ L{layer_idx} rope-quant/o_fp8] nan_frac={nf:.4f} "
f"shape={tuple(o_fp8.shape)} dtype={o_fp8.dtype}")
sf32 = o_scale.to(torch.float32)
nf2 = torch.isnan(sf32).float().mean().item()
if nf2 > 0:
print(f"[NAN @ L{layer_idx} rope-quant/o_scale] nan_frac={nf2:.4f} "
f"shape={tuple(o_scale.shape)} dtype={o_scale.dtype}")
wo_a_fp8 = self.wo_a.weight
wo_a_scale = self.wo_a.weight_scale_inv
@@ -368,55 +333,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
list(self._einsum_recipe),
)
# NaN-trace: check wo_a einsum output
if _debug:
zf32 = z.to(torch.float32)
nf = torch.isnan(zf32).float().mean().item()
if nf > 0:
print(f"[NAN @ L{layer_idx} wo_a-einsum/z] nan_frac={nf:.4f} "
f"inf_any={torch.isinf(zf32).any().item()} "
f"shape={tuple(z.shape)} dtype={z.dtype}")
# wo_b inspection — dump all tensor attributes once
if _debug and not hasattr(self, '_wo_b_inspected'):
self._wo_b_inspected = True
layer_idx = getattr(self, 'layer_idx', None) or getattr(self, 'layer_name', '?')
print(f"[wo_b-inspect L{layer_idx}] type={type(self.wo_b).__name__}")
print(f"[wo_b-inspect L{layer_idx}] z (input) nan_frac="
f"{torch.isnan(z.to(torch.float32)).float().mean().item():.4f} "
f"abs_max={z.to(torch.float32).abs().max().item():.4e}")
for attr in dir(self.wo_b):
if attr.startswith('_'):
continue
try:
v = getattr(self.wo_b, attr)
except Exception:
continue
if torch.is_tensor(v):
vf = v.to(torch.float32) if v.dtype not in (torch.float32,) else v
nf = torch.isnan(vf).float().mean().item()
inf = torch.isinf(vf).any().item()
try:
vmin = vf.min().item()
vmax = vf.max().item()
except Exception:
vmin = vmax = float('nan')
print(f"[wo_b-inspect L{layer_idx}] {attr}: "
f"dtype={v.dtype} shape={tuple(v.shape)} "
f"nan_frac={nf:.4f} inf={inf} min={vmin:.4e} max={vmax:.4e}")
result = self.wo_b(z.flatten(1))
# NaN-trace: check final wo_b output
if _debug:
rf32 = result.to(torch.float32)
nf = torch.isnan(rf32).float().mean().item()
if nf > 0:
print(f"[NAN @ L{layer_idx} wo_b/result] nan_frac={nf:.4f} "
f"inf_any={torch.isinf(rf32).any().item()} "
f"shape={tuple(result.shape)} dtype={result.dtype}")
return result
return self.wo_b(z.flatten(1))
def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
assert self.aux_stream_list is not None