fix: tqdm over MoE layer warmup, compile every layer, no print spam
The outer loop tqdm now covers the full finalize_weights + warmup for each MoE layer. CuTeDSL caches by (M,N,K) so every layer shape gets compiled during warmup — no RPC timeouts during inference. (JIT compile)NVFP4 MoE layers: 50%|██████████░░░░░░░░░░| 31/61
This commit is contained in:
@@ -221,7 +221,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
It handles NVFP4 natively with full Blackwell pipeline overlap (TMA → MMA → Epilogue).
|
||||
This replaces the broken C++ CUTLASS kernel (see README.md for the full story).
|
||||
"""
|
||||
_cutedsl_compiled: bool = False
|
||||
_cutedsl_runner: 'CuTeDSLMoERunner | None' = None
|
||||
|
||||
# NVFP4 E2M1 lookup table (positive values, sign from bit 3)
|
||||
E2M1_LUT = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
||||
@@ -330,7 +330,6 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
set_weight_attrs(self.w2_input_scale, weight_attrs)
|
||||
|
||||
self._cutedsl_runner = None
|
||||
self._cutedsl_compiled = False
|
||||
|
||||
# Register in the static forward context so the custom-op wrapper
|
||||
# can look up this module by name from within a torch.compile graph.
|
||||
@@ -506,18 +505,15 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
self.w2_input_scale = None
|
||||
|
||||
# Warm up the CuTeDSL kernel (JIT compiles MLIR→PTX on first call).
|
||||
# CuTeDSL caches compiled kernels by (M, N, K) shape, so different
|
||||
# layer shapes may trigger additional compiles. We only print the
|
||||
# compile message once to avoid spam (61 layers × 8 ranks).
|
||||
# CuTeDSL caches by (M, N, K) shape — different shapes trigger new
|
||||
# compiles. Running warmup on every layer ensures all kernels are
|
||||
# compiled before inference, preventing vLLM RPC timeouts.
|
||||
# The MMA tiler needs >= 128 tokens; using 128 for the warmup.
|
||||
if not DeepseekV4MegaMoEExperts._cutedsl_compiled:
|
||||
DeepseekV4MegaMoEExperts._cutedsl_compiled = True
|
||||
print(" Compiling CuTeDSL NVFP4 MegaMoE kernels (one-time JIT, ~1-2 min)...", flush=True)
|
||||
device = self._cutedsl_runner.l1_fp4[0].device
|
||||
dummy_hidden = torch.randn(128, self.hidden_size, dtype=torch.bfloat16, device=device)
|
||||
dummy_ids = torch.zeros(128, 1, dtype=torch.int32, device=device)
|
||||
dummy_weights = torch.ones(128, 1, dtype=torch.float32, device=device)
|
||||
try:
|
||||
device = self._cutedsl_runner.l1_fp4[0].device
|
||||
dummy_hidden = torch.randn(128, self.hidden_size, dtype=torch.bfloat16, device=device)
|
||||
dummy_ids = torch.zeros(128, 1, dtype=torch.int32, device=device)
|
||||
dummy_weights = torch.ones(128, 1, dtype=torch.float32, device=device)
|
||||
self._cutedsl_runner.run(dummy_hidden, dummy_weights, dummy_ids, expert_indices=[0])
|
||||
except Exception as exc:
|
||||
print(f" CuTeDSL warmup FAILED: {exc}", flush=True)
|
||||
@@ -1607,7 +1603,7 @@ class DeepseekV4Model(nn.Module):
|
||||
def finalize_mega_moe_weights(self) -> None:
|
||||
from tqdm import tqdm
|
||||
layers = list(islice(self.layers, self.start_layer, self.end_layer))
|
||||
for layer in tqdm(layers, desc=" (view-cast)uint8→NVFP4 experts", unit="layer"):
|
||||
for layer in tqdm(layers, desc=" (JIT compile)NVFP4 MoE layers", unit="layer"):
|
||||
layer.ffn.finalize_mega_moe_weights()
|
||||
|
||||
def _convert_nvfp4_post_load(self):
|
||||
|
||||
Reference in New Issue
Block a user