- Move dead dsv4/ modules to dsv4/_archive/ (52 files)
- model/{dsv4,mtp,layer,layer_schedule}
- layers/{embedding,attention,ffn,norm} (kept linear,mhc,router,moe,shared_expert,grouped_linear - live)
- cache/*, kernels/cache/*, kernels/indexer/{csa_indexer,score_topk,compute_valid_lens}
- kernels/router/{nvfp4_fused_router,dense_router_decode_kernel,dense_router_prefill}
- ops/{topk,topk_select,rope,router}, loader/{hf_checkpoint,layout_convert}
- reference/{attention,compressor,csa_attention,moe_pipeline}
- kernels/compressor/{compress_tail,csa_hca}
- Restore dsv4/ops/{router,custom_ops}.py (needed by live layers)
- Fix dsv4/kernels/{indexer,compressor,attention}/__init__.py (removed broken imports)
- Remove preload_all() from loader.py (dead, referenced nonexistent .cu file)
- Fix loader.py docstring (fused_amax_quantize_nvfp4 → quantize_nvfp4_from_buffer)
- Move broken tests to tests/e2e_archive/
- test_fused_router, production_values_test, e2e/{one_layer,model_construction,csa_hca}
- vLLM has 0 imports of dsv4 (Step 0 confirmed)
139 lines
4.8 KiB
Python
139 lines
4.8 KiB
Python
"""torch.library.custom_op wrappers for CuTeDSL NVFP4 kernels.
|
|
|
|
Dynamo (torch.compile fullgraph) cannot trace through CuTeDSL internals
|
|
(JIT compilation, cute.compile, etc.). By wrapping the runner calls in
|
|
torch.library.custom_op, Dynamo treats them as opaque black boxes.
|
|
|
|
This is the correct approach per PyTorch's extensibility model:
|
|
- custom_op is the supported way to make Dynamo skip tracing
|
|
- autograd.Function does NOT work reliably with fullgraph mode
|
|
- The runner's _run_impl is already cudagraph-safe
|
|
|
|
The registry pattern: custom ops can only take tensor/scalar arguments.
|
|
We store runners in a global dict keyed by integer ID, and pass the ID
|
|
as an int parameter. During Dynamo tracing, the fake impl returns a
|
|
correctly-shaped tensor without touching the runner. During execution,
|
|
the real impl looks up the runner and calls _run_impl.
|
|
"""
|
|
|
|
import torch
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Runner registry — maps integer IDs to runner objects
|
|
# ---------------------------------------------------------------------------
|
|
_next_runner_id = 0
|
|
_runner_registry: dict[int, object] = {}
|
|
|
|
|
|
def register_runner(runner) -> int:
|
|
"""Register a CuTeDSL runner and return its integer ID."""
|
|
global _next_runner_id
|
|
rid = _next_runner_id
|
|
_next_runner_id += 1
|
|
_runner_registry[rid] = runner
|
|
return rid
|
|
|
|
|
|
def get_runner(rid: int):
|
|
"""Look up a runner by ID."""
|
|
return _runner_registry[rid]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# NVFP4 Linear GEMM custom op (single linear layer)
|
|
# ---------------------------------------------------------------------------
|
|
@torch.library.custom_op("nvfp4::linear_gemm", mutates_args=())
|
|
def nvfp4_linear_gemm(
|
|
x: torch.Tensor,
|
|
runner_id: int,
|
|
out_features: int,
|
|
) -> torch.Tensor:
|
|
"""Opaque NVFP4 linear GEMM for torch.compile.
|
|
|
|
Args:
|
|
x: (M, K) BF16 input
|
|
runner_id: integer key into the runner registry
|
|
out_features: output dimension (for shape inference)
|
|
Returns:
|
|
(M, out_features) BF16 output
|
|
"""
|
|
runner = get_runner(runner_id)
|
|
return runner._run_impl(x)
|
|
|
|
|
|
@nvfp4_linear_gemm.register_fake
|
|
def _(x, runner_id, out_features):
|
|
return torch.empty(x.shape[0], out_features, dtype=torch.bfloat16, device=x.device)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# NVFP4 MoE custom op (L1 + SiLU + L2 grouped GEMM)
|
|
# ---------------------------------------------------------------------------
|
|
@torch.library.custom_op("nvfp4::moe_gemm", mutates_args=())
|
|
def nvfp4_moe_gemm(
|
|
hidden_states: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
runner_id: int,
|
|
hidden_size: int,
|
|
) -> torch.Tensor:
|
|
"""Opaque NVFP4 MoE GEMM for torch.compile.
|
|
|
|
Args:
|
|
hidden_states: (M, K) BF16 input
|
|
topk_weights: (M, top_k) float32 routing weights
|
|
topk_ids: (M, top_k) int32 expert IDs
|
|
runner_id: integer key into the runner registry
|
|
hidden_size: output dimension (for shape inference)
|
|
Returns:
|
|
(M, hidden_size) BF16 output
|
|
"""
|
|
runner = get_runner(runner_id)
|
|
return runner._run_impl(hidden_states, topk_weights, topk_ids)
|
|
|
|
|
|
@nvfp4_moe_gemm.register_fake
|
|
def _(hidden_states, topk_weights, topk_ids, runner_id, hidden_size):
|
|
return torch.empty(
|
|
hidden_states.shape[0], hidden_size,
|
|
dtype=torch.bfloat16, device=hidden_states.device,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DSV4 Sparse FMHA custom op (attention with SWA + sink bias)
|
|
# ---------------------------------------------------------------------------
|
|
@torch.library.custom_op("dsv4::sparse_fmha_with_swa", mutates_args=())
|
|
def dsv4_sparse_fmha(
|
|
q: torch.Tensor, # (n_q_heads, T, hd) BF16
|
|
k: torch.Tensor, # (n_kv_heads, N, hd) or (N, hd) BF16
|
|
v: torch.Tensor, # same as k
|
|
sink_bias: torch.Tensor, # (n_q_heads,) FP32 — can be zeros if unused
|
|
scale: float,
|
|
swa_len: int,
|
|
is_causal: bool,
|
|
n_comp: int,
|
|
) -> torch.Tensor:
|
|
"""Opaque DSV4 attention for torch.compile.
|
|
|
|
Delegates to dsv4_attention with the appropriate flags.
|
|
sink_bias is always passed (use zeros when unused) to keep the
|
|
custom_op signature tensor-only for Dynamo compatibility.
|
|
"""
|
|
from dsv4.kernels.attention.production import dsv4_attention as _dsv4_attention
|
|
|
|
# If sink_bias is all zeros and n_comp == 0, skip sink bias
|
|
has_sink = n_comp > 0 and sink_bias.abs().sum().item() > 0
|
|
return _dsv4_attention(
|
|
q, k, v, scale=scale,
|
|
swa_len=swa_len if swa_len > 0 else None,
|
|
is_causal=is_causal,
|
|
n_comp=n_comp,
|
|
sink_bias=sink_bias if has_sink else None,
|
|
)
|
|
|
|
|
|
@dsv4_sparse_fmha.register_fake
|
|
def _(q, k, v, sink_bias, scale, swa_len, is_causal, n_comp):
|
|
return torch.empty_like(q)
|