clean up and possible big fix
This commit is contained in:
@@ -525,130 +525,28 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
fast_math: bool,
|
||||
) -> None:
|
||||
import os
|
||||
import nvfp4_megamoe_kernel as deep_gemm
|
||||
|
||||
# One-time state dump: model params vs checkpoint keys for layer 0 attn
|
||||
if int(os.environ.get('MEGA_MOE_DEBUG', '0')) and not getattr(self, '_state_dump_done', False):
|
||||
self._state_dump_done = True
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
import gc
|
||||
models = [o for o in gc.get_objects()
|
||||
if type(o).__name__ in ('DeepseekV4Model', 'DeepseekV4ForCausalLM')]
|
||||
if models:
|
||||
m = models[0]
|
||||
sd = dict(m.named_parameters())
|
||||
l0 = sorted(k for k in sd if 'layers.0.' in k and ('attn' in k or 'self_attn' in k))
|
||||
print("=== MODEL params (layer 0 attn) ===")
|
||||
for k in l0:
|
||||
t = sd[k]
|
||||
nz = (t != 0).any().item()
|
||||
print(f" {k}\n shape={tuple(t.shape)} dtype={t.dtype} any_nonzero={nz}")
|
||||
|
||||
from safetensors import safe_open
|
||||
import glob, json
|
||||
model_dir = '/model'
|
||||
idx_path = os.path.join(model_dir, 'model.safetensors.index.json')
|
||||
if os.path.exists(idx_path):
|
||||
with open(idx_path) as f:
|
||||
idx = json.load(f)
|
||||
l0_ckpt = sorted(k for k in idx['weight_map']
|
||||
if 'layers.0.' in k and ('attn' in k or 'self_attn' in k))
|
||||
print(f"\n=== CHECKPOINT keys (layer 0 attn) ===")
|
||||
shard_of = idx['weight_map']
|
||||
shards_needed = sorted(set(shard_of[k] for k in l0_ckpt))
|
||||
shapes = {}
|
||||
for shard_name in shards_needed:
|
||||
shard_path = os.path.join(model_dir, shard_name)
|
||||
with safe_open(shard_path, framework='pt') as h:
|
||||
for k in l0_ckpt:
|
||||
if shard_of[k] == shard_name:
|
||||
shapes[k] = (tuple(h.get_tensor(k).shape), h.get_tensor(k).dtype)
|
||||
for k in l0_ckpt:
|
||||
s, d = shapes.get(k, ('?', '?'))
|
||||
print(f" {k}\n shape={s} dtype={d}")
|
||||
else:
|
||||
print(f"[state-dump] No index.json at {idx_path}")
|
||||
from nvfp4_megamoe_kernel import stage_activation, nvfp4_mega_moe_full
|
||||
|
||||
symm_buffer = self.get_symm_buffer()
|
||||
symm_buffer.experts_start_idx = self.experts_start_idx
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
# NaN-trace: check hidden_states BEFORE staging
|
||||
if int(os.environ.get('MEGA_MOE_DEBUG', '0')):
|
||||
hs_f32 = hidden_states.to(torch.float32)
|
||||
nan_frac = torch.isnan(hs_f32).float().mean().item()
|
||||
nan_per_tok = torch.isnan(hs_f32).any(dim=-1)
|
||||
bad_tok_ids = nan_per_tok.nonzero(as_tuple=True)[0][:10].tolist()
|
||||
finite_mask = ~torch.isnan(hs_f32) & ~torch.isinf(hs_f32)
|
||||
finite_max = hs_f32[finite_mask].abs().max().item() if finite_mask.any() else float('nan')
|
||||
print(f"[PRE-stage] hidden_states: "
|
||||
f"nan_frac={nan_frac:.4f} "
|
||||
f"inf_any={torch.isinf(hs_f32).any().item()} "
|
||||
f"finite_max={finite_max:.4e} "
|
||||
f"first_bad_token_idxs={bad_tok_ids} "
|
||||
f"experts_start_idx={self.experts_start_idx}")
|
||||
|
||||
# Quantize activation using the kernel's PyTorch stage_activation
|
||||
# (same code path the kernel uses for L1→L2 requantization).
|
||||
# This replaces the broken Triton staging kernel — no more uint32
|
||||
# pack/unpack, no more Triton tensor indexing issues.
|
||||
from nvfp4_megamoe_kernel import stage_activation
|
||||
x_fp4, x_sf = stage_activation(hidden_states)
|
||||
symm_buffer.x[:num_tokens].copy_(x_fp4)
|
||||
symm_buffer.x_sf[:num_tokens].copy_(x_sf)
|
||||
symm_buffer.topk_idx[:num_tokens].copy_(topk_ids)
|
||||
symm_buffer.topk_weights[:num_tokens].copy_(topk_weights)
|
||||
|
||||
# Debug: check staging output
|
||||
if int(os.environ.get('MEGA_MOE_DEBUG', '0')):
|
||||
print(f"[MEGA_MOE_DEBUG] After staging: x dtype={symm_buffer.x.dtype} shape={symm_buffer.x.shape}")
|
||||
print(f"[MEGA_MOE_DEBUG] x_sf dtype={symm_buffer.x_sf.dtype} shape={symm_buffer.x_sf.shape}")
|
||||
print(f"[MEGA_MOE_DEBUG] topk_idx dtype={symm_buffer.topk_idx.dtype} shape={symm_buffer.topk_idx.shape}")
|
||||
print(f"[MEGA_MOE_DEBUG] topk_weights dtype={symm_buffer.topk_weights.dtype} shape={symm_buffer.topk_weights.shape}")
|
||||
# Check for NaN/Inf in the staging output
|
||||
x_sample = symm_buffer.x[:num_tokens]
|
||||
sf_sample = symm_buffer.x_sf[:num_tokens]
|
||||
print(f"[MEGA_MOE_DEBUG] x range: min={x_sample.min().item()} max={x_sample.max().item()}")
|
||||
if sf_sample.numel() > 0:
|
||||
print(f"[MEGA_MOE_DEBUG] x_sf range: min={sf_sample.to(torch.float32).min().item()} max={sf_sample.to(torch.float32).max().item()}")
|
||||
topk_sample = symm_buffer.topk_idx[:num_tokens]
|
||||
print(f"[MEGA_MOE_DEBUG] topk_idx range: min={topk_sample.min().item()} max={topk_sample.max().item()}")
|
||||
torch.cuda.synchronize()
|
||||
print("[MEGA_MOE_DEBUG] Staging CUDA sync OK")
|
||||
|
||||
# This method must have been already called during the weight loading phase.
|
||||
# We call it again here to cover the dummy weight loading case.
|
||||
self.finalize_weights()
|
||||
|
||||
assert self._transformed_l1_weights is not None
|
||||
assert self._transformed_l2_weights is not None
|
||||
from nvfp4_megamoe_kernel import nvfp4_mega_moe_full as fp8_nvfp4_mega_moe
|
||||
|
||||
# Debug: dump shapes before mega_moe
|
||||
if int(os.environ.get('MEGA_MOE_DEBUG', '0')):
|
||||
l1_w, l1_sf = self._transformed_l1_weights
|
||||
l2_w, l2_sf = self._transformed_l2_weights
|
||||
print(f"[MEGA_MOE_DEBUG] num_tokens={num_tokens}, hidden={hidden_states.shape[1]}")
|
||||
print(f"[MEGA_MOE_DEBUG] l1_w: dtype={l1_w.dtype} shape={l1_w.shape} stride={l1_w.stride()}")
|
||||
print(f"[MEGA_MOE_DEBUG] l1_sf: dtype={l1_sf.dtype} shape={l1_sf.shape} stride={l1_sf.stride()}")
|
||||
print(f"[MEGA_MOE_DEBUG] l2_w: dtype={l2_w.dtype} shape={l2_w.shape} stride={l2_w.stride()}")
|
||||
print(f"[MEGA_MOE_DEBUG] l2_sf: dtype={l2_sf.dtype} shape={l2_sf.shape} stride={l2_sf.stride()}")
|
||||
print(f"[MEGA_MOE_DEBUG] symm_buffer nbytes={symm_buffer.buffer.nbytes} rank={symm_buffer.group.rank()}")
|
||||
print(f"[MEGA_MOE_DEBUG] num_experts={self.num_experts} topk={topk_ids.shape[1]} max_tokens={self.max_num_tokens}")
|
||||
print(f"[MEGA_MOE_DEBUG] y: dtype={y.dtype} shape={y.shape}")
|
||||
# Force CUDA sync to catch any prior async errors
|
||||
torch.cuda.synchronize()
|
||||
print("[MEGA_MOE_DEBUG] CUDA sync OK — prior ops clean")
|
||||
|
||||
# MEGA_MOE_STATIC: skip the kernel entirely, return zeros
|
||||
# Tests whether the crash is in the kernel launch or in prior data prep
|
||||
if int(os.environ.get('MEGA_MOE_STATIC', '0')):
|
||||
print(f"[MEGA_MOE_STATIC] Skipping fp8_nvfp4_mega_moe, returning zeros")
|
||||
y.zero_()
|
||||
return
|
||||
|
||||
fp8_nvfp4_mega_moe(
|
||||
nvfp4_mega_moe_full(
|
||||
y,
|
||||
self._transformed_l1_weights,
|
||||
self._transformed_l2_weights,
|
||||
@@ -1358,6 +1256,33 @@ class DeepseekV4Model(nn.Module):
|
||||
("compressor.fused_wkv_wgate", "compressor.wkv", 0),
|
||||
("compressor.fused_wkv_wgate", "compressor.wgate", 1),
|
||||
]
|
||||
|
||||
# Checkpoint key → model param name substitutions.
|
||||
# Applied to each (name, weight) pair before matching against
|
||||
# params_dict. Order matters: longer/more-specific patterns first.
|
||||
CKPT_KEY_SUBST = {
|
||||
# self_attn projection names → vLLM attn attribute names
|
||||
".self_attn.q_a_proj.": ".attn.wq_a.",
|
||||
".self_attn.q_b_proj.": ".attn.wq_b.",
|
||||
".self_attn.q_a_norm.": ".attn.q_norm.",
|
||||
".self_attn.o_a_proj.": ".attn.wo_a.",
|
||||
".self_attn.o_b_proj.": ".attn.wo_b.",
|
||||
".self_attn.sinks": ".attn.attn_sink",
|
||||
".self_attn.kv_proj.": ".attn.wkv.",
|
||||
".self_attn.kv_norm.": ".attn.kv_norm.",
|
||||
# Compressor: self_attn.compressor → attn.mla_attn.compressor
|
||||
".self_attn.compressor.kv_norm.": ".attn.kv_norm.",
|
||||
".self_attn.compressor.": ".attn.mla_attn.compressor.",
|
||||
# Compressor projections for stacking (fused_wkv_wgate)
|
||||
".compressor.kv_proj.": ".compressor.wkv.",
|
||||
".compressor.gate_proj.": ".compressor.gate.",
|
||||
# Shared expert projections (stacking into gate_up_proj)
|
||||
".shared_experts.gate_proj.": ".shared_experts.w1.",
|
||||
".shared_experts.up_proj.": ".shared_experts.w3.",
|
||||
# modelopt uses mlp, vllm uses ffn internally
|
||||
".mlp.": ".ffn.",
|
||||
}
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
@@ -1372,29 +1297,18 @@ class DeepseekV4Model(nn.Module):
|
||||
# Pre-compute expert mapping ONCE.
|
||||
expert_mapping = self.get_expert_mapping()
|
||||
|
||||
# Debug: dump incoming checkpoint names for layer 0 attention o_* keys
|
||||
if os.environ.get('MEGA_MOE_DEBUG'):
|
||||
_o_sample = []
|
||||
_o_seen = set()
|
||||
for _n, _ in weights:
|
||||
if 'layers.0.self_attn' in _n and 'o_' in _n:
|
||||
if _n not in _o_seen:
|
||||
_o_seen.add(_n)
|
||||
_o_sample.append(_n)
|
||||
if len(_o_sample) > 20:
|
||||
break
|
||||
if _o_sample:
|
||||
print(f"[LOAD-RAW] incoming layer 0 o_* checkpoint names:")
|
||||
for _n in _o_sample:
|
||||
print(f" {_n}")
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# Debug: trace o_b_proj/wo_b/o_a_proj/wo_a through the loader
|
||||
if os.environ.get('MEGA_MOE_DEBUG') and ('o_b_proj' in name or 'wo_b' in name or 'o_a_proj' in name or 'wo_a' in name):
|
||||
print(f"[LOAD-TRACE] candidate name={name!r} "
|
||||
f"in_params_dict={name in params_dict} "
|
||||
f"loaded_dtype={loaded_weight.dtype} "
|
||||
f"loaded_shape={tuple(loaded_weight.shape)}")
|
||||
# Strip 'model.' prefix from checkpoint keys.
|
||||
# vLLM's weight iteration yields keys like 'model.layers.0...'
|
||||
# but named_parameters() on DeepseekV4Model returns 'layers.0...'
|
||||
if name.startswith("model."):
|
||||
name = name[len("model."):]
|
||||
|
||||
# Apply checkpoint → model name substitutions
|
||||
for ckpt_pat, model_pat in CKPT_KEY_SUBST.items():
|
||||
if ckpt_pat in name:
|
||||
name = name.replace(ckpt_pat, model_pat)
|
||||
break # first match wins (order matters)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
@@ -2191,18 +2105,16 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
|
||||
class DeepseekV4ForCausalLM(nn.Module):
|
||||
model_cls = DeepseekV4Model
|
||||
|
||||
# Default mapper assumes the original FP4-expert checkpoint layout.
|
||||
# Overridden per-instance in __init__ when expert_dtype != "fp4".
|
||||
hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4")
|
||||
# NOTE: We do NOT set hf_to_vllm_mapper here because our custom
|
||||
# load_weights handles all checkpoint→model name remapping inline.
|
||||
# If hf_to_vllm_mapper is set, vLLM's AutoWeightsLoader may be invoked
|
||||
# INSTEAD of our load_weights, silently dropping NVFP4 weight loading.
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.config = config
|
||||
expert_dtype = getattr(config, "expert_dtype", "fp4")
|
||||
if expert_dtype != "fp4":
|
||||
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper(expert_dtype)
|
||||
|
||||
self.model = self.model_cls(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
@@ -2268,48 +2180,6 @@ class DeepseekV4ForCausalLM(nn.Module):
|
||||
if os.environ.get('NVFP4_DEBUG_SYNC', '') == '1':
|
||||
torch.cuda.synchronize()
|
||||
print("[NVFP4] post-load conversion done, CUDA OK")
|
||||
|
||||
# Post-load NaN scale scan — find any scale tensors that are NaN
|
||||
# after weight loading + post-load conversion
|
||||
nan_attrs = []
|
||||
for name, module in self.named_modules():
|
||||
for attr in ('weight_scale', 'weight_scale_inv', 'weight_scale_2',
|
||||
'input_scale', 'act_scale'):
|
||||
if hasattr(module, attr):
|
||||
t = getattr(module, attr)
|
||||
if torch.is_tensor(t) and torch.isnan(t.to(torch.float32)).any().item():
|
||||
nan_attrs.append((name, attr, tuple(t.shape), str(t.dtype)))
|
||||
if nan_attrs:
|
||||
print(f"[POST-LOAD] {len(nan_attrs)} NaN scale tensors after loading:")
|
||||
for n, a, s, d in nan_attrs[:20]:
|
||||
print(f" {n}.{a} shape={s} dtype={d}")
|
||||
else:
|
||||
print("[POST-LOAD] No NaN scale tensors found — scales are clean")
|
||||
|
||||
# Dump layer 0 attn keys: model state_dict vs checkpoint
|
||||
if int(os.environ.get('NVFP4_DEBUG', '0')):
|
||||
sd = self.state_dict()
|
||||
layer0_keys = sorted(k for k in sd if 'layers.0.attn' in k or 'layers.0.self_attn' in k)
|
||||
print("=== MODEL state_dict (layer 0 attn): ===")
|
||||
for k in layer0_keys:
|
||||
t = sd[k]
|
||||
nz = (t != 0).any().item() if torch.is_tensor(t) else '?'
|
||||
print(f" {k} shape={tuple(t.shape)} dtype={t.dtype} any_nonzero={nz}")
|
||||
|
||||
from safetensors import safe_open
|
||||
import glob
|
||||
ckpt_files = sorted(glob.glob(os.path.join('/model', '*.safetensors')))
|
||||
print(f"\n=== CHECKPOINT files: {len(ckpt_files)} shards ===")
|
||||
seen = set()
|
||||
for f in ckpt_files[:5]:
|
||||
with safe_open(f, framework='pt') as h:
|
||||
for k in h.keys():
|
||||
if ('layers.0.' in k) and ('attn' in k or 'self_attn' in k):
|
||||
if k not in seen:
|
||||
seen.add(k)
|
||||
t = h.get_tensor(k)
|
||||
print(f" {k} shape={tuple(t.shape)} dtype={t.dtype}")
|
||||
|
||||
return loaded_params
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
|
||||
@@ -285,19 +285,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
hidden_states: torch.Tensor,
|
||||
llama_4_scaling: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
import os
|
||||
layer_idx = getattr(self, 'layer_idx', '?')
|
||||
_debug = int(os.environ.get('MEGA_MOE_DEBUG', '0'))
|
||||
|
||||
# NaN-trace: check attention inputs
|
||||
if _debug:
|
||||
hs_f32 = hidden_states.to(torch.float32)
|
||||
nf = torch.isnan(hs_f32).float().mean().item()
|
||||
if nf > 0:
|
||||
print(f"[NAN @ L{layer_idx} attn-in/hidden_states] nan_frac={nf:.4f} "
|
||||
f"inf_any={torch.isinf(hs_f32).any().item()} "
|
||||
f"shape={tuple(hidden_states.shape)} dtype={hidden_states.dtype}")
|
||||
|
||||
# Pre-allocate attention output with FlashMLA-padded head count.
|
||||
# The op writes into `o_padded`; we slice to n_local_heads after.
|
||||
num_tokens = hidden_states.shape[0]
|
||||
@@ -316,15 +303,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
)
|
||||
o = o_padded[:, : self.n_local_heads, :]
|
||||
|
||||
# NaN-trace: check attention output
|
||||
if _debug:
|
||||
o_f32 = o.to(torch.float32)
|
||||
nf = torch.isnan(o_f32).float().mean().item()
|
||||
if nf > 0:
|
||||
print(f"[NAN @ L{layer_idx} attn-out/o_sliced] nan_frac={nf:.4f} "
|
||||
f"inf_any={torch.isinf(o_f32).any().item()} "
|
||||
f"shape={tuple(o.shape)} dtype={o.dtype}")
|
||||
|
||||
# O projection: inverse RoPE + FP8 quant + einsum + wo_b
|
||||
o_fp8, o_scale = fused_inv_rope_fp8_quant(
|
||||
o,
|
||||
@@ -337,19 +315,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
tma_aligned_scales=self._tma_aligned_scales,
|
||||
)
|
||||
|
||||
# NaN-trace: check rope+quant output
|
||||
if _debug:
|
||||
of32 = o_fp8.to(torch.float32)
|
||||
nf = torch.isnan(of32).float().mean().item()
|
||||
if nf > 0:
|
||||
print(f"[NAN @ L{layer_idx} rope-quant/o_fp8] nan_frac={nf:.4f} "
|
||||
f"shape={tuple(o_fp8.shape)} dtype={o_fp8.dtype}")
|
||||
sf32 = o_scale.to(torch.float32)
|
||||
nf2 = torch.isnan(sf32).float().mean().item()
|
||||
if nf2 > 0:
|
||||
print(f"[NAN @ L{layer_idx} rope-quant/o_scale] nan_frac={nf2:.4f} "
|
||||
f"shape={tuple(o_scale.shape)} dtype={o_scale.dtype}")
|
||||
|
||||
wo_a_fp8 = self.wo_a.weight
|
||||
wo_a_scale = self.wo_a.weight_scale_inv
|
||||
|
||||
@@ -368,55 +333,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
list(self._einsum_recipe),
|
||||
)
|
||||
|
||||
# NaN-trace: check wo_a einsum output
|
||||
if _debug:
|
||||
zf32 = z.to(torch.float32)
|
||||
nf = torch.isnan(zf32).float().mean().item()
|
||||
if nf > 0:
|
||||
print(f"[NAN @ L{layer_idx} wo_a-einsum/z] nan_frac={nf:.4f} "
|
||||
f"inf_any={torch.isinf(zf32).any().item()} "
|
||||
f"shape={tuple(z.shape)} dtype={z.dtype}")
|
||||
|
||||
# wo_b inspection — dump all tensor attributes once
|
||||
if _debug and not hasattr(self, '_wo_b_inspected'):
|
||||
self._wo_b_inspected = True
|
||||
layer_idx = getattr(self, 'layer_idx', None) or getattr(self, 'layer_name', '?')
|
||||
print(f"[wo_b-inspect L{layer_idx}] type={type(self.wo_b).__name__}")
|
||||
print(f"[wo_b-inspect L{layer_idx}] z (input) nan_frac="
|
||||
f"{torch.isnan(z.to(torch.float32)).float().mean().item():.4f} "
|
||||
f"abs_max={z.to(torch.float32).abs().max().item():.4e}")
|
||||
for attr in dir(self.wo_b):
|
||||
if attr.startswith('_'):
|
||||
continue
|
||||
try:
|
||||
v = getattr(self.wo_b, attr)
|
||||
except Exception:
|
||||
continue
|
||||
if torch.is_tensor(v):
|
||||
vf = v.to(torch.float32) if v.dtype not in (torch.float32,) else v
|
||||
nf = torch.isnan(vf).float().mean().item()
|
||||
inf = torch.isinf(vf).any().item()
|
||||
try:
|
||||
vmin = vf.min().item()
|
||||
vmax = vf.max().item()
|
||||
except Exception:
|
||||
vmin = vmax = float('nan')
|
||||
print(f"[wo_b-inspect L{layer_idx}] {attr}: "
|
||||
f"dtype={v.dtype} shape={tuple(v.shape)} "
|
||||
f"nan_frac={nf:.4f} inf={inf} min={vmin:.4e} max={vmax:.4e}")
|
||||
|
||||
result = self.wo_b(z.flatten(1))
|
||||
|
||||
# NaN-trace: check final wo_b output
|
||||
if _debug:
|
||||
rf32 = result.to(torch.float32)
|
||||
nf = torch.isnan(rf32).float().mean().item()
|
||||
if nf > 0:
|
||||
print(f"[NAN @ L{layer_idx} wo_b/result] nan_frac={nf:.4f} "
|
||||
f"inf_any={torch.isinf(rf32).any().item()} "
|
||||
f"shape={tuple(result.shape)} dtype={result.dtype}")
|
||||
|
||||
return result
|
||||
return self.wo_b(z.flatten(1))
|
||||
|
||||
def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
|
||||
assert self.aux_stream_list is not None
|
||||
|
||||
Reference in New Issue
Block a user