First call: cute.compile() with real tensors (warmup).
Subsequent calls: just invoke compiled() with new CuTe views.
No cute.compile() in the forward path = cudagraph-safe.
The CuTeDSL kernel's TMA descriptors are bound to the
compilation-time tensor addresses. Caching the compiled kernel
and reusing it with different tensor allocations produces wrong
memory access patterns (cosine 0.5 instead of 0.99).
Fresh compilation is proven correct (cosine 0.989). We can
optimize later with proper TMA descriptor reinitialization.
quantize_to_nvfp4() only packs the last dimension, but for weight
matrices (K, N), K is the packed dimension. The weight quantizer
reshapes (k_blocks, block_size, N) and computes block scales along
the K block dimension. This was accidentally replaced with a simple
delegation to quantize_to_nvfp4, producing wrong tensor shapes.
The kernel's TMA descriptors are sized from compilation-time shapes.
Dummy 256x256 caused wrong memory access for real 3584x6144 data.
Now compiles with actual runtime tensors on first use, cached by
(num_experts, K, N). Compilation happens once during warmup.
Forward call remains cudagraph-safe.
The compiled kernel's TMA descriptors are sized based on compilation
shapes. Using dummy 256x256 shapes caused wrong memory access patterns
for the real 3584x6144 data. Now uses actual K_packed and N_packed
from the runtime tensors.
float4_e2m1fn_x2 packs 2 values per byte along K, not N.
The GEMM output N dimension is the logical N from mat_b.shape[2],
not 2x packed. Previous n_dim*2 was wrong — it accidentally worked
in the test because intermediate_size*2 == 2*intermediate_size.
Real model with N=9216 exposed the bug.
torch.tensor() creates on CPU then copies to CUDA, which is forbidden
during cudagraph capture. new_tensor() creates directly on the
source tensor's device.
- Removed all [:total_slots] dynamic slicing with GPU scalars
- slot_hidden gathers from hidden_states directly using sorted_token_ids
- scatter_add uses full sorted_token_ids (padding slots have zero weight)
- _assemble_scales_cudagraph_safe returns 2D via padded_scales.shape[0]
- Fixed padded_scales_buf allocation via float16->float8 cast
- GEMM output size: n_dim * 2 for float4_e2m1fn_x2 packed format
- assemble_activation_scales_gpu: builds padded+swizzled scale tensor
without .item() or .tolist() CPU syncs. Uses GPU index arange + cat
+ single scatter instead of per-expert Python slicing.
- Still has a for e in range(num_experts) loop but num_experts is
compile-time constant so torch.compile unrolls it.
- Added tests/cudagraph_test.py: attempts CUDA graph capture on the
MoE runner, diagnoses sync violations with patched torch functions.
- Removed the if total_slots == 0 early return (Python control flow
on GPU data)
Key changes for cudagraph compatibility:
- No .item() or .tolist() calls (zero CPU-GPU syncs)
- Pre-allocated buffers at max_num_tokens size
- GPU-only expert offsets via bincount+cumsum
- searchsorted to map rows to experts (no Python for-loop with GPU indices)
- Single scatter operation for scale padding
- Pre-allocated token_indices reused for searchsorted row mapping
- quantize_activation_nvfp4 with fixed global scale (no .max() sync)
- Cached CuTeDSL kernel (no cute.compile per forward)
- No torch.cuda.synchronize() in forward path
The fully GPU-vectorized _assemble_scales_gpu() caused index out of
bounds errors because tensor slicing with GPU-computed indices from
Python is undefined behavior.
Went back to .item() on expert_offsets for the per-expert scale split.
This forces CPU-GPU syncs (breaks cudagraph) but produces correct results.
The path to cudagraph compatibility is either:
1. Modify CuTeDSL scale assembly API to accept flat tensor + offsets
2. Use the CUTLASS kernel (already verified working)
hc_head_fuse_tilelang expects fn shape[0]=hc_mult (4) but we passed
hc_mult*(2+hc_mult) (24). Since --enforce-eager disables @torch.compile
anyway, hc_head runs eagerly and doesn't need warmup.
After _ensure_stacked frees per-expert lists, code that accesses
l1_fp4 or w13_weight.device crashes with NoneType errors. Fix:
- _check_runtime_supported: fall back to _l1_mat_b.device
- _run_mega_moe assertion: check _l1_mat_b as alternative
- finalize_weights guard: check _l1_mat_b as alternative
_ensure_stacked() creates stacked copies of all weights but never freed
the per-expert lists. For 256 experts on a 175GB model, this doubles
weight memory to ~350GB, causing OOM.
Now the per-expert lists (l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs)
are set to None after stacking, keeping only the single stacked copy.
Force-compile all lazy tilelang JIT kernels (mhc_pre, mhc_post)
and torch.compile'd hc_head during model loading, BEFORE the HTTP
server comes up. This eliminates the crash when eager mode inference
hits the model before tilelang compilation finishes.
Fixes the core issue: cudagraph capture forced eager compilation but
ate all GPU memory. Now we can run eager mode safely.