Commit Graph

6 Commits

Author SHA1 Message Date
35fab6cff3 Replace autograd.Function with torch.library.custom_op for Dynamo compat
Dynamo (torch.compile fullgraph) cannot trace through CuTeDSL internals
(cute.compile, JIT, etc.). The autograd.Function approach was unreliable
with fullgraph mode — Dynamo would still try to trace through it.

Fix: torch.library.custom_op makes Dynamo treat our GEMM as an opaque
black box. No reimplementing the kernel — just route through the existing
runner via a registry pattern:
  - Runners registered in global dict with integer IDs
  - Custom op takes (tensors, runner_id, shape_hint) -> tensor
  - Dynamo calls fake impl for shape inference, never touches the runner
  - At execution time, real impl looks up runner and calls _run_impl

Changes:
  - New: cutedsl/custom_ops.py (custom op definitions + registry)
  - New: tests/test_custom_op.py (local unit tests, no GPU needed)
  - Removed: _Nvfp4LinearApply, _MoEApply (autograd.Function classes)
  - Updated: nvfp4_linear.py, runner.py, cutedsl.py, nvfp4_cutedsl.py
    to use custom ops instead of autograd.Function
  - Updated: cutedsl_quant_method.py to use custom op + registry
2026-05-19 01:54:48 +00:00
48386e34ad Fix torch.compile: use custom autograd Function instead of @torch.compiler.disable
torch.compile fullgraph mode can't handle @torch.compiler.disable (skips
the function and refuses to compile). Custom autograd Functions are treated
as opaque ops by torch.compile — they execute eagerly without the compiler
trying to trace into CuTeDSL internals (JIT, Path.cwd, etc).
2026-05-18 21:38:28 +00:00
85e1cd3b69 Fix torch.compile crash: @torch.compiler.disable on all CuTeDSL run()
CuTeDSL internals (Path.cwd, threading, JIT) are incompatible with
torch.dynamo tracing. Marking run() as compiler-disabled makes the
runners opaque to torch.compile — they execute eagerly while the
rest of the model gets compiled.
2026-05-18 21:07:35 +00:00
a94011ec92 Fix torch.compile crash: remove threading.Lock from LUT cache path
The _NVFP4_STEP_LUT_LOCK caused 'Unsupported context manager' under
torch.compile/cudagraph. LUT is now pre-populated during warmup so
the fast path (cache hit) never hits a lock.

Also removed all init/warmup debug prints from CuTeDSL kernels.
2026-05-18 20:54:55 +00:00
450793311c Wire CuTeDSL kernels into vLLM: replace all BF16 dequant with native NVFP4
- CuTeDSLNvfp4Method: custom quant method that creates CuTeDSL runners
  during process_weights_after_loading, then swaps to CuTeDSLNvfp4LinearMethod
  for forward dispatch
- Attention projections (fused_wqa_wkv, wq_b, wo_b) now route through
  CuTeDSLNvfp4Linear (cosine 0.992-0.996 vs BF16 reference)
- Shared expert now uses CuTeDSLSharedExpertRunner (cosine 0.992 vs BF16)
  with monkey-patched forward for fused L1+SiLU+L2 pipeline
- Deleted all BF16 dequant code (_dequant_nvfp4_to_bf16, _post_quant_fix,
  input_scale fixes)
- Deleted _post_quant_fix hook from utils.py
- Fixed SwiGLU clamp: gate clamped BEFORE SiLU (matching SiluAndMulWithClamp)
- Cleaned up all debug prints
- Updated Dockerfile with new kernel files
2026-05-18 20:27:42 +00:00
6ce6a47be9 Add NVFP4 linear runner + attention projection test
- CuTeDSLNvfp4Linear: generic single-GEMM runner for any NVFP4 projection
- test_attention.py: tests q_a_proj, q_b_proj, kv_proj, o_b_proj vs BF16
- Same pad+swizzle pattern as shared expert, but no SiLU/fusion
2026-05-18 20:14:03 +00:00