fix: tqdm over MoE layer warmup, compile every layer, no print spam

The outer loop tqdm now covers the full finalize_weights + warmup for
each MoE layer. CuTeDSL caches by (M,N,K) so every layer shape gets
compiled during warmup — no RPC timeouts during inference.

  (JIT compile)NVFP4 MoE layers:  50%|██████████░░░░░░░░░░| 31/61
This commit is contained in:
2026-05-16 05:21:11 +00:00
parent 3838561c19
commit 4d4cfa6b28

View File

@@ -221,7 +221,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
It handles NVFP4 natively with full Blackwell pipeline overlap (TMA → MMA → Epilogue).
This replaces the broken C++ CUTLASS kernel (see README.md for the full story).
"""
_cutedsl_compiled: bool = False
_cutedsl_runner: 'CuTeDSLMoERunner | None' = None
# NVFP4 E2M1 lookup table (positive values, sign from bit 3)
E2M1_LUT = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
@@ -330,7 +330,6 @@ class DeepseekV4MegaMoEExperts(nn.Module):
set_weight_attrs(self.w2_input_scale, weight_attrs)
self._cutedsl_runner = None
self._cutedsl_compiled = False
# Register in the static forward context so the custom-op wrapper
# can look up this module by name from within a torch.compile graph.
@@ -506,18 +505,15 @@ class DeepseekV4MegaMoEExperts(nn.Module):
self.w2_input_scale = None
# Warm up the CuTeDSL kernel (JIT compiles MLIR→PTX on first call).
# CuTeDSL caches compiled kernels by (M, N, K) shape, so different
# layer shapes may trigger additional compiles. We only print the
# compile message once to avoid spam (61 layers × 8 ranks).
# CuTeDSL caches by (M, N, K) shape different shapes trigger new
# compiles. Running warmup on every layer ensures all kernels are
# compiled before inference, preventing vLLM RPC timeouts.
# The MMA tiler needs >= 128 tokens; using 128 for the warmup.
if not DeepseekV4MegaMoEExperts._cutedsl_compiled:
DeepseekV4MegaMoEExperts._cutedsl_compiled = True
print(" Compiling CuTeDSL NVFP4 MegaMoE kernels (one-time JIT, ~1-2 min)...", flush=True)
device = self._cutedsl_runner.l1_fp4[0].device
dummy_hidden = torch.randn(128, self.hidden_size, dtype=torch.bfloat16, device=device)
dummy_ids = torch.zeros(128, 1, dtype=torch.int32, device=device)
dummy_weights = torch.ones(128, 1, dtype=torch.float32, device=device)
try:
device = self._cutedsl_runner.l1_fp4[0].device
dummy_hidden = torch.randn(128, self.hidden_size, dtype=torch.bfloat16, device=device)
dummy_ids = torch.zeros(128, 1, dtype=torch.int32, device=device)
dummy_weights = torch.ones(128, 1, dtype=torch.float32, device=device)
self._cutedsl_runner.run(dummy_hidden, dummy_weights, dummy_ids, expert_indices=[0])
except Exception as exc:
print(f" CuTeDSL warmup FAILED: {exc}", flush=True)
@@ -1607,7 +1603,7 @@ class DeepseekV4Model(nn.Module):
def finalize_mega_moe_weights(self) -> None:
from tqdm import tqdm
layers = list(islice(self.layers, self.start_layer, self.end_layer))
for layer in tqdm(layers, desc=" (view-cast)uint8→NVFP4 experts", unit="layer"):
for layer in tqdm(layers, desc=" (JIT compile)NVFP4 MoE layers", unit="layer"):
layer.ffn.finalize_mega_moe_weights()
def _convert_nvfp4_post_load(self):