Dynamo (torch.compile fullgraph) cannot trace through CuTeDSL internals
(cute.compile, JIT, etc.). The autograd.Function approach was unreliable
with fullgraph mode — Dynamo would still try to trace through it.
Fix: torch.library.custom_op makes Dynamo treat our GEMM as an opaque
black box. No reimplementing the kernel — just route through the existing
runner via a registry pattern:
- Runners registered in global dict with integer IDs
- Custom op takes (tensors, runner_id, shape_hint) -> tensor
- Dynamo calls fake impl for shape inference, never touches the runner
- At execution time, real impl looks up runner and calls _run_impl
Changes:
- New: cutedsl/custom_ops.py (custom op definitions + registry)
- New: tests/test_custom_op.py (local unit tests, no GPU needed)
- Removed: _Nvfp4LinearApply, _MoEApply (autograd.Function classes)
- Updated: nvfp4_linear.py, runner.py, cutedsl.py, nvfp4_cutedsl.py
to use custom ops instead of autograd.Function
- Updated: cutedsl_quant_method.py to use custom op + registry
allow_in_graph doesn't work — Dynamo can't create proxies for Python
objects (the runner). The custom op approach requires only tensor args.
This time the GEMM impl correctly:
- Uses quantize_activation_nvfp4 for activation quantization
- Pads x_fp4 via uint8 + view(float4) for torch.zeros compat
- Assembles A-side scales with pad + swizzle
- Uses int32 expert_offsets (CuTeDSL requirement)
- Passes runner's pre-assembled mat_b, scale_b, gsb tensors
The custom op approach required reimplementing the GEMM (wrong scale
assembly, wrong tensor formats, cudaErrorIllegalAddress). Instead,
use torch.autograd.Function + torch._dynamo.allow_in_graph which
tells Dynamo to treat the function as an opaque kernel call, while
still using the runner's battle-tested _run_impl for the actual GEMM.
allow_in_graph is the proper way to register opaque ops for Dynamo
without reimplementing the computation.
The warmup custom op call hit cudaErrorIllegalAddress because our
custom op GEMM implementation doesn't match the runner's call convention.
Skip warmup for now — MoE kernel warmup handles CuTeDSL JIT cleanup.
- pad_and_swizzle_single takes 1 arg (2D tensor), not 4
- Inline the scale assembly logic: pad x_sf → swizzle → unsqueeze for 1 group
- Remove unused CuTeDSLNvfp4Linear import from custom op impl
Dynamo in fullgraph mode traces through torch.autograd.Function, hitting
CuTeDSL JIT internals (Path.cwd) and crashing. Registering as a custom op
makes it opaque to Dynamo — tracing calls the fake impl, real impl only
runs during inference.
Custom op: cutedsl::nvfp4_gemm(x, mat_b, scale_b, global_scale_b,
in_features, out_features, activation_global_scale) -> Tensor
Store finalized weight tensors on the layer (from runner._mat_b etc.)
instead of the runner object, since custom ops can only accept tensors.
- Create CuTeDSLNvFp4LinearKernel extending NvFp4LinearKernel base class
- Register it via init_nvfp4_linear_kernel() selection mechanism
(inserted at top of _POSSIBLE_NVFP4_KERNELS, before FlashInfer)
- process_weights_after_loading: uint8→FP4, permute, create CuTeDSL runner
- apply_weights: route through CuTeDSL GEMM
- Update Dockerfile: copy kernel + registration script
- Fix attention: always use forward() for quantized compressor/indexer
layers (dtype check was fragile after kernel swaps weights to dummy BF16)