Cleanup Step 2: Archive Lineage P code, fix broken imports

- Move dead dsv4/ modules to dsv4/_archive/ (52 files)
  - model/{dsv4,mtp,layer,layer_schedule}
  - layers/{embedding,attention,ffn,norm} (kept linear,mhc,router,moe,shared_expert,grouped_linear - live)
  - cache/*, kernels/cache/*, kernels/indexer/{csa_indexer,score_topk,compute_valid_lens}
  - kernels/router/{nvfp4_fused_router,dense_router_decode_kernel,dense_router_prefill}
  - ops/{topk,topk_select,rope,router}, loader/{hf_checkpoint,layout_convert}
  - reference/{attention,compressor,csa_attention,moe_pipeline}
  - kernels/compressor/{compress_tail,csa_hca}
- Restore dsv4/ops/{router,custom_ops}.py (needed by live layers)
- Fix dsv4/kernels/{indexer,compressor,attention}/__init__.py (removed broken imports)
- Remove preload_all() from loader.py (dead, referenced nonexistent .cu file)
- Fix loader.py docstring (fused_amax_quantize_nvfp4 → quantize_nvfp4_from_buffer)
- Move broken tests to tests/e2e_archive/
  - test_fused_router, production_values_test, e2e/{one_layer,model_construction,csa_hca}
- vLLM has 0 imports of dsv4 (Step 0 confirmed)
This commit is contained in:
2026-06-02 19:27:07 +00:00
parent 8de47e26ce
commit f3b551956d
55 changed files with 2881 additions and 306 deletions

View File

@@ -0,0 +1,368 @@
"""CuTeDSL NVFP4 Grouped Linear for wo_a (o_proj first half).
wo_a in DeepSeek V4 is a grouped matmul (bmm) with n_local_groups=8 groups.
Each group: (tokens, heads_per_group * head_dim) × (heads_per_group * head_dim, o_lora_rank) → (tokens, o_lora_rank)
The vLLM forward does this via DeepGEMM fp8_einsum with equation "bhr,hdr->bhd".
We replace it with our CuTeDSL ScaledGroupedGemm using n_local_groups as num_experts,
where every token goes to every "expert" (group).
wo_a is loaded as BF16 from our NVFP4 checkpoint, then quantized to NVFP4 here.
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
"""
import torch
from dsv4.ops.quantize import (
quantize_activation_nvfp4,
quantize_weight_to_nvfp4,
quantize_nvfp4_gpu_fused,
)
from dsv4.ops.layouts import (
make_b_k_major,
assemble_scales_2d_side,
assemble_scales_3d_side,
)
from dsv4.ops.gemm_runner import (
run_nvfp4_grouped_gemm,
)
from dsv4.ops.layouts import (
ceil_div as cutedsl_ceil_div,
pad_and_swizzle_single,
)
from dsv4.ops.custom_ops import register_runner, nvfp4_linear_gemm
class Nvfp4GroupedLinear:
"""Grouped NVFP4 linear for wo_a (o-projection first half).
Handles the "bhr,hdr->bhd" einsum pattern:
- o: (tokens, n_local_heads, head_dim) → reshape to (tokens, n_local_groups, heads_per_group * head_dim)
- wo_a: (n_local_groups, heads_per_group * head_dim, o_lora_rank) → NVFP4 per group
- z: (tokens, n_local_groups, o_lora_rank)
Uses ScaledGroupedGemm with num_groups=n_local_groups.
Every token goes to every group (no routing).
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
"""
def __init__(
self,
n_local_groups: int,
heads_per_group: int,
head_dim: int,
o_lora_rank: int,
max_num_tokens: int = 8192,
device: str = "cuda",
):
self.n_local_groups = n_local_groups
self.heads_per_group = heads_per_group
self.head_dim = head_dim
self.o_lora_rank = o_lora_rank
self.max_num_tokens = max_num_tokens
self.device = device
# Per-group dimensions
self.group_in_features = heads_per_group * head_dim # 8192
self.group_out_features = o_lora_rank # 1536
# NVFP4 weight storage: lists of per-group tensors
self._weight_fp4 = None # list of (K//2, N) float4_e2m1fn_x2
self._weight_sf = None # list of (K//16, N) float8_e4m3fn
self._weight_gs = None # list of float32
# Processed weights (set by finalize_weights)
self._mat_b = None
self._scale_b = None
self._gsb = None
# Activation global scale
self._activation_global_scale = 1.0 / (6.0 * 448.0)
# Pre-allocated buffers
self._padded_x_fp4_buf = None
self._gsa_buf = None
self._expert_offsets_buf = None
self._buffers_allocated = False
def set_bf16_weight(self, wo_a_bf16: torch.Tensor):
"""Set wo_a weight from BF16 and quantize to NVFP4.
Args:
wo_a_bf16: (n_local_groups * o_lora_rank, heads_per_group * head_dim) BF16
OR (n_local_groups, heads_per_group * head_dim, o_lora_rank) if from bmm
"""
# Quantize each group separately
fp4_list = []
sf_list = []
gs_list = []
if wo_a_bf16.ndim == 3:
# bmm format: (n_local_groups, heads_per_group * head_dim, o_lora_rank)
for g in range(self.n_local_groups):
w_g = wo_a_bf16[g] # (in_features, out_features)
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_g)
# quantize_weight_to_nvfp4 returns (K//2, N) with K=in_features
# Our kernel expects (K_packed, N_packed) where K is the contraction dim
# For weight (in_features, out_features): K=in_features (contraction)
# quantize_weight_to_nvfp4 treats dim 0 as K, so result is (K//2, N) ✓
fp4_list.append(w_fp4)
sf_list.append(w_sf)
gs_list.append(w_gs)
else:
# Dense format: (n_local_groups * o_lora_rank, heads_per_group * head_dim)
# Split into per-group blocks
for g in range(self.n_local_groups):
start = g * self.o_lora_rank
end = start + self.o_lora_rank
w_g = wo_a_bf16[start:end, :] # (o_lora_rank, in_features)
# NOTE: This is transposed — weight is (out, in) but quantize_weight_to_nvfp4
# expects (K, N) where K is the packed/contraction dim.
# For matmul X @ W^T, the contraction dim of W is dim 1 (in_features).
# So we need to transpose before quantizing.
w_g_t = w_g.T # (in_features, o_lora_rank) = (K, N)
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_g_t)
fp4_list.append(w_fp4)
sf_list.append(w_sf)
gs_list.append(w_gs)
self._weight_fp4 = fp4_list
self._weight_sf = sf_list
self._weight_gs = gs_list
def load_nvfp4_weight(self, weight, weight_scale, weight_scale_2=None, input_scale=None):
"""Load NVFP4 weights directly from checkpoint — no dequant/re-quant.
The checkpoint stores weights in (out_features, in_features) layout:
weight: (n_groups * o_rank, group_in_features // 2) uint8
weight_scale: (n_groups * o_rank, group_in_features // 16) float8_e4m3fn
weight_scale_2: scalar or (n_groups * o_rank,) float
input_scale: scalar or (n_groups * o_rank,) float (unused for weight dequant)
Each group's chunk is (o_rank, K_packed) = (N, K_packed) in row-major.
Our GEMM expects (K_packed, N) per group, so we transpose each group.
Block scales follow the same transpose.
Args:
weight: (n_groups * o_rank, group_in_features // 2) uint8
weight_scale: (n_groups * o_rank, group_in_features // 16) float8_e4m3fn
weight_scale_2: scalar or per-row scale tensor (optional)
input_scale: scalar or per-row (unused — for activation quantization)
"""
fp4_list = []
sf_list = []
gs_list = []
K_packed = self.group_in_features // 2
N = self.o_lora_rank
K_sf = self.group_in_features // 16 # block scale dim along K
for g in range(self.n_local_groups):
# Extract this group's weight: (o_rank, K_packed) = (N, K_packed)
start = g * N
end = start + N
w_g = weight[start:end] # (N, K_packed) uint8
ws_g = weight_scale[start:end] # (N, K_sf) float8_e4m3fn
# Transpose to (K_packed, N) — the layout quantize_weight_to_nvfp4 produces
w_g_t = w_g.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
ws_g_t = ws_g.permute(1, 0).contiguous()
fp4_list.append(w_g_t)
sf_list.append(ws_g_t)
# Global scale: weight_scale_2
if weight_scale_2 is not None:
if weight_scale_2.numel() == 1:
gs_list.append(weight_scale_2.float().item())
else:
# Per-row: take mean of this group's rows
gs_list.append(weight_scale_2[start:end].float().mean().item())
else:
gs_list.append(1.0)
self._weight_fp4 = fp4_list
self._weight_sf = sf_list
self._weight_gs = gs_list
def finalize_weights(self):
"""Process NVFP4 weights for CuTeDSL GEMM."""
if self._weight_fp4 is None:
raise RuntimeError("Call set_bf16_weight() before finalize_weights()")
self._mat_b = make_b_k_major(torch.stack(self._weight_fp4)) # (groups, K_packed, N_packed)
self._scale_b = assemble_scales_3d_side(self._weight_sf)
self._gsb = torch.tensor(self._weight_gs, dtype=torch.float32, device=self.device)
# Free raw weights
self._weight_fp4 = None
self._weight_sf = None
self._weight_gs = None
def _allocate_buffers(self):
"""Pre-allocate buffers at max size for cudagraph compatibility."""
max_rows_per_group = cutedsl_ceil_div(self.max_num_tokens, 128) * 128
total_max_rows = max_rows_per_group * self.n_local_groups
self._padded_x_fp4_buf = torch.zeros(
total_max_rows, self.group_in_features // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2)
self._gsa_buf = torch.zeros(self.n_local_groups, dtype=torch.float32, device=self.device)
self._expert_offsets_buf = torch.zeros(self.n_local_groups, dtype=torch.int32, device=self.device)
self._buffers_allocated = True
def _ensure_initialized(self):
if self._mat_b is None:
self.finalize_weights()
if not self._buffers_allocated:
self._allocate_buffers()
def _assemble_scales_single_group(self, x_sf):
"""Assemble 2D-side activation scales for num_groups=1."""
num_rows, num_cols = x_sf.shape
padded_rows = cutedsl_ceil_div(num_rows, 128) * 128
padded_cols = cutedsl_ceil_div(num_cols, 4) * 4
buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn)
buf[:num_rows, :num_cols] = x_sf
swizzled_flat = pad_and_swizzle_single(buf)
return swizzled_flat.reshape(padded_rows, padded_cols)
def compute_activation_global_scale(self, o_sample: torch.Tensor):
"""Compute activation global scale from a warmup forward.
Args:
o_sample: (tokens, n_local_heads, head_dim) BF16 attention output sample
"""
self._ensure_initialized()
# Reshape to grouped format, then flatten to 2D for quantization
o_grouped = o_sample.reshape(-1, self.n_local_groups, self.group_in_features)
# We need a single gs for all groups — use the overall amax
from dsv4.ops.quantize import (
quantize_to_nvfp4,
)
o_flat = o_sample.reshape(-1, o_sample.shape[-1]) # (tokens, n_local_heads * head_dim) — not right
# Actually, for grouped GEMM, each group's activation is (tokens, group_in_features)
# The global scale should be computed per-group, but for simplicity use one scale
# based on the overall amax.
with torch.no_grad():
_, _, gs = quantize_to_nvfp4(o_grouped.reshape(-1, self.group_in_features))
self._activation_global_scale = gs
def run(self, o: torch.Tensor) -> torch.Tensor:
"""Forward: BF16 attention output → NVFP4 grouped GEMM → BF16 z.
Args:
o: (num_tokens, n_local_heads, head_dim) BF16 — attention output
AFTER inverse RoPE has been applied
Returns:
z: (num_tokens, n_local_groups, o_lora_rank) BF16
"""
if not hasattr(self, '_runner_id'):
self._runner_id = register_runner(self)
return nvfp4_linear_gemm(
o, self._runner_id, self.n_local_groups * self.o_lora_rank,
)
def _run_impl(self, o: torch.Tensor) -> torch.Tensor:
"""Actual implementation.
Input o is (tokens, n_local_heads, head_dim).
We reshape to (tokens, n_local_groups, heads_per_group * head_dim),
then treat each group's (tokens, group_in_features) as one "expert"
in our grouped GEMM. All tokens go to all groups.
The grouped GEMM layout requires each group's tokens to be
contiguous at their correct offset:
- Group 0: rows [0, padded_T)
- Group 1: rows [padded_T, 2*padded_T)
- ...
- Group G: rows [(G-1)*padded_T, G*padded_T)
"""
self._ensure_initialized()
num_tokens = o.shape[0]
padded_rows_per_group = cutedsl_ceil_div(num_tokens, 128) * 128
# Reshape: (tokens, n_local_heads, head_dim) → (tokens, n_local_groups, group_in_features)
o_grouped = o.reshape(num_tokens, self.n_local_groups, self.group_in_features)
# Permute to groups-first: (G, T, D)
o_grouped = o_grouped.permute(1, 0, 2)
# Flatten all groups into (G*T, D) for batched fused quantize — single kernel launch
o_flat = o_grouped.reshape(self.n_local_groups * num_tokens, self.group_in_features)
# Fused amax + quantize: zero CPU-GPU syncs.
# Computes gsa on GPU, quantizes to NVFP4, returns GPU tensor.
# Replaces the old path: .item() sync + Python quantize per group.
if getattr(self, '_use_runtime_gsa', False):
x_fp4_flat, x_sf_flat, gsa_gpu = quantize_nvfp4_gpu_fused(o_flat)
# gsa_gpu is (G*T,) — all rows share same amax (from max over full tensor)
# For the GEMM's global_scale_a, fill all group slots with the same gsa value
# Use GPU-only copy: no .item(), no CPU sync
self._gsa_buf[:1].copy_(gsa_gpu[:1]) # GPU→GPU scalar copy, no sync
# Broadcast to all groups (all get same gsa)
if self.n_local_groups > 1:
self._gsa_buf[1:].copy_(self._gsa_buf[:1].expand(self.n_local_groups - 1))
else:
self._gsa_buf.fill_(self._activation_global_scale)
x_fp4_flat, x_sf_flat = quantize_activation_nvfp4(
o_flat, self._activation_global_scale
)
# Reshape FP4 back to (G, T, D//2) and scatter into padded buffer
padded_x_fp4 = self._padded_x_fp4_buf
padded_x_fp4.view(torch.uint8).zero_()
x_fp4_grouped = x_fp4_flat.reshape(self.n_local_groups, num_tokens, self.group_in_features // 2)
for g in range(self.n_local_groups):
offset = g * padded_rows_per_group
padded_x_fp4.view(torch.uint8)[offset:offset + num_tokens] = x_fp4_grouped[g].view(torch.uint8)
# Reshape scales back to (G, T, D//16) and assemble
x_sf_grouped = x_sf_flat.reshape(self.n_local_groups, num_tokens, self.group_in_features // 16)
all_x_sf = [x_sf_grouped[g] for g in range(self.n_local_groups)]
# Assemble A-side scales for all groups
from dsv4.ops.layouts import (
assemble_scales_2d_side,
)
scale_a = assemble_scales_2d_side(all_x_sf)
# Expert offsets: cumulative [padded_T, 2*padded_T, ..., n_groups*padded_T]
expert_offsets = self._expert_offsets_buf
for g in range(self.n_local_groups):
expert_offsets[g] = (g + 1) * padded_rows_per_group
# Global scales — GPU-computed gsa already in _gsa_buf (no CPU sync)
gsa = self._gsa_buf
# Run grouped GEMM
out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4,
mat_b=self._mat_b,
scale_a=scale_a,
scale_b=self._scale_b,
expert_offsets=expert_offsets,
global_scale_a=gsa,
global_scale_b=self._gsb,
)
# Extract real outputs and reshape
# GEMM output has the same layout as mat_a: groups-first with padding
z = torch.empty(num_tokens, self.n_local_groups, self.o_lora_rank,
dtype=torch.bfloat16, device=o.device)
for g in range(self.n_local_groups):
offset = g * padded_rows_per_group
z[:, g, :] = out[offset:offset + num_tokens, :]
return z
def __call__(self, o: torch.Tensor) -> torch.Tensor:
return self.run(o)

View File

@@ -0,0 +1,267 @@
"""CuTeDSL NVFP4 Linear (single GEMM)
Generic NVFP4 GEMM runner for attention projections and any single
linear layer. Uses ScaledGroupedGemmKernel with num_groups=1.
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
"""
import torch
from dsv4.ops.quantize import (
quantize_activation_nvfp4,
quantize_to_nvfp4,
)
from dsv4.ops.layouts import (
make_b_k_major,
)
from dsv4.ops.gemm_runner import (
run_nvfp4_grouped_gemm,
)
from dsv4.kernels.gemm.grouped import (
ceil_div as cutedsl_ceil_div,
pad_and_swizzle_single,
)
from dsv4.ops.custom_ops import register_runner, nvfp4_linear_gemm
class Nvfp4Linear:
"""Single NVFP4 GEMM using CuTeDSL (num_groups=1).
Handles any (K, N) weight matrix in NVFP4 format.
Simple: quantize activation → GEMM → BF16 output.
No SiLU, no fusion, no routing.
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
"""
def __init__(
self,
in_features: int,
out_features: int,
max_num_tokens: int = 8192,
device: str = "cuda",
):
self.in_features = in_features
self.out_features = out_features
self.max_num_tokens = max_num_tokens
self.device = device
# Weights (set after construction, then call finalize_weights)
self.fp4 = None # list of 1 tensor
self.sf = None # list of 1 tensor
self.gs = None # list of 1 float
self.ws2 = None # list of 1 tensor — weight_scale_2 (scalar, folded into global_scale_b)
# Processed weights
self._mat_b = None
self._scale_b = None
self._gsb = None
# Activation global scale
self._activation_global_scale = 1.0 / (6.0 * 448.0)
# Pre-allocated buffers
self._padded_x_fp4_buf = None
self._expert_offsets_buf = None
self._gsa_buf = None
self._buffers_allocated = False
def finalize_weights(self):
"""Process weights for CuTeDSL GEMM."""
# Convert uint8 checkpoint weights to float4_e2m1fn_x2 view
fp4_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.fp4]
# Checkpoint weight is (out_features//2, in_features//2) = (N_packed, K_packed)
# make_b_k_major expects (E, K_packed, N_packed), so we need to permute
stacked = torch.stack(fp4_view).permute(0, 2, 1).contiguous() # (1, K_packed, N_packed)
self._mat_b = make_b_k_major(stacked)
# Checkpoint scale is (N_packed, K_sf) — already in the right row order for the
# kernel's swizzle. Use assemble_raw_scales_2d3d_3d_side (no transpose),
# NOT assemble_scales_3d_side (which transposes K_sf↔N).
from dsv4.ops.layouts import assemble_raw_scales_2d3d_3d_side
self._scale_b = assemble_raw_scales_2d3d_3d_side(self.sf)
self._gsb = torch.tensor(self.gs, dtype=torch.float32, device=self.device)
# Fold weight_scale_2 into global_scale_b
# Dequant formula: w = lut[w_packed] * weight_scale * weight_scale_2
# Production GEMM: y = (x * scale_a * gsa) @ (w * scale_b * gsb)
# So gsb = input_scale * weight_scale_2
if self.ws2 is not None and len(self.ws2) > 0 and self.ws2[0] is not None:
ws2_val = self.ws2[0].float().item()
self._gsb = self._gsb * ws2_val
# Free raw weights
self.fp4 = None
self.sf = None
self.gs = None
self.ws2 = None
# Eagerly JIT-compile the GEMM kernel for this (K, N) shape.
# Uses num_groups=1 since this is a single linear layer.
K_packed = self.in_features // 2
N_packed = self.out_features // 2
# warmup_compilation(1, K_packed, N_packed, self.device) # Lazy compile on first real forward
def _ensure_buffer_size(self, num_tokens: int):
"""Ensure the padded buffer is large enough for num_tokens."""
needed_rows = cutedsl_ceil_div(num_tokens, 128) * 128
if self._padded_x_fp4_buf is not None and self._padded_x_fp4_buf.shape[0] >= needed_rows:
return # Already big enough
self._padded_x_fp4_buf = torch.zeros(
needed_rows, self.in_features // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2)
self._expert_offsets_buf = torch.zeros(1, dtype=torch.int32, device=self.device)
self._gsa_buf = torch.full((1,), self._activation_global_scale, dtype=torch.float32, device=self.device)
def _ensure_initialized(self):
if self._mat_b is None:
self.finalize_weights()
def _assemble_scales_single_group(self, x_sf):
"""Assemble 2D-side activation scales for num_groups=1."""
num_rows, num_cols = x_sf.shape
padded_rows = cutedsl_ceil_div(num_rows, 128) * 128
padded_cols = cutedsl_ceil_div(num_cols, 4) * 4
buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn)
buf[:num_rows, :num_cols] = x_sf
swizzled_flat = pad_and_swizzle_single(buf)
return swizzled_flat.reshape(padded_rows, padded_cols)
def compute_activation_global_scale(self, hidden_states_sample):
"""Compute activation global scale from a warmup forward."""
self._ensure_initialized()
with torch.no_grad():
_, _, gs = quantize_to_nvfp4(hidden_states_sample)
self._activation_global_scale = gs
def run(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Forward: BF16 input → NVFP4 GEMM → BF16 output.
Uses torch.library.custom_op (nvfp4::linear_gemm) so torch.compile
treats this as an opaque op. The custom op calls _run_impl internally.
"""
if not hasattr(self, '_runner_id'):
self._runner_id = register_runner(self)
return nvfp4_linear_gemm(
hidden_states, self._runner_id, self.out_features,
)
def _run_impl(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Actual implementation — called via custom autograd to be torch.compile-safe."""
self._ensure_initialized()
num_tokens = hidden_states.shape[0]
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
# Ensure buffer is large enough
self._ensure_buffer_size(num_tokens)
# Fused amax + quantize: single kernel launch, zero CPU-GPU syncs.
# Computes amax on GPU → derives gsa → quantizes to NVFP4.
# gsa written to GPU buffer for downstream GEMM global_scale_a.
#
# This replaces the two-step path:
# compute_amax_gsa_gpu(hidden_states) → .item() sync
# quantize_nvfp4_gpu(hidden_states, gsa_float) → another kernel launch
#
# Old path: ~2 kernel launches + 1 .item() sync per projection.
# New path: 1 kernel launch + 0 .item() syncs per projection.
# Total across 61 layers: ~486 .item() syncs eliminated.
if getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
x_fp4, x_sf, gsa_gpu = quantize_nvfp4_gpu_fused(hidden_states)
self._gsa_buf.copy_(gsa_gpu[:1].reshape(1)) # GPU → GPU, no sync
else:
# P2 FIX: No per-call fill_(). The _gsa_buf already has the correct
# value — set either during initialization (via _ensure_buffer_size)
# or by the first GPU compute when _use_runtime_gsa was True.
# Old path: self._gsa_buf.fill_(self._activation_global_scale)
# — H2D transfer every call (~5µs each × 244 calls = ~1.2ms/token).
# New path: zero H2D transfers on the hot path.
from dsv4.ops.quantize import quantize_nvfp4_gpu
x_fp4, x_sf = quantize_nvfp4_gpu(hidden_states, self._activation_global_scale)
# Scatter x_fp4 into padded buffer
padded_x_fp4 = self._padded_x_fp4_buf
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[:x_fp4.shape[0]] = x_fp4.view(torch.uint8)
# Assemble A-side scales
scale_a = self._assemble_scales_single_group(x_sf)
# Expert offsets: [padded_rows] for 1 group
expert_offsets = self._expert_offsets_buf
expert_offsets.fill_(padded_rows)
# Global scales — GPU-computed gsa already in _gsa_buf (no CPU sync)
gsa = self._gsa_buf
# Run GEMM
out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4,
mat_b=self._mat_b,
scale_a=scale_a,
scale_b=self._scale_b,
expert_offsets=expert_offsets,
global_scale_a=gsa,
global_scale_b=self._gsb,
)
return out[:num_tokens]
def run_from_quantized(self, quant: 'QuantizedActivation') -> torch.Tensor:
"""Run GEMM with pre-quantized activation (skip quantize step).
Used when the input has already been quantized by a fused
RMSNorm+quantize kernel. Saves 2 kernel launches per call.
Args:
quant: QuantizedActivation with x_fp4, x_sf, gsa
"""
from dsv4.ops.quantize import QuantizedActivation
assert isinstance(quant, QuantizedActivation)
self._ensure_initialized()
num_tokens = quant.num_tokens
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
self._ensure_buffer_size(num_tokens)
# Scatter pre-quantized x_fp4 into padded buffer
padded_x_fp4 = self._padded_x_fp4_buf
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[:quant.x_fp4.shape[0]] = quant.x_fp4.view(torch.uint8)
# Assemble A-side scales from pre-quantized sf
scale_a = self._assemble_scales_single_group(quant.x_sf)
# Expert offsets
expert_offsets = self._expert_offsets_buf
expert_offsets.fill_(padded_rows)
# Global scales — use the per-row gsa from the fused kernel
# Reshape to (1,) if scalar, or use per-row (M,) broadcast
gsa = quant.gsa[:1].reshape(1) if quant.gsa.shape[0] == 1 else quant.gsa[:num_tokens]
if gsa.shape != self._gsa_buf.shape:
self._gsa_buf = gsa.contiguous()
else:
self._gsa_buf.copy_(gsa)
# Run GEMM
out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4,
mat_b=self._mat_b,
scale_a=scale_a,
scale_b=self._scale_b,
expert_offsets=expert_offsets,
global_scale_a=self._gsa_buf,
global_scale_b=self._gsb,
)
return out[:num_tokens]
def __call__(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.run(hidden_states)

549
dsv4/_archive/layers/mhc.py Normal file
View File

@@ -0,0 +1,549 @@
"""
mHC (Manifold-Constrained Hyper-Connections) — Inference Layer.
Implements Section 2.2 of the DeepSeek-V4 paper for the forward pass only.
Verified against HuggingFace DeepseekV4HyperConnection (transformers main,
modeling_deepseek_v4.py). The ordering of fn/base/scale outputs is
[pre(4), post(4), comb(16)] — NOT [pre, comb, post]. The comb matrix is
consumed TRANSPOSED in post_block. Sinkhorn starts from softmax (not exp).
pre (A_l) has an hc_eps additive guard.
---------------------------------------------------------------------
V4-Pro reference dimensions (Section 4.2.1)
---------------------------------------------------------------------
d = 7168 hidden dim
n_hc = 4 hyper-connection expansion factor
N_proj = 24 fused output of W_pre(4) + W_post(4) + W_comb(16)
K_proj = 4*7168 = 28672 = n_hc * d (flattened residual)
t_max = 20 Sinkhorn iterations
---------------------------------------------------------------------
Checkpoint layout (fn / base / scale)
---------------------------------------------------------------------
fn: (24, 28672) — rows ordered [pre(4), post(4), comb(16)]
base: (24,) — ordered [pre(4), post(4), comb(16)]
scale: (3,) — [alpha_pre, alpha_post, alpha_comb]
This matches the HuggingFace split:
pre_w, post_w, comb_w = F.linear(flat, fn).split([4, 4, 16])
pre_b, post_b, comb_b = base.split([4, 4, 16])
pre_scale, post_scale, comb_scale = scale.unbind(0)
---------------------------------------------------------------------
Kernel dependency
---------------------------------------------------------------------
tf32_hc_prenorm_gemm (DeepGEMM, SM90/SM100)
a: (T, K) BF16 — flattened residual X_flat
b: (N, K) FP32 — stacked weight [W_pre; W_post; W_comb]
d: (S, T, N) or (T, N) FP32 — raw projection outputs (pre-normalised)
sqr_sum: (S, T) or (T,) FP32 — Σ a² per token (for RMSNorm denominator)
num_splits = S (16 recommended for K=28672)
After the call:
d = d.sum(0) → (T, N)
sqr_sum = sqr_sum.sum(0) → (T,)
rms_scale = sqrt(K / (sqr_sum + eps))
d_norm = d * rms_scale[:,None] — equivalent to RMSNorm(X_flat) @ W_stacked
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Try importing DeepGEMM; fall back to plain BF16 matmul if unavailable.
# ---------------------------------------------------------------------------
try:
import deep_gemm
_HAS_DEEP_GEMM = True
except ImportError:
_HAS_DEEP_GEMM = False
NUM_SPLITS = 16 # K-split count for tf32_hc_prenorm_gemm numerical stability
EPS_RMSN = 1e-6
HC_EPS = 1e-6 # eps guard on pre (A_l) and Sinkhorn, matching HF reference
# ---------------------------------------------------------------------------
# Sinkhorn-Knopp projection (T batched 4×4 matrices)
# ---------------------------------------------------------------------------
def sinkhorn_knopp(
logits: torch.Tensor, # (T, n, n) raw logits (NOT exp'd)
t_max: int = 20,
eps: float = HC_EPS,
) -> torch.Tensor:
"""
Project each (n×n) matrix onto the Birkhoff polytope
(doubly stochastic matrices) via alternating row/col normalisation.
Matches HuggingFace DeepseekV4HyperConnection.forward:
1. softmax along last dim (row-normalize the logits)
2. add eps
3. column-normalize
4. (t_max - 1) alternating row/col normalizations
NO PYTHON FALLBACK. If the CUDA kernel fails, the pipeline dies.
The kernel MUST compile and run correctly. Period.
"""
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("mhc_sinkhorn", ["mhc_sinkhorn.cu"])
return mod.mhc_sinkhorn(logits.float(), t_max, eps)
# ---------------------------------------------------------------------------
# Context carried between pre_block and post_block
# ---------------------------------------------------------------------------
@dataclass
class mHCContext:
"""Holds the per-token mixing matrices computed in pre_block."""
B_l: torch.Tensor # (T, n_hc, n_hc) doubly stochastic residual transform
C_l: torch.Tensor # (T, n_hc) output mapping (2*sigmoid)
# ---------------------------------------------------------------------------
# mHC layer
# ---------------------------------------------------------------------------
class mHCLayer:
"""
Wraps one transformer sub-layer (attention *or* MoE) with the mHC
residual update.
Typical call pattern per layer:
x_in, ctx = mhc.pre_block(X_l)
F_out = transformer_sublayer(x_in) # (T, d)
X_next = mhc.post_block(X_l, F_out, ctx)
where X_l has shape (T, n_hc, d) — the expanded residual state.
The first call at layer 0 should use X_0 initialised via `init_state`.
"""
def __init__(
self,
hidden_dim: int = 7168,
n_hc: int = 4,
t_max_sinkhorn: int = 20,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
self.d = hidden_dim
self.n_hc = n_hc
self.K_proj = n_hc * hidden_dim # 28672 for V4-Pro
self.N_proj = n_hc + n_hc + n_hc * n_hc # 4 + 4 + 16 = 24
self.t_max = t_max_sinkhorn
self.device = device
self.dtype = dtype
# ── Learnable weights (set via load_weights) ──────────────────
# Checkpoint fn ordering: [pre(4), post(4), comb(16)]
# We store them in this order and build W_stacked = [pre, post, comb]
self.W_pre = self._buf(n_hc, self.K_proj, dtype=torch.float32) # (4, K)
self.W_post = self._buf(n_hc, self.K_proj, dtype=torch.float32) # (4, K)
self.W_comb = self._buf(n_hc * n_hc, self.K_proj, dtype=torch.float32) # (16, K)
# Checkpoint base ordering: [pre(4), post(4), comb(16)]
self.S_pre = self._buf(1, n_hc) # (1, 4) — pre bias
self.S_post = self._buf(n_hc, 1) # (4, 1) — post bias
self.S_comb = self._buf(n_hc, n_hc) # (4, 4) — comb bias
# Checkpoint scale ordering: [alpha_pre, alpha_post, alpha_comb]
self.alpha_pre = torch.zeros(1, device=device, dtype=torch.float32)
self.alpha_post = torch.zeros(1, device=device, dtype=torch.float32)
self.alpha_comb = torch.zeros(1, device=device, dtype=torch.float32)
# Pre-allocated split buffers (set in _ensure_buffers)
self._d_split = None # (NUM_SPLITS, max_T, N_proj) FP32
self._sqr_sum_split = None # (NUM_SPLITS, max_T) FP32
self._max_T = 0
# Fused stacked weight for DeepGEMM (built once in _build_stacked)
self._W_stacked = None # (N_proj, K_proj) FP32
# ── Construction helpers ──────────────────────────────────────────
def _buf(self, *shape, dtype=None):
dt = dtype or self.dtype
return torch.empty(*shape, dtype=dt, device=self.device)
def load_weights(
self,
W_pre: torch.Tensor, # (n_hc, K) FP32
W_post: torch.Tensor, # (n_hc, K) FP32
W_comb: torch.Tensor, # (n_hc², K) FP32
S_pre: torch.Tensor, # (1, n_hc)
S_post: torch.Tensor, # (n_hc, 1)
S_comb: torch.Tensor, # (n_hc, n_hc)
alpha_pre: float,
alpha_post: float,
alpha_comb: float,
):
"""
Load all mHC parameters from the checkpoint.
The W tensors must be FP32 — they are loaded as FP32 in the prenorm
GEMM (BF16 input × FP32 weight). Everything else can be BF16 in the
checkpoint and will be cast here.
"""
def _f32(t): return t.to(device=self.device, dtype=torch.float32).contiguous()
def _cvt(t): return t.to(device=self.device, dtype=self.dtype).contiguous()
self.W_pre = _f32(W_pre)
self.W_post = _f32(W_post)
self.W_comb = _f32(W_comb)
self.S_pre = _cvt(S_pre)
self.S_post = _cvt(S_post)
self.S_comb = _cvt(S_comb)
self.alpha_pre = torch.tensor(alpha_pre, dtype=torch.float32, device=self.device)
self.alpha_post = torch.tensor(alpha_post, dtype=torch.float32, device=self.device)
self.alpha_comb = torch.tensor(alpha_comb, dtype=torch.float32, device=self.device)
self._W_stacked = None # invalidate cache
def _build_stacked(self):
"""Fuse W_pre / W_post / W_comb into one (N_proj, K_proj) FP32 tensor.
Order: [pre(4), post(4), comb(16)] — matches checkpoint fn layout.
"""
self._W_stacked = torch.cat([self.W_pre, self.W_post, self.W_comb], dim=0)
# Must be K-major (contiguous along K) for DeepGEMM
self._W_stacked = self._W_stacked.contiguous()
def _ensure_buffers(self, T: int):
"""Pre-allocate split buffers if needed (avoids hot-path alloc)."""
if T <= self._max_T:
return
self._d_split = torch.empty(
NUM_SPLITS, T, self.N_proj, dtype=torch.float32, device=self.device
)
self._sqr_sum_split = torch.empty(
NUM_SPLITS, T, dtype=torch.float32, device=self.device
)
self._max_T = T
# ── Forward ──────────────────────────────────────────────────────
def _project_and_rms(self, X_flat: torch.Tensor) -> torch.Tensor:
"""
Compute RMSNorm(X_flat) @ W_stacked.T → (T, N_proj) FP32.
Uses tf32_hc_prenorm_gemm when DeepGEMM is available for fused
GEMM + squared-sum accumulation. Falls back to plain BF16 matmul.
X_flat: (T, K_proj) BF16
"""
T = X_flat.shape[0]
K = self.K_proj
if _HAS_DEEP_GEMM:
if self._W_stacked is None:
self._build_stacked()
self._ensure_buffers(T)
d_s = self._d_split[:, :T, :] # view, no copy
ss_s = self._sqr_sum_split[:, :T]
deep_gemm.tf32_hc_prenorm_gemm(
X_flat.contiguous(), # a
self._W_stacked, # b (N, K) FP32
d_s, # d (S, T, N)
ss_s, # sqr_sum (S, T)
num_splits=NUM_SPLITS,
)
d_out = d_s.sum(dim=0) # (T, N)
sqr_sum = ss_s.sum(dim=0) # (T,)
else:
if self._W_stacked is None:
self._build_stacked()
x_f32 = X_flat.float()
d_out = x_f32 @ self._W_stacked.T # (T, N)
sqr_sum = x_f32.pow(2).sum(dim=-1) # (T,)
# RMSNorm scale: multiply raw GEMM output by rsqrt(mean(x²))
rms_scale = torch.sqrt(K / (sqr_sum + EPS_RMSN)) # (T,)
return (d_out * rms_scale.unsqueeze(-1)).to(self.dtype) # (T, N) in BF16
def _dynamic_params(
self, X_l: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute per-token A_l, B_l, C_l from the current residual state.
Matches HuggingFace DeepseekV4HyperConnection.forward exactly:
1. UnweightedRMSNorm on flattened residual
2. F.linear(flat, fn) → split [pre, post, comb]
3. pre = sigmoid(pre_w * scale[0] + base[:4]) + eps
4. post = 2 * sigmoid(post_w * scale[1] + base[4:8])
5. comb = Sinkhorn(softmax(comb_w * scale[2] + base[8:]), iters)
X_l: (T, n_hc, d)
Returns:
A_l: (T, n_hc) sigmoid-constrained input mapping (+ eps)
B_l: (T, n_hc, n_hc) doubly-stochastic residual transform
C_l: (T, n_hc) 2*sigmoid-constrained output mapping
"""
T, n, d = X_l.shape
assert n == self.n_hc and d == self.d
# Flatten: (T, n_hc*d)
X_flat = X_l.reshape(T, self.K_proj).to(self.dtype)
# Unweighted RMSNorm on flattened residual (HF: self.input_norm)
# This normalizes BEFORE the linear projection.
X_flat_f = X_flat.float()
rms_inv = X_flat_f.pow(2).mean(dim=-1, keepdim=True).add(EPS_RMSN).rsqrt()
X_flat = (X_flat_f * rms_inv).to(self.dtype)
# Fused RMSNorm projection: (T, N_proj) = RMSNorm(X_flat) @ fn.T
# Note: the RMSNorm above is the "input_norm" (unweighted). The
# _project_and_rms method applies a SECOND RMSNorm (as part of
# the fused GEMM). This is intentional — the prenorm GEMM fuses
# RMSNorm into the GEMM output, and the input_norm is a separate
# unweighted norm on the input. When DeepGEMM is available, both
# are fused into a single kernel. In the fallback path, we apply
# both explicitly (the input_norm above + the GEMM-internal norm
# in _project_and_rms). The result is mathematically:
# proj = RMSNorm(RMSNorm(X_flat) @ W.T)
# which is equivalent to the HF:
# proj = F.linear(input_norm(X_flat), fn)
# followed by... wait, no. HF does NOT apply a second RMSNorm.
# Let me re-read HF:
# flat = self.input_norm(hidden_streams.flatten(start_dim=2).float())
# pre_w, post_w, comb_w = F.linear(flat, self.fn.float()).split(...)
# So HF: 1. input_norm(X_flat), 2. linear, 3. split.
# Our _project_and_rms: 1. (no input_norm yet), 2. RMSNorm(X_flat) @ W.T
# which is: (X_flat / rms(X_flat)) @ W.T = X_flat @ W.T / rms(X_flat)
# This is NOT the same as input_norm(X_flat) @ W.T because input_norm
# normalizes each token independently while RMSNorm in the GEMM divides
# the ENTIRE dot product by the RMS.
# Actually, let me re-check. Our _project_and_rms does:
# d_out = X_flat @ W.T
# rms_scale = sqrt(K / (sqr_sum + eps))
# return d_out * rms_scale
# = (X_flat @ W.T) * sqrt(K / (sum(X_flat^2) + eps))
# = (X_flat @ W.T) / sqrt(mean(X_flat^2) + eps)
# = X_flat / sqrt(mean(X_flat^2) + eps) @ W.T
# (because sqrt(mean(X^2) + eps) is a scalar per token)
# So this IS the same as input_norm(X_flat) @ W.T! ✓
# The RMSNorm commutes with the linear because it's per-token.
# So we DON'T need a separate input_norm — the GEMM-fused RMSNorm
# is equivalent. The explicit input_norm above is redundant.
# Remove it:
X_flat = X_l.reshape(T, self.K_proj).to(self.dtype)
proj = self._project_and_rms(X_flat).float()
# Split: [pre(4), post(4), comb(16)]
n = self.n_hc
pre_raw = proj[:, 0:n] # (T, n_hc)
post_raw = proj[:, n:2*n] # (T, n_hc)
comb_raw = proj[:, 2*n:2*n + n*n] # (T, n_hc²)
# Apply scale and bias (matching HF: raw * scale + base)
S_pre = self.S_pre.float() # (1, n_hc)
S_post = self.S_post.float() # (n_hc, 1)
S_comb = self.S_comb.float() # (n_hc, n_hc)
pre_tilde = self.alpha_pre * pre_raw + S_pre # (T, n_hc)
post_tilde = self.alpha_post * post_raw + S_post.flatten().unsqueeze(0) # (T, n_hc)
comb_tilde = self.alpha_comb * comb_raw + S_comb.flatten().unsqueeze(0) # (T, n_hc²)
# Apply constraints (matching HF exactly)
# pre = sigmoid(...) + hc_eps (note the eps!)
A_l = torch.sigmoid(pre_tilde) + HC_EPS # (T, n_hc)
# post = 2 * sigmoid(...)
C_l = 2.0 * torch.sigmoid(post_tilde) # (T, n_hc)
# comb = Sinkhorn(softmax(logits) + eps, iters)
comb_logits = comb_tilde.reshape(T, n, n)
B_l = sinkhorn_knopp(comb_logits, t_max=self.t_max) # (T, n_hc, n_hc)
return A_l.to(self.dtype), B_l, C_l.to(self.dtype)
# ----------------------------------------------------------------
# Public API: pre_block / post_block
# ----------------------------------------------------------------
def pre_block(
self,
X_l: torch.Tensor, # (T, n_hc, d) BF16
) -> Tuple[torch.Tensor, mHCContext]:
"""
Compute dynamic mixing params and extract the layer input.
Returns:
x_in: (T, d) BF16 — the actual input to pass to the sub-layer
ctx: mHCContext — {B_l, C_l} to be passed to post_block
"""
A_l, B_l, C_l = self._dynamic_params(X_l)
# Layer input: x_in = sum_j A_l[j] * X_l[j] (weighted sum of streams)
# Matches HF: collapsed = (pre.unsqueeze(-1) * hidden_streams).sum(dim=2)
# A_l: (T, n_hc) X_l: (T, n_hc, d)
x_in = torch.bmm(A_l.unsqueeze(1), X_l).squeeze(1) # (T, d)
return x_in, mHCContext(B_l=B_l, C_l=C_l)
def post_block(
self,
X_l: torch.Tensor, # (T, n_hc, d) BF16 — residual state BEFORE sub-layer
F_out: torch.Tensor, # (T, d) BF16 — sub-layer output
ctx: mHCContext,
) -> torch.Tensor:
"""
Apply the mHC residual update.
Matches HuggingFace: X_next = post * F_out + comb.T @ X_l
Note: comb (B_l) is consumed TRANSPOSED! This matches the HF reference:
torch.matmul(comb.transpose(-1, -2), hidden_streams)
Returns:
X_next: (T, n_hc, d) BF16
"""
# B_l.T @ X_l — note the TRANSPOSE! HF uses comb.transpose(-1,-2)
BX = torch.bmm(ctx.B_l.transpose(-1, -2), X_l.float())
# C_l * F_out
CF = ctx.C_l.unsqueeze(-1) * F_out.unsqueeze(1) # (T, n_hc, d)
X_next = (CF.float() + BX).to(self.dtype) # (T, n_hc, d)
# Diagnostic: warn on residual blowup
x_max = X_next.abs().max().item()
if x_max > 500:
# Don't clip in production, just warn
pass
return X_next
# ----------------------------------------------------------------
# Utility
# ----------------------------------------------------------------
@staticmethod
def init_state(
embeddings: torch.Tensor, # (T, d) BF16 — token embeddings
n_hc: int = 4,
) -> torch.Tensor:
"""
Initialise X_0 for the first layer.
Returns: (T, n_hc, d) BF16
"""
return embeddings.unsqueeze(1).expand(-1, n_hc, -1).clone()
@staticmethod
def read_out(X_L: torch.Tensor) -> torch.Tensor:
"""
Extract the final hidden state from the last residual state.
Stream 0 is the primary output stream.
Returns: (T, d) BF16
"""
return X_L[:, 0, :]
# ---------------------------------------------------------------------------
# Quick smoke test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import sys
torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
D, N_HC = 7168, 4
K = N_HC * D # 28672
N_PROJ = N_HC + N_HC + N_HC ** 2 # 4 + 4 + 16 = 24
mhc = mHCLayer(hidden_dim=D, n_hc=N_HC, device=device, dtype=dtype)
# Random weights matching the expected shapes (fn ordering: pre, post, comb)
mhc.load_weights(
W_pre = torch.randn(N_HC, K, dtype=torch.float32),
W_post = torch.randn(N_HC, K, dtype=torch.float32),
W_comb = torch.randn(N_HC**2, K, dtype=torch.float32),
S_pre = torch.zeros(1, N_HC, dtype=dtype),
S_post = torch.zeros(N_HC, 1, dtype=dtype),
S_comb = torch.eye(N_HC, dtype=dtype), # identity: pure residual
alpha_pre = 0.01,
alpha_post = 0.01,
alpha_comb = 0.01,
)
T = 4 # 4 tokens
# ── Forward pass ────────────────────────────────────────────────
embeddings = torch.randn(T, D, dtype=dtype, device=device)
X = mHCLayer.init_state(embeddings, n_hc=N_HC)
print(f"X_0: {X.shape} (T={T}, n_hc={N_HC}, d={D})")
for layer_idx in range(2):
x_in, ctx = mhc.pre_block(X)
print(f"\nLayer {layer_idx}:")
print(f" x_in (to sub-layer): {x_in.shape}")
print(f" B_l: {ctx.B_l.shape}")
print(f" C_l: {ctx.C_l.shape}")
F_out = x_in
X = mhc.post_block(X, F_out, ctx)
print(f" X_next: {X.shape}")
hidden = mHCLayer.read_out(X)
print(f"\nFinal hidden: {hidden.shape}")
# ── B_l is doubly stochastic check ──────────────────────────────
print("\n=== Doubly stochastic check ===")
B = ctx.B_l
row_sums = B.sum(dim=-1)
col_sums = B.sum(dim=-2)
print(f" row sum range: [{row_sums.min():.6f}, {row_sums.max():.6f}] (want ≈ 1.0)")
print(f" col sum range: [{col_sums.min():.6f}, {col_sums.max():.6f}] (want ≈ 1.0)")
assert (row_sums - 1).abs().max() < 1e-3, "B_l rows do not sum to 1"
assert (col_sums - 1).abs().max() < 1e-3, "B_l cols do not sum to 1"
print(" PASSED")
# ── A_l and C_l bounds ────────────────────────────────────────
A_l, B_l2, C_l = mhc._dynamic_params(X)
print(f"\n=== A_l ∈ (eps, 1+eps) check ===")
print(f" A_l range: [{A_l.min():.4f}, {A_l.max():.4f}] (want ∈ (eps, 1+eps))")
print(" PASSED")
print(f"\n=== C_l ∈ (0, 2) check ===")
print(f" C_l range: [{C_l.min():.4f}, {C_l.max():.4f}] (want ∈ (0, 2))")
assert C_l.min() > 0 and C_l.max() < 2, "C_l out of 2*sigmoid range"
print(" PASSED")
# ── Equivalence: T=1 decode vs T=N prefill ──────────────────────
print("\n=== Token-by-token decode == batch prefill ===")
T_big = 8
h_big = torch.randn(T_big, D, dtype=dtype, device=device)
X_batch = mHCLayer.init_state(h_big, n_hc=N_HC)
x_in_batch, ctx_batch = mhc.pre_block(X_batch)
x_in_tokens = []
for t in range(T_big):
X_t = X_batch[t:t+1]
x_in_t, _ = mhc.pre_block(X_t)
x_in_tokens.append(x_in_t)
x_in_seq = torch.cat(x_in_tokens, dim=0)
diff = (x_in_batch - x_in_seq).abs().max().item()
print(f" max |batch - sequential| on x_in: {diff:.6f}")
assert diff < 1e-2, f"Mismatch too large: {diff}"
print(" PASSED")
print("\nAll checks done.")
if not _HAS_DEEP_GEMM:
print("\n(deep_gemm not available — used BF16 matmul fallback)")

700
dsv4/_archive/layers/moe.py Normal file
View File

@@ -0,0 +1,700 @@
"""
vLLM integration for the CuTeDSL NVFP4 MoE kernel.
CUDA-graph-compatible design:
- All intermediate buffers pre-allocated at max_num_tokens * top_k size
- No .item(), .tolist(), .cpu() — zero CPU-GPU syncs
- No dynamic slicing with GPU scalars — always operate on full pre-allocated buffers
- Extra slots (beyond real tokens) are zero and contribute nothing to output
- Fixed-shape tensors throughout the forward pass
vLLM cudagraph captures at fixed token budgets (1,2,4,8,...,8192).
During capture, num_tokens equals the budget — all shapes are fixed.
During replay, inputs are padded to the budget size. Our runner always
processes max_slots = budget * top_k rows; padding rows are zeros.
"""
import torch
from dsv4.ops.quantize import (
quantize_activation_nvfp4,
quantize_weight_to_nvfp4,
quantize_to_nvfp4,
quantize_nvfp4_gpu,
deinterleave_quantize_nvfp4_cuda,
)
from dsv4.ops.layouts import (
make_b_k_major,
assemble_scales_3d_side,
interleave_l1_weights,
deinterleave_l1_weights,
)
from dsv4.ops.gemm_runner import (
run_nvfp4_grouped_gemm,
run_fused_swiglu_grouped_gemm,
warmup_fused_swiglu_compilation,
)
from dsv4.ops.layouts import (
ceil_div as cutedsl_ceil_div,
pad_and_swizzle_single,
)
from dsv4.ops.custom_ops import register_runner, nvfp4_moe_gemm
class Nvfp4MoE:
"""Manages NVFP4 MoE execution via the CuTeDSL kernel.
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs,
no dynamic shapes. Always computes at max_num_tokens * top_k capacity.
"""
def __init__(self, num_experts, hidden_size, intermediate_size,
max_num_tokens=8192, top_k=8, device="cuda",
experts_start_idx=0):
self.num_experts = num_experts
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.max_num_tokens = max_num_tokens
self.top_k = top_k
self.device = device
self.experts_start_idx = experts_start_idx
self._swiglu_limit = None # Set via set_swiglu_limit()
self._fused_swiglu = False # Set via set_fused_swiglu()
# Weight storage (set before _ensure_stacked)
self.l1_fp4 = None
self.l1_sf = None
self.l1_gs = None
self.l2_fp4 = None
self.l2_sf = None
self.l2_gs = None
# Stacked weight tensors (set in _ensure_stacked)
self._l1_mat_b = None
self._l2_mat_b = None
self._l1_scale_b = None
self._l2_scale_b = None
self._l1_gsb = None
self._l2_gsb = None
# Default: 1/2688 ≈ 0.000372 (amax=1 → gs=1/2688)
# Overridden in finalize_weights with checkpoint input_scale or warmup value
self._l1_activation_global_scale = 1.0 / (6.0 * 448.0)
self._l2_activation_global_scale = 1.0 / (6.0 * 448.0)
# Pre-allocated cudagraph buffers (set in _allocate_buffers)
self._token_indices = None
self._expert_offsets_buf = None
self._per_expert_scale_bufs_l1 = None
self._per_expert_scale_bufs_l2 = None
self._padded_x_sf_buf_l1 = None
self._padded_x_sf_buf_l2 = None
self._l1_gsa_buf = None
self._l2_gsa_buf = None
self._output_buf = None
self._row_indices_buf = None
self._padded_hidden_buf = None
self._padded_activated_buf = None # unused, using shared
self._padded_expert_offsets_buf = None
self._max_chunks_per_expert = cutedsl_ceil_div(
self.max_num_tokens * self.top_k, self.num_experts * 128
)
self._buffers_allocated = False
def set_swiglu_limit(self, limit: float | None):
"""Set the swiglu_limit for activation clamping."""
self._swiglu_limit = limit
def set_fused_swiglu(self, enabled: bool):
"""Enable fused L1 GEMM + SwiGLU kernel (saves 240+ BF16 kernel launches per token)."""
self._fused_swiglu = enabled
def _fill_token_indices(self):
"""Fill _token_indices with [0,0,..0, 1,1,..1, ...] (each token repeated top_k times).
Builds on CPU first, then copies to GPU, to ensure correctness
regardless of CuTeDSL JIT GPU memory corruption.
"""
src = torch.arange(self.max_num_tokens, dtype=torch.int32)
cpu_indices = src.unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1)
self._token_indices.copy_(cpu_indices)
def _allocate_buffers(self):
"""Pre-allocate scale buffers at max size for cudagraph compatibility."""
# Per-expert scale buffers: separate L1/L2 since K_sf differs
K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16)
padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4
K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16)
padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4
self._per_expert_scale_bufs_l1 = [
torch.zeros(128, padded_cols_l1, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
for _ in range(self.num_experts)
]
self._per_expert_scale_bufs_l2 = [
torch.zeros(128, padded_cols_l2, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
for _ in range(self.num_experts)
]
# Initialize shared buffers dict (if not already)
device_key = str(self.device)
if not hasattr(Nvfp4MoE, '_shared_padded_bufs'):
Nvfp4MoE._shared_padded_bufs = {}
if device_key not in Nvfp4MoE._shared_padded_bufs:
Nvfp4MoE._shared_padded_bufs[device_key] = {}
# Padded x_sf buffers: SHARED across all runners (not per-layer)
max_sf_rows = self.num_experts * self._max_chunks_per_expert * 128
if 'xsf_l1' not in Nvfp4MoE._shared_padded_bufs[device_key]:
Nvfp4MoE._shared_padded_bufs[device_key].update({
'xsf_l1': torch.zeros(
max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn),
'xsf_l2': torch.zeros(
max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn),
'output': torch.zeros(
self.max_num_tokens, self.hidden_size, dtype=torch.bfloat16, device=self.device
),
})
self._padded_x_sf_buf_l1 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l1']
self._padded_x_sf_buf_l2 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l2']
self._output_buf = Nvfp4MoE._shared_padded_bufs[device_key]['output']
# Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture)
self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
# Row indices for scale assembly (max_num_tokens * top_k slots)
self._row_indices_buf = torch.arange(
self.max_num_tokens * self.top_k, device=self.device
)
# Padded hidden/activated: SHARED across all runners (not per-layer)
max_rows_per_expert = self._max_chunks_per_expert * 128
padded_max_slots = self.num_experts * max_rows_per_expert
if 'hidden' not in Nvfp4MoE._shared_padded_bufs[device_key]:
Nvfp4MoE._shared_padded_bufs[device_key].update({
'hidden': torch.zeros(
padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device
),
'hidden_fp4': torch.zeros(
padded_max_slots, self.hidden_size // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2),
'activated': torch.zeros(
padded_max_slots, self.intermediate_size, dtype=torch.bfloat16, device=self.device
),
'activated_fp4': torch.zeros(
padded_max_slots, self.intermediate_size // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2),
})
self._shared_bufs = Nvfp4MoE._shared_padded_bufs[device_key]
# Padded expert offsets buffer: [0, max_rows, 2*max_rows, ...] (fixed)
self._padded_expert_offsets_buf = torch.zeros(
self.num_experts + 1, dtype=torch.int32, device=self.device
)
max_rows_per_expert = self._max_chunks_per_expert * 128
self._padded_expert_offsets_buf[1:] = torch.arange(
1, self.num_experts + 1, dtype=torch.int32, device=self.device
) * max_rows_per_expert
self._buffers_allocated = True
def _ensure_stacked(self):
if self._l1_mat_b is not None:
return
# Convert weights to kernel format
if hasattr(self, 'l1_fp4_stacked') and self.l1_fp4_stacked is not None:
# Fast path: pre-stacked 3D tensors in checkpoint format (E, N, K)
# Permute to (E, K, N) then make K-major
l1_fp4_ekn = self.l1_fp4_stacked.permute(0, 2, 1).contiguous()
l2_fp4_ekn = self.l2_fp4_stacked.permute(0, 2, 1).contiguous()
# Interleave L1 gate/up weights at granularity 4 BF16.
# This pairs gate/up within the MMA accumulator, enabling
# fused SwiGLU without runtime conditionals.
l1_fp4_ekn = interleave_l1_weights(l1_fp4_ekn)
# Convert uint8 checkpoint weights to float4_e2m1fn_x2 view
if l1_fp4_ekn.dtype == torch.uint8:
l1_fp4_ekn = l1_fp4_ekn.view(torch.float4_e2m1fn_x2)
if l2_fp4_ekn.dtype == torch.uint8:
l2_fp4_ekn = l2_fp4_ekn.view(torch.float4_e2m1fn_x2)
# Free stacked checkpoints before make_b_k_major (saves one copy)
self.l1_fp4_stacked = None
self.l2_fp4_stacked = None
torch.cuda.empty_cache()
self._l1_mat_b = make_b_k_major(l1_fp4_ekn)
self._l2_mat_b = make_b_k_major(l2_fp4_ekn)
del l1_fp4_ekn, l2_fp4_ekn
torch.cuda.empty_cache()
# Scales: checkpoint is (E, N, K_sf) — the kernel expects (N, K_sf)
# per expert for swizzle. Split into views (no copy), then assemble.
l1_sf_list = [self.l1_sf_stacked[i] for i in range(self.num_experts)]
l2_sf_list = [self.l2_sf_stacked[i] for i in range(self.num_experts)]
self.l1_sf_stacked = None
self.l2_sf_stacked = None
torch.cuda.empty_cache()
# Interleave L1 SF along N to match the interleaved weight layout.
# SF per expert from checkpoint is (N, K_sf). Interleave along N.
# interleave_l1_weights operates on last dim, so transpose to (K_sf, N),
# interleave, transpose back to (N, K_sf) for swizzle.
l1_sf_il = []
for sf_nk in l1_sf_list:
sf_kn = sf_nk.T.contiguous().unsqueeze(0) # (1, K_sf, N)
sf_kn = interleave_l1_weights(sf_kn) # (1, K_sf, N) interleaved along N
l1_sf_il.append(sf_kn[0].T.contiguous()) # (N, K_sf)
del l1_sf_list
l1_sf_list = l1_sf_il
# assemble_scales_3d_side expects (K_sf, N) per expert and transposes
# to (N, K_sf) internally. But our scales are already (N, K_sf) from
# the checkpoint! Skip the transpose by calling the assembly directly.
from dsv4.ops.layouts import (
assemble_raw_scales_2d3d_3d_side,
)
self._l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_list)
self._l2_scale_b = assemble_raw_scales_2d3d_3d_side(l2_sf_list)
del l1_sf_list, l2_sf_list
else:
# Legacy path: per-expert lists
l1_stacked = torch.stack(self.l1_fp4) # (E, K, N)
l1_stacked = interleave_l1_weights(l1_stacked) # interleave gate/up
if l1_stacked.dtype == torch.uint8:
l1_stacked = l1_stacked.view(torch.float4_e2m1fn_x2)
l2_stacked = torch.stack(self.l2_fp4)
if l2_stacked.dtype == torch.uint8:
l2_stacked = l2_stacked.view(torch.float4_e2m1fn_x2)
self._l1_mat_b = make_b_k_major(l1_stacked)
self._l2_mat_b = make_b_k_major(l2_stacked)
# Interleave L1 SF to match weight interleave
# SF from quantize_weight_to_nvfp4 is (K_sf, N). Interleave along N,
# then transpose to (N, K_sf) for swizzle via assemble_scales_3d_side.
l1_sf_il = []
for sf in self.l1_sf:
sf_ekn = sf.unsqueeze(0) # (1, K_sf, N)
sf_ekn = interleave_l1_weights(sf_ekn) # interleaved along N
l1_sf_il.append(sf_ekn[0]) # (K_sf, N)
self._l1_scale_b = assemble_scales_3d_side(l1_sf_il)
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
del l1_stacked, l1_sf_il
self.l1_fp4 = None
self.l1_sf = None
self.l2_fp4 = None
self.l2_sf = None
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
# Fold weight_scale_2 into global_scale_b
# gsb = input_scale * weight_scale_2
if self.l1_ws2 is not None:
for i, ws2 in enumerate(self.l1_ws2):
if ws2 is not None:
self._l1_gsb[i] *= ws2.float().item()
if self.l2_ws2 is not None:
for i, ws2 in enumerate(self.l2_ws2):
if ws2 is not None:
self._l2_gsb[i] *= ws2.float().item()
self.l1_gs = None
self.l2_gs = None
self.l1_ws2 = None
self.l2_ws2 = None
# Allocate buffers and eagerly warmup JIT compilation.
# cute.compile does NOT corrupt GPU memory (verified 2026-05-20).
# We warmup eagerly here to ensure compilation happens before
# the model's first forward pass, not during it.
self._token_indices = torch.zeros(
self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device
)
self._fill_token_indices()
# No _needs_token_refill: cute.compile does NOT corrupt GPU memory.
# The original corruption was a misdiagnosis (see bridge.py cache docs).
# Eagerly JIT-compile GEMM kernels for L1 and L2 shapes.
# This triggers cute.compile once per shape, caching the compiled
# kernel + workspace. Subsequent run() calls hit the cache.
# MUST happen before model forward pass to avoid OOM from lazy JIT.
from dsv4.ops.layouts import (
ceil_div as bridge_ceil_div,
)
from dsv4.ops.gemm_runner import (
warmup_compilation,
warmup_fused_swiglu_compilation,
)
K_packed = self.hidden_size // 2
N_packed_l1 = (2 * self.intermediate_size) // 2 # gate+up combined
N_packed_l2 = self.hidden_size // 2 # down
warmup_compilation(self.num_experts, K_packed, N_packed_l1, self.device) # L1
warmup_compilation(self.num_experts, K_packed, N_packed_l2, self.device) # L2
if self._fused_swiglu:
warmup_fused_swiglu_compilation(
self.num_experts, K_packed, N_packed_l1, self.device,
swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0,
) # Fused L1
self._expert_offsets_buf = torch.zeros(
self.num_experts + 1, dtype=torch.int32, device=self.device
)
self._allocate_buffers()
def prepare_weights_direct(self, l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs):
"""DEPRECATED: Use prepare_weights_from_stacked() for checkpoint weights.
This path takes pre-quantized per-expert lists. The stacked path is
more memory-efficient and avoids per-expert list overhead.
"""
self.l1_fp4 = l1_fp4
self.l1_sf = l1_sf
self.l1_gs = l1_gs
self.l2_fp4 = l2_fp4
self.l2_sf = l2_sf
self.l2_gs = l2_gs
self._l1_mat_b = None
def prepare_weights_from_stacked(self, l1_fp4_stacked, l1_sf_stacked,
l1_gs, l2_fp4_stacked, l2_sf_stacked,
l2_gs):
"""Prepare weights from pre-stacked 3D tensors (checkpoint format).
Takes (E, N, K_packed) fp4 and (E, N, K_sf) scale tensors directly
from the checkpoint, avoiding the per-expert list→stack round-trip.
The conversion to K-major and swizzled layout happens in _ensure_stacked.
This just stores the tensors for deferred processing.
"""
# Store in checkpoint format (E, N, K) — _ensure_stacked will convert
self.l1_fp4_stacked = l1_fp4_stacked
self.l1_sf_stacked = l1_sf_stacked
self.l1_gs = l1_gs
self.l2_fp4_stacked = l2_fp4_stacked
self.l2_sf_stacked = l2_sf_stacked
self.l2_gs = l2_gs
self._l1_mat_b = None
def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16):
"""DEPRECATED: Use prepare_weights_from_stacked() instead.
This path dequantizes checkpoint NVFP4 to BF16 then re-quantizes to our FP4.
While the round-trip is lossless for DeepSeek-V4 (our packing matches
the checkpoint convention exactly), it wastes memory and compute.
The direct byte path (prepare_weights_from_stacked) is preferred.
"""
self.l1_fp4, self.l1_sf, self.l1_gs = [], [], []
self.l2_fp4, self.l2_sf, self.l2_gs = [], [], []
for l1_w, l2_w in zip(l1_weights_bf16, l2_weights_bf16):
l1_w_t = l1_w.T
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_w_t)
self.l1_fp4.append(w_fp4)
self.l1_sf.append(w_sf)
self.l1_gs.append(w_gs)
l2_w_t = l2_w.T
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l2_w_t)
self.l2_fp4.append(w_fp4)
self.l2_sf.append(w_sf)
self.l2_gs.append(w_gs)
self._l1_mat_b = None
def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets,
padded_expert_offsets,
padded_x_sf_buf, per_expert_bufs):
"""Assemble 2D-side activation scales (cudagraph-safe, NO CPU syncs).
Phase 1: Scatter x_sf into padded per-expert sections (GPU-only).
Phase 2: Apply full-buffer Blackwell 32_4_4 swizzle (no Python loops).
The buffer is 128-row aligned per expert (from padded_expert_offsets),
so the full-buffer swizzle produces the correct layout. The GEMM reads
scale_a using padded_expert_offsets, matching the scatter layout.
"""
K_sf = x_sf.shape[1]
padded_x_sf = padded_x_sf_buf
padded_x_sf.zero_()
# Phase 1: Scatter x_sf into padded per-expert sections (GPU-only)
total_rows = x_sf.shape[0]
row_indices = self._row_indices_buf[:total_rows]
expert_assign = torch.searchsorted(
expert_offsets[1:], row_indices, right=True
).clamp(max=self.num_experts - 1)
local_row = row_indices - expert_offsets[expert_assign]
dst_rows = padded_expert_offsets[expert_assign] + local_row
padded_x_sf[dst_rows, :K_sf] = x_sf
# Phase 2: Full-buffer swizzle (no CPU sync, no Python loops)
# padded_x_sf is 128-row aligned per expert and 4-col aligned.
# to_blocked: (rows, cols) → view(R, 128, C, 4) → permute(0,2,1,3)
# → reshape(-1, 4, 32, 4) → transpose(1,2) → reshape(-1, 32, 16) → flatten
rows = padded_x_sf.shape[0]
cols = padded_x_sf.shape[1]
R = rows // 128
C = cols // 4
blocks = padded_x_sf.view(R, 128, C, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
swizzled = rearranged.flatten().view(torch.float8_e4m3fn)
return swizzled.reshape(rows, cols)
def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids):
"""Compute activation global scales from a warmup forward pass.
Called BEFORE cudagraph capture. Uses the SAME padded GEMM path as run()
to ensure kernel JIT happens with the same layout, and L2 gs is computed
from actual L1 output (not an approximation).
"""
self._ensure_stacked()
device = hidden_states_sample.device
num_tokens = hidden_states_sample.shape[0]
top_k = topk_ids.shape[1]
with torch.no_grad():
# Build slot mapping (same as run())
flat_ids = topk_ids.reshape(-1)
num_slots = num_tokens * top_k
token_indices = self._token_indices[:num_slots]
sort_idx = flat_ids.argsort(stable=True)
sorted_ids = flat_ids[sort_idx]
sorted_token_ids = token_indices[sort_idx]
slot_hidden = hidden_states_sample[sorted_token_ids]
# L1: get exact gs from quantize_to_nvfp4
_, _, l1_gs = quantize_to_nvfp4(slot_hidden)
# Quantize slot_hidden for GEMM
slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs)
tokens_per_expert = torch.bincount(sorted_ids, minlength=self.num_experts)[:self.num_experts].int()
expert_offsets = self._expert_offsets_buf
expert_offsets.zero_()
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
padded_expert_offsets = self._padded_expert_offsets_buf
padded_expert_offsets.zero_()
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
# Compute padded_dst (same as run())
row_indices = self._row_indices_buf[:num_slots]
expert_assign = torch.searchsorted(
expert_offsets[1:], row_indices, right=True
).clamp(max=self.num_experts - 1)
local_row = row_indices - expert_offsets[expert_assign]
padded_dst = padded_expert_offsets[expert_assign] + local_row
# Scatter x_fp4 into padded layout
padded_x_fp4 = self._shared_bufs['hidden_fp4']
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8)
l1_scale_a = self._assemble_scales_cudagraph_safe(
slot_x_sf, expert_offsets[:self.num_experts + 1],
padded_expert_offsets,
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
)
l1_gsa = torch.full((self.num_experts,), l1_gs, dtype=torch.float32, device=device)
l1_out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
)
# Extract real token outputs
l1_out_real = l1_out[padded_dst]
# L2: get exact gs from SiLU(gate)*up
# De-interleave L1 output: with interleaved weights, L1 GEMM
# output has [gate]*4, [up]*4 pattern. De-interleave before splitting.
l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0]
gate = l1_deil[:, :self.intermediate_size]
up = l1_deil[:, self.intermediate_size:]
gate_silu = torch.nn.functional.silu(gate)
if self._swiglu_limit is not None:
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
activated = gate_silu * up
_, _, l2_gs = quantize_to_nvfp4(activated)
self._l1_activation_global_scale = l1_gs
self._l2_activation_global_scale = l2_gs
def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
"""Forward: route tokens to experts, GEMM, combine.
Uses torch.library.custom_op (nvfp4::moe_gemm) so torch.compile
treats this as an opaque op. The custom op calls _run_impl internally.
"""
if not hasattr(self, '_runner_id'):
self._runner_id = register_runner(self)
return nvfp4_moe_gemm(
hidden_states, topk_weights, topk_ids,
self._runner_id, self.hidden_size,
)
def _run_impl(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
"""Run the NVFP4 MoE forward pass.
Handles global→local expert ID remapping for expert parallelism.
Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes.
Each expert's slots are padded to multiples of 128 for the GEMM.
expert_offsets is [0, padded_e0, padded_e0+padded_e1, ...].
scale_a is produced at those same offsets.
"""
num_tokens = hidden_states.shape[0]
top_k = topk_ids.shape[1]
device = hidden_states.device
self._ensure_stacked()
# -- Remap global expert IDs to local IDs --
local_ids = topk_ids - self.experts_start_idx
local_mask = (local_ids >= 0) & (local_ids < self.num_experts)
safe_ids = local_ids.clamp(0, self.num_experts - 1)
safe_weights = topk_weights * local_mask.float()
# -- Build slot mapping --
flat_ids = safe_ids.reshape(-1)
flat_weights = safe_weights.reshape(-1)
num_slots = num_tokens * top_k
token_indices = self._token_indices[:num_slots]
sort_idx = flat_ids.argsort(stable=True)
sorted_ids = flat_ids[sort_idx]
sorted_weights = flat_weights[sort_idx]
sorted_token_ids = token_indices[sort_idx]
# Expert offsets (real token counts)
tokens_per_expert = torch.bincount(sorted_ids, minlength=self.num_experts)[:self.num_experts].int()
expert_offsets = self._expert_offsets_buf
expert_offsets.zero_()
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
# Pad each expert to 128-row alignment (GPU-only computation)
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
padded_expert_offsets = self._padded_expert_offsets_buf
padded_expert_offsets.zero_()
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
total_padded_slots = padded_expert_offsets[self.num_experts]
# -- Gather hidden states into slot order, compute padded_dst --
slot_hidden = hidden_states[sorted_token_ids]
row_indices = self._row_indices_buf[:num_slots]
expert_assign = torch.searchsorted(
expert_offsets[1:], row_indices, right=True
).clamp(max=self.num_experts - 1)
local_row = row_indices - expert_offsets[expert_assign]
padded_dst = padded_expert_offsets[expert_assign] + local_row
# === L1: gate + up ===
# Fused amax + quantize: single kernel, zero CPU-GPU syncs.
# Computes amax on GPU → derives gsa → quantizes to NVFP4.
# gsa written to GPU buffer for GEMM global_scale_a.
if getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
slot_x_fp4, slot_x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(slot_hidden)
self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU, no sync
else:
slot_x_fp4, slot_x_sf = quantize_nvfp4_gpu(
slot_hidden, self._l1_activation_global_scale
)
# Scatter x_fp4 into padded layout for the GEMM
# Must scatter as uint8 (float4_e2m1fn_x2 doesn't support index_put)
padded_x_fp4 = self._shared_bufs['hidden_fp4']
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8)
l1_scale_a = self._assemble_scales_cudagraph_safe(
slot_x_sf, expert_offsets[:self.num_experts + 1],
padded_expert_offsets,
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
)
l1_gsa = self._l1_gsa_buf # already filled by GPU compute (no .fill_ needed)
if self._fused_swiglu:
# === Fused L1 GEMM + SwiGLU in kernel registers ===
l1_out = run_fused_swiglu_grouped_gemm(
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0,
)
l1_out_real = l1_out[padded_dst]
# Fused deinterleave + amax + quantize: zero CPU syncs.
# Computes gsa from de-interleaved SwiGLU output on GPU,
# quantizes in the same kernel. Writes gsa to GPU buffer.
if getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import deinterleave_amax_quantize_nvfp4_fused
slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = deinterleave_amax_quantize_nvfp4_fused(
l1_out_real, self.intermediate_size)
self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync
else:
slot_l2_x_fp4, slot_l2_x_sf = deinterleave_quantize_nvfp4_cuda(
l1_out_real, self.intermediate_size, self._l2_activation_global_scale
)
else:
# === Non-fused L1 GEMM + PyTorch SiLU(gate)*up ===
l1_out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
)
l1_out_real = l1_out[padded_dst]
l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0]
gate = l1_deil[:, :self.intermediate_size]
up = l1_deil[:, self.intermediate_size:]
gate_silu = torch.nn.functional.silu(gate)
if self._swiglu_limit is not None:
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
activated = gate_silu * up
# Compute runtime gsa for L2 from activated output (non-fused path)
# Fused amax + quantize: zero CPU syncs.
if not self._fused_swiglu and getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(activated)
self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync
elif not self._fused_swiglu:
slot_l2_x_fp4, slot_l2_x_sf = quantize_nvfp4_gpu(
activated, self._l2_activation_global_scale
)
padded_activated_fp4 = self._shared_bufs['activated_fp4']
padded_activated_fp4.view(torch.uint8).zero_()
padded_activated_fp4.view(torch.uint8)[padded_dst] = slot_l2_x_fp4.view(torch.uint8)
l2_scale_a = self._assemble_scales_cudagraph_safe(
slot_l2_x_sf, expert_offsets[:self.num_experts + 1],
padded_expert_offsets,
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
)
l2_gsa = self._l2_gsa_buf # already filled by GPU compute (no .fill_ needed)
l2_out = run_nvfp4_grouped_gemm(
mat_a=padded_activated_fp4, mat_b=self._l2_mat_b,
scale_a=l2_scale_a, scale_b=self._l2_scale_b,
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
)
l2_out_real = l2_out[padded_dst]
# === Scatter -> final output ===
y = self._output_buf[:num_tokens]
y.zero_()
weighted_out = l2_out_real * sorted_weights.unsqueeze(1).to(l2_out_real.dtype)
y.scatter_add_(
0,
sorted_token_ids.unsqueeze(1).expand(-1, self.hidden_size),
weighted_out,
)
return y

View File

@@ -0,0 +1,345 @@
"""DSV4 Router — token-to-expert assignment.
Two routing modes that share an output shape:
- 'dense': sqrt(softplus(X @ W_gate)) + per-expert bias, top-k selection.
Used by MoE layers 3+ (the bulk of the network).
- 'hash': deterministic per-token-ID lookup, uniform weights.
Used by the first 3 MoE layers per DSV4 §2.1.
Both modes produce (topk_weights, topk_ids) suitable for direct
consumption by Nvfp4MoE.run().
CUDA-graph-compatible: pre-allocated buffers, no CPU-GPU syncs.
Selection between modes is by layer_idx at construction time —
the kernel path is fixed once the Router is built so the dispatch
is constant-folded by torch.compile.
"""
from __future__ import annotations
from typing import Optional, Literal
import torch
from dsv4.ops.router import (
register_router,
dense_router_op,
hash_router_op,
)
RouterMode = Literal["dense", "hash"]
class Router:
"""DSV4 expert router.
Per the DeepSeek-V4 paper (§2.1):
- Affinity activation is sqrt(softplus(·)), replacing V3's sigmoid(·).
- Auxiliary-loss-free strategy: a learned per-expert bias (loaded
from checkpoint, frozen at inference) is added to the activation
for SELECTION only. The actual gating weight applied to expert
outputs uses the UNBIASED activation.
- First 3 MoE layers use Hash routing (Roller et al. 2021): a
precomputed [vocab_size, k] LUT mapping token IDs to expert IDs.
No gate GEMM is performed.
- Sequence-wise balance loss is training-only; not applied here.
Parameters
----------
hidden_size : int
Model hidden dimension. Must match W_gate's K dimension.
num_experts : int
Total routed experts (Flash: 256, Pro: 384). Shared experts are
handled separately by Nvfp4SharedExpert.
top_k : int
Experts activated per token. DSV4 uses 6.
routed_scaling_factor : float
Post-renormalization scale on gating weights. DSV3 used 2.5;
verify against the V4 checkpoint config — may be per-layer.
mode : {'dense', 'hash'}
Routing strategy. Decided at construction; cannot change at runtime.
vocab_size : int, optional
Required when mode='hash'. The LUT is [vocab_size, top_k] int32.
max_num_tokens : int
Upper bound on N for pre-allocated buffer sizing.
device : str
CUDA device.
"""
def __init__(
self,
hidden_size: int,
num_experts: int,
top_k: int = 6,
routed_scaling_factor: float = 2.5,
*,
mode: RouterMode,
vocab_size: Optional[int] = None,
max_num_tokens: int = 8192,
device: str = "cuda",
):
if mode == "hash" and vocab_size is None:
raise ValueError("vocab_size is required when mode='hash'")
if mode not in ("dense", "hash"):
raise ValueError(f"unknown router mode: {mode!r}")
self.hidden_size = hidden_size
self.num_experts = num_experts
self.top_k = top_k
self.routed_scaling_factor = routed_scaling_factor
self.mode = mode
self.vocab_size = vocab_size
self.max_num_tokens = max_num_tokens
self.device = device
# ---- Parameters (filled by load_weights / finalize_weights) ----
# Dense mode — fused NVFP4 kernel (single-kernel, preferred):
# gate_weight: raw NVFP4 gate weight tensor [K_packed, E_packed] uint8
# gate_weight_scale: weight scale [K_sf, E_sf] FP8 E4M3
# gate_ws2: weight_scale_2 (global scale base)
# gate_input_scale: input_scale (activation global scale base)
# Dense mode — 2-kernel NVFP4 path (fallback):
# gate_lin: Nvfp4Linear for the gate projection
# Dense mode — BF16 fallback:
# W_gate: BF16 weight for cuBLAS when NVFP4 scales not available
# Hash mode:
# hash_lut: [vocab_size, top_k] int32 — precomputed expert IDs.
self.gate_weight = None # Raw NVFP4 weight for fused kernel
self.gate_weight_scale = None # FP8 E4M3 scale for fused kernel
self.gate_ws2 = None # weight_scale_2 for fused kernel
self.gate_input_scale = None # input_scale for fused kernel
self.gate_lin = None # Nvfp4Linear for 2-kernel NVFP4 path
self.W_gate: Optional[torch.Tensor] = None # BF16 fallback
self.e_bias: Optional[torch.Tensor] = None
self.hash_lut: Optional[torch.Tensor] = None
# ---- Pre-allocated output buffers (cudagraph-safe) ----
self._topk_weights_buf: Optional[torch.Tensor] = None
self._topk_ids_buf: Optional[torch.Tensor] = None
# Runner ID assigned on first call (see custom_op pattern).
self._runner_id: Optional[int] = None
# ------------------------------------------------------------------
# Weight loading
# ------------------------------------------------------------------
def load_weights(
self,
W_gate: Optional[torch.Tensor] = None,
e_bias: Optional[torch.Tensor] = None,
hash_lut: Optional[torch.Tensor] = None,
) -> None:
"""Populate router parameters from a checkpoint shard.
Dense mode expects (W_gate, e_bias). Hash mode expects (hash_lut).
Mismatches with self.mode raise immediately — these errors are
nearly always loader bugs and silent acceptance would mask them.
"""
if self.mode == "dense":
if e_bias is None:
raise ValueError("dense router needs e_bias")
assert e_bias.shape == (self.num_experts,), \
f"e_bias shape {tuple(e_bias.shape)} != ({self.num_experts},)"
self.e_bias = e_bias.to(device=self.device, dtype=torch.float32)
if W_gate is not None:
self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16)
# gate_lin is set separately via load_nvfp4_gate()
else: # hash
if hash_lut is None:
raise ValueError("hash router needs hash_lut")
assert hash_lut.shape == (self.vocab_size, self.top_k), \
f"hash_lut shape {tuple(hash_lut.shape)} != " \
f"{(self.vocab_size, self.top_k)}"
assert (hash_lut >= 0).all() and (hash_lut < self.num_experts).all(), \
"hash_lut contains out-of-range expert IDs"
self.hash_lut = hash_lut.to(device=self.device, dtype=torch.int32)
def load_nvfp4_gate(self, gate_lin) -> None:
"""Set the NVFP4 gate linear layer (2-kernel path).
Called by the single_shot after constructing the Nvfp4Linear
from checkpoint NVFP4 scales. When set, _run_dense_impl uses
the production NVFP4 GEMM path instead of BF16 cuBLAS.
"""
self.gate_lin = gate_lin
def load_nvfp4_fused_gate(self, gate_weight, gate_weight_scale,
gate_ws2, gate_input_scale,
gate_weight_bf16=None) -> None:
"""Set raw NVFP4 gate tensors and create Nvfp4Linear for production GEMM."""
self.gate_weight = gate_weight.to(device=self.device)
self.gate_weight_scale = gate_weight_scale.to(device=self.device)
self.gate_ws2 = gate_ws2.to(device=self.device) if gate_ws2 is not None else None
self.gate_input_scale = gate_input_scale.to(self.device)
# Create Nvfp4Linear from BF16 weight (handles layout correctly)
if gate_weight_bf16 is not None:
from dsv4.layers.linear import Nvfp4Linear
from dsv4.ops.quantize import quantize_to_nvfp4
E = gate_weight_bf16.shape[0]
gate_lin = Nvfp4Linear(in_features=self.hidden_size, out_features=E, device=self.device)
g_fp4, g_sf, g_gs = quantize_to_nvfp4(gate_weight_bf16.bfloat16().to(self.device))
gate_lin.fp4 = [g_fp4]
gate_lin.sf = [g_sf]
gate_lin.gs = [g_gs]
ws2_val = gate_ws2.float().item() if gate_ws2.numel() == 1 else gate_ws2.float().mean().item()
gate_lin.ws2 = [torch.tensor([ws2_val], device=self.device, dtype=torch.float32)]
gate_lin._activation_global_scale = gate_input_scale.float().item() if gate_input_scale.numel() == 1 else gate_input_scale.float().mean().item()
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
gate_lin.finalize_weights()
self.gate_lin = gate_lin
def finalize_weights(self) -> None:
"""Allocate output buffers and JIT-compile the routing kernel.
Mirrors the finalize_weights() pattern in Nvfp4Linear: a one-time
setup step called after all parameters are loaded. Triggers
kernel compilation so the first forward isn't paying that cost.
"""
self._topk_weights_buf = torch.empty(
self.max_num_tokens, self.top_k,
dtype=torch.float32, device=self.device,
)
self._topk_ids_buf = torch.empty(
self.max_num_tokens, self.top_k,
dtype=torch.int32, device=self.device,
)
# Eager JIT — dispatcher knows our mode and triggers the right
# kernel's compile path. See dsv4/ops/router.py.
from dsv4.ops.router import warmup_router_compilation
warmup_router_compilation(self)
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def __call__(
self,
hidden_states: torch.Tensor,
token_ids: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Produce (topk_weights, topk_ids) for downstream Nvfp4MoE.
Parameters
----------
hidden_states : Tensor [N, hidden_size] bfloat16
Required for dense mode. Ignored for hash mode (kept in the
signature so the call site is mode-agnostic).
token_ids : Tensor [N] int32, optional
Required for hash mode. Ignored for dense mode.
Returns
-------
topk_weights : Tensor [N, top_k] float32
topk_ids : Tensor [N, top_k] int32
Notes
-----
Both outputs are views into pre-allocated buffers — do not retain
them across router calls. Nvfp4MoE consumes them immediately,
which matches its existing contract.
"""
if self._topk_weights_buf is None:
raise RuntimeError("Router.finalize_weights() not called")
if self.mode == "dense":
if hidden_states is None:
raise ValueError("dense router requires hidden_states")
return self._run_dense(hidden_states)
else:
if token_ids is None:
raise ValueError("hash router requires token_ids")
return self._run_hash(token_ids)
# ------------------------------------------------------------------
# Mode-specific dispatch — each routes through a torch.library.custom_op
# so Dynamo / torch.compile treats the kernel as opaque.
# ------------------------------------------------------------------
def _run_dense(self, hidden_states: torch.Tensor):
if self._runner_id is None:
self._runner_id = register_router(self)
return dense_router_op(
hidden_states,
self._runner_id,
self.num_experts,
self.top_k,
)
def _run_hash(self, token_ids: torch.Tensor):
if self._runner_id is None:
self._runner_id = register_router(self)
return hash_router_op(
token_ids,
self._runner_id,
self.top_k,
)
# ------------------------------------------------------------------
# Called by the custom_op dispatch in dsv4/ops/router.py — not by user code.
# ------------------------------------------------------------------
def _run_dense_impl(self, hidden_states: torch.Tensor):
"""Hot-path: fused NVFP4, 2-kernel NVFP4, or BF16 fallback.
Priority:
1. Fused NVFP4 kernel (single-kernel GEMM + router epilogue)
2. 2-kernel NVFP4 path (Nvfp4Linear + activation_topk)
3. BF16 cuBLAS fallback
"""
N = hidden_states.shape[0]
out_w = self._topk_weights_buf[:N]
out_ids = self._topk_ids_buf[:N]
if self.gate_lin is not None:
# NVFP4 production GEMM path (proven Nvfp4Linear)
from dsv4.kernels.router import dense_router_dispatch_nvfp4
dense_router_dispatch_nvfp4(
hidden_states=hidden_states,
gate_lin=self.gate_lin,
e_bias=self.e_bias,
routed_scaling_factor=self.routed_scaling_factor,
top_k=self.top_k,
out_weights=out_w,
out_ids=out_ids,
)
elif self.gate_weight is not None:
# Fused NVFP4 path (gate_lin was not created)
# Fall back to BF16
from dsv4.kernels.router import dense_router_dispatch
dense_router_dispatch(
hidden_states=hidden_states,
W_gate=self.W_gate,
e_bias=self.e_bias,
routed_scaling_factor=self.routed_scaling_factor,
top_k=self.top_k,
out_weights=out_w,
out_ids=out_ids,
)
else:
from dsv4.kernels.router import dense_router_dispatch
dense_router_dispatch(
hidden_states=hidden_states,
W_gate=self.W_gate,
e_bias=self.e_bias,
routed_scaling_factor=self.routed_scaling_factor,
top_k=self.top_k,
out_weights=out_w,
out_ids=out_ids,
)
return out_w, out_ids
def _run_hash_impl(self, token_ids: torch.Tensor):
"""Hot-path entry into the hash gather kernel.
Implementation lives in dsv4/kernels/cuda/hash_router.cu via the
wrapper in dsv4/ops/router.py.
"""
from dsv4.kernels.router import hash_router_dispatch
N = token_ids.shape[0]
out_w = self._topk_weights_buf[:N]
out_ids = self._topk_ids_buf[:N]
hash_router_dispatch(
token_ids=token_ids,
hash_lut=self.hash_lut,
top_k=self.top_k,
out_weights=out_w, # filled with 1/k
out_ids=out_ids,
)
return out_w, out_ids

View File

@@ -0,0 +1,409 @@
"""CuTeDSL Shared Expert Pipeline
NVFP4 inference for DeepSeek V4 shared experts.
Uses ScaledGroupedGemmKernel with num_groups=1.
Pipeline:
1. Quantize activation: BF16 → NVFP4 (using warmup gs)
2. L1 GEMM: NVFP4_act × NVFP4_weight(gate_up) → BF16
3. SiLU(gate) * up → BF16
4. Re-quantize: BF16 → NVFP4 (using warmup gs)
5. L2 GEMM: NVFP4_act × NVFP4_weight(down) → BF16
Unlike MoE, there's no routing, no scatter, no expert offsets.
All tokens go through the same expert (the shared expert).
Scale assembly is just: quantize activation → pad to 128-row alignment → Blackwell swizzle.
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs,
no dynamic shapes. Padding rows are zeros that contribute nothing to GEMM output.
"""
import torch
from dsv4.ops.quantize import (
quantize_activation_nvfp4,
quantize_to_nvfp4,
)
from dsv4.ops.layouts import (
make_b_k_major,
interleave_l1_weights,
deinterleave_l1_weights,
)
from dsv4.ops.gemm_runner import (
run_nvfp4_grouped_gemm,
run_fused_swiglu_grouped_gemm,
)
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
from dsv4.kernels.gemm.grouped import (
ceil_div as cutedsl_ceil_div,
pad_and_swizzle_single,
)
class _SharedExpertApply(torch.autograd.Function):
"""Custom autograd function to make CuTeDSL runner opaque to torch.compile."""
@staticmethod
def forward(ctx, runner, hidden_states):
return runner._run_impl(hidden_states)
class Nvfp4SharedExpert:
"""NVFP4 shared expert runner using CuTeDSL GEMM (num_groups=1).
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
"""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
max_num_tokens: int = 8192,
device: str = "cuda",
swiglu_limit: float = 10.0,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.max_num_tokens = max_num_tokens
self.device = device
self.swiglu_limit = swiglu_limit
self._fused_swiglu = False # Set via set_fused_swiglu()
# Weights (set after construction, then call finalize_weights)
self.l1_fp4 = None
self.l1_sf = None
self.l1_gs = None
self.l2_fp4 = None
self.l2_sf = None
self.l2_gs = None
# weight_scale_2 per layer (scalar, folded into global_scale_b in finalize_weights)
self.l1_ws2 = None
self.l2_ws2 = None
# Processed weights (set by finalize_weights)
self._l1_mat_b = None
self._l2_mat_b = None
self._l1_scale_b = None
self._l2_scale_b = None
self._l1_gsb = None
self._l2_gsb = None
# Activation global scales (set by compute_activation_global_scales)
self._l1_activation_global_scale = 1.0 / (6.0 * 448.0)
self._l2_activation_global_scale = 1.0 / (6.0 * 448.0)
# Pre-allocated cudagraph buffers (set in _allocate_buffers)
self._padded_x_fp4_buf_l1 = None
self._padded_x_sf_buf_l1 = None
self._padded_x_fp4_buf_l2 = None
self._padded_x_sf_buf_l2 = None
self._l1_gsa_buf = None
self._l2_gsa_buf = None
self._expert_offsets_buf = None
self._buffers_allocated = False
def set_swiglu_limit(self, limit: float):
self.swiglu_limit = limit
def set_fused_swiglu(self, enabled: bool):
"""Enable fused L1 GEMM + SwiGLU kernel (1-group variant of MoE fused kernel)."""
self._fused_swiglu = enabled
def finalize_weights(self):
"""Process weights for CuTeDSL GEMM. Must be called after setting l1/l2 weights."""
# Convert uint8 checkpoint weights to float4_e2m1fn_x2 view
l1_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.l1_fp4]
l2_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.l2_fp4]
# Checkpoint weight is (N_packed, K_packed), make_b_k_major expects (E, K_packed, N_packed)
l1_stacked = torch.stack(l1_view).permute(0, 2, 1).contiguous()
l2_stacked = torch.stack(l2_view).permute(0, 2, 1).contiguous()
# P1: Interleave L1 gate/up weights for fused SwiGLU kernel compatibility.
# The fused kernel's SwiGLU epilogue expects granularity-8 interleaved gate/up.
# The unfused path (if _fused_swiglu=False) deinterleaves the GEMM output before splitting.
if self._fused_swiglu:
l1_stacked = interleave_l1_weights(l1_stacked, granularity_bf16=8)
# Stack weights and convert to K-major
self._l1_mat_b = make_b_k_major(l1_stacked) # (1, K_packed, N_packed)
self._l2_mat_b = make_b_k_major(l2_stacked)
# Checkpoint scale is (N_packed, K_sf) — use assemble_raw_scales_2d3d_3d_side
from dsv4.ops.layouts import assemble_raw_scales_2d3d_3d_side
self._l1_scale_b = assemble_raw_scales_2d3d_3d_side(self.l1_sf)
self._l2_scale_b = assemble_raw_scales_2d3d_3d_side(self.l2_sf)
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
# Fold weight_scale_2 into global_scale_b
# gsb = input_scale * weight_scale_2
if self.l1_ws2 is not None:
for i, ws2 in enumerate(self.l1_ws2):
if ws2 is not None:
self._l1_gsb[i] *= ws2.float().item()
if self.l2_ws2 is not None:
for i, ws2 in enumerate(self.l2_ws2):
if ws2 is not None:
self._l2_gsb[i] *= ws2.float().item()
# Free raw weights
self.l1_fp4 = None
self.l1_sf = None
self.l1_gs = None
self.l2_fp4 = None
self.l2_sf = None
self.l2_gs = None
self.l1_ws2 = None
self.l2_ws2 = None
def _allocate_buffers(self):
"""Pre-allocate all buffers at max size for cudagraph compatibility."""
max_rows = cutedsl_ceil_div(self.max_num_tokens, 128) * 128 # pad to 128
# L1: hidden_size packed, L2: intermediate_size packed
self._padded_x_fp4_buf_l1 = torch.zeros(
max_rows, self.hidden_size // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2)
self._padded_x_fp4_buf_l2 = torch.zeros(
max_rows, self.intermediate_size // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2)
# Padded scale buffers (need same padded dimensions as pad_and_swizzle_single produces)
K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16)
padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4
K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16)
padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4
self._padded_x_sf_buf_l1 = torch.zeros(
max_rows, padded_cols_l1, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
self._padded_x_sf_buf_l2 = torch.zeros(
max_rows, padded_cols_l2, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
# Global scale buffers
self._l1_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
self._l2_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
# Expert offsets for num_groups=1: just [num_tokens_padded]
# The GEMM expects expert_offsets as (num_experts,) cumulative offsets
# For 1 expert: offsets = [num_tokens] (just one element)
self._expert_offsets_buf = torch.zeros(1, dtype=torch.int32, device=self.device)
self._buffers_allocated = True
def _ensure_initialized(self):
"""Lazily initialize stacked weights and buffers."""
if self._l1_mat_b is None:
self.finalize_weights()
if not self._buffers_allocated:
self._allocate_buffers()
def _assemble_scales_single_group(self, x_sf, num_tokens, padded_x_sf_buf):
"""Assemble 2D-side activation scales for num_groups=1.
For a single group, scale assembly is just:
1. Copy x_sf into a correctly-sized buffer (padded to 128 rows, 4 cols)
2. Apply pad_and_swizzle_single (Blackwell swizzle)
3. Reshape back to 2D (kernel expects 2D scale_a)
The padded buffer must be sized exactly for 128-aligned num_tokens,
NOT the max_num_tokens buffer (which would be way too large).
"""
num_rows, num_cols = x_sf.shape
padded_rows = cutedsl_ceil_div(num_rows, 128) * 128
padded_cols = cutedsl_ceil_div(num_cols, 4) * 4
# Use a temp buffer sized for this exact token count
buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn)
buf[:num_rows, :num_cols] = x_sf
swizzled_flat = pad_and_swizzle_single(buf)
return swizzled_flat.reshape(padded_rows, padded_cols)
def compute_activation_global_scales(self, hidden_states_sample):
"""Compute activation global scales from a warmup forward pass.
Called BEFORE cudagraph capture. Uses quantize_to_nvfp4 to get
the exact global_scale from the data, then runs L1 to compute
L2 gs from actual SiLU(gate)*up output.
"""
self._ensure_initialized()
with torch.no_grad():
# L1: exact gs from quantize_to_nvfp4
_, _, l1_gs = quantize_to_nvfp4(hidden_states_sample)
self._l1_activation_global_scale = l1_gs
# Run L1 GEMM to get intermediate for L2 gs
num_tokens = hidden_states_sample.shape[0]
l1_out = self._run_l1(hidden_states_sample)
if l1_out is not None and not torch.isnan(l1_out).any():
gate = l1_out[:, :self.intermediate_size]
up = l1_out[:, self.intermediate_size:]
if self.swiglu_limit is not None:
gate = gate.clamp(max=self.swiglu_limit)
up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
activated = torch.nn.functional.silu(gate) * up
_, _, l2_gs = quantize_to_nvfp4(activated)
self._l2_activation_global_scale = l2_gs
def _run_l1_fused(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Fused L1 GEMM + SwiGLU + clamp — single kernel launch (1-group variant of MoE fused kernel)."""
num_tokens = hidden_states.shape[0]
x_bf16 = hidden_states.reshape(num_tokens, self.hidden_size)
# Quantize activation to NVFP4 (fused amax + quantize)
if getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
x_fp4, x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(x_bf16)
self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU
else:
from dsv4.ops.quantize import quantize_activation_nvfp4
x_fp4, x_sf = quantize_activation_nvfp4(x_bf16, self._l1_activation_global_scale)
# Padded buffer setup for 1-group GEMM
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
padded_x_fp4 = self._padded_x_fp4_buf_l1
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8)
# Assemble A-side scales
scale_a = self._assemble_scales_single_group(x_sf, num_tokens, self._padded_x_sf_buf_l1)
# Expert offsets: [padded_rows] for 1 group (int32, pre-allocated)
expert_offsets = self._expert_offsets_buf
expert_offsets.fill_(padded_rows)
# Global scales — GPU-computed gsa already in _l1_gsa_buf (no CPU sync)
gsa = self._l1_gsa_buf
# Run fused GEMM + SwiGLU
l1_out = run_fused_swiglu_grouped_gemm(
mat_a=padded_x_fp4,
mat_b=self._l1_mat_b,
scale_a=scale_a,
scale_b=self._l1_scale_b,
expert_offsets=expert_offsets,
global_scale_a=gsa,
global_scale_b=self._l1_gsb,
swiglu_limit=self.swiglu_limit if self.swiglu_limit is not None else 0.0,
)
l1_out_real = l1_out[:num_tokens] # (num_tokens, 2*intermediate) BF16, interleaved [silu(gate), silu(gate)*up]
# Deinterleave to separate gate and up, then take up half (SwiGLU result)
l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0] # (num_tokens, 2*intermediate) deinterleaved
intermediate = l1_deil[:, self.intermediate_size:] # up half = silu(gate)*up
return intermediate # (num_tokens, intermediate_size) BF16
def _run_l1(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""L1 GEMM: activation × gate_up_weight → BF16."""
num_tokens = hidden_states.shape[0]
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
# Fused amax + quantize: zero CPU syncs.
if getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
x_fp4, x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(hidden_states)
self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU, no sync
else:
x_fp4, x_sf = quantize_activation_nvfp4(
hidden_states, self._l1_activation_global_scale
)
# Scatter x_fp4 into padded buffer
padded_x_fp4 = self._padded_x_fp4_buf_l1
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8)
# Assemble A-side scales
scale_a = self._assemble_scales_single_group(x_sf, num_tokens, self._padded_x_sf_buf_l1)
# Expert offsets: [padded_rows] for 1 group
expert_offsets = self._expert_offsets_buf
expert_offsets.fill_(padded_rows)
# Global scales — GPU-computed gsa already in _l1_gsa_buf (no CPU sync)
gsa = self._l1_gsa_buf
# Run GEMM
out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4,
mat_b=self._l1_mat_b,
scale_a=scale_a,
scale_b=self._l1_scale_b,
expert_offsets=expert_offsets,
global_scale_a=gsa,
global_scale_b=self._l1_gsb,
)
# Extract real token outputs
return out[:num_tokens]
def _run_l2(self, intermediate: torch.Tensor) -> torch.Tensor:
"""L2 GEMM: intermediate × down_weight → BF16."""
num_tokens = intermediate.shape[0]
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
# Fused amax + quantize: zero CPU syncs.
if getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
x_fp4, x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(intermediate)
self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync
else:
x_fp4, x_sf = quantize_activation_nvfp4(
intermediate, self._l2_activation_global_scale
)
# Scatter into padded buffer
padded_x_fp4 = self._padded_x_fp4_buf_l2
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8)
# Assemble A-side scales
scale_a = self._assemble_scales_single_group(x_sf, num_tokens, self._padded_x_sf_buf_l2)
# Expert offsets
expert_offsets = self._expert_offsets_buf
expert_offsets.fill_(padded_rows)
# Global scales — GPU-computed gsa already in _l2_gsa_buf (no CPU sync)
gsa = self._l2_gsa_buf
# Run GEMM
out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4,
mat_b=self._l2_mat_b,
scale_a=scale_a,
scale_b=self._l2_scale_b,
expert_offsets=expert_offsets,
global_scale_a=gsa,
global_scale_b=self._l2_gsb,
)
return out[:num_tokens]
def run(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Full shared expert forward: L1 → SiLU → L2 → output."""
return _SharedExpertApply.apply(self, hidden_states)
def _run_impl(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Actual implementation — called via custom autograd to be torch.compile-safe."""
self._ensure_initialized()
if self._fused_swiglu:
# P1: Fused L1 GEMM + SwiGLU + clamp in one kernel launch
intermediate = self._run_l1_fused(hidden_states)
else:
l1_out = self._run_l1(hidden_states)
if l1_out.shape[1] < 2 * self.intermediate_size:
print(f" WARNING: l1_out shape {l1_out.shape} < expected (N, {2*self.intermediate_size})", flush=True)
gate = l1_out[:, :self.intermediate_size]
up = l1_out[:, self.intermediate_size:]
if torch.isnan(l1_out).any():
print(f" SE L1 NaN: l1_out nan at {torch.isnan(l1_out).sum().item()} / {l1_out.numel()} positions, shape={l1_out.shape}", flush=True)
if torch.isnan(gate).any() or torch.isnan(up).any():
print(f" SE gate nan={torch.isnan(gate).any().item()} up nan={torch.isnan(up).any().item()}", flush=True)
if self.swiglu_limit is not None:
gate = gate.clamp(max=self.swiglu_limit)
up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
intermediate = torch.nn.functional.silu(gate) * up
output = self._run_l2(intermediate)
return output

View File

@@ -0,0 +1,138 @@
"""torch.library.custom_op wrappers for CuTeDSL NVFP4 kernels.
Dynamo (torch.compile fullgraph) cannot trace through CuTeDSL internals
(JIT compilation, cute.compile, etc.). By wrapping the runner calls in
torch.library.custom_op, Dynamo treats them as opaque black boxes.
This is the correct approach per PyTorch's extensibility model:
- custom_op is the supported way to make Dynamo skip tracing
- autograd.Function does NOT work reliably with fullgraph mode
- The runner's _run_impl is already cudagraph-safe
The registry pattern: custom ops can only take tensor/scalar arguments.
We store runners in a global dict keyed by integer ID, and pass the ID
as an int parameter. During Dynamo tracing, the fake impl returns a
correctly-shaped tensor without touching the runner. During execution,
the real impl looks up the runner and calls _run_impl.
"""
import torch
# ---------------------------------------------------------------------------
# Runner registry — maps integer IDs to runner objects
# ---------------------------------------------------------------------------
_next_runner_id = 0
_runner_registry: dict[int, object] = {}
def register_runner(runner) -> int:
"""Register a CuTeDSL runner and return its integer ID."""
global _next_runner_id
rid = _next_runner_id
_next_runner_id += 1
_runner_registry[rid] = runner
return rid
def get_runner(rid: int):
"""Look up a runner by ID."""
return _runner_registry[rid]
# ---------------------------------------------------------------------------
# NVFP4 Linear GEMM custom op (single linear layer)
# ---------------------------------------------------------------------------
@torch.library.custom_op("nvfp4::linear_gemm", mutates_args=())
def nvfp4_linear_gemm(
x: torch.Tensor,
runner_id: int,
out_features: int,
) -> torch.Tensor:
"""Opaque NVFP4 linear GEMM for torch.compile.
Args:
x: (M, K) BF16 input
runner_id: integer key into the runner registry
out_features: output dimension (for shape inference)
Returns:
(M, out_features) BF16 output
"""
runner = get_runner(runner_id)
return runner._run_impl(x)
@nvfp4_linear_gemm.register_fake
def _(x, runner_id, out_features):
return torch.empty(x.shape[0], out_features, dtype=torch.bfloat16, device=x.device)
# ---------------------------------------------------------------------------
# NVFP4 MoE custom op (L1 + SiLU + L2 grouped GEMM)
# ---------------------------------------------------------------------------
@torch.library.custom_op("nvfp4::moe_gemm", mutates_args=())
def nvfp4_moe_gemm(
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
runner_id: int,
hidden_size: int,
) -> torch.Tensor:
"""Opaque NVFP4 MoE GEMM for torch.compile.
Args:
hidden_states: (M, K) BF16 input
topk_weights: (M, top_k) float32 routing weights
topk_ids: (M, top_k) int32 expert IDs
runner_id: integer key into the runner registry
hidden_size: output dimension (for shape inference)
Returns:
(M, hidden_size) BF16 output
"""
runner = get_runner(runner_id)
return runner._run_impl(hidden_states, topk_weights, topk_ids)
@nvfp4_moe_gemm.register_fake
def _(hidden_states, topk_weights, topk_ids, runner_id, hidden_size):
return torch.empty(
hidden_states.shape[0], hidden_size,
dtype=torch.bfloat16, device=hidden_states.device,
)
# ---------------------------------------------------------------------------
# DSV4 Sparse FMHA custom op (attention with SWA + sink bias)
# ---------------------------------------------------------------------------
@torch.library.custom_op("dsv4::sparse_fmha_with_swa", mutates_args=())
def dsv4_sparse_fmha(
q: torch.Tensor, # (n_q_heads, T, hd) BF16
k: torch.Tensor, # (n_kv_heads, N, hd) or (N, hd) BF16
v: torch.Tensor, # same as k
sink_bias: torch.Tensor, # (n_q_heads,) FP32 — can be zeros if unused
scale: float,
swa_len: int,
is_causal: bool,
n_comp: int,
) -> torch.Tensor:
"""Opaque DSV4 attention for torch.compile.
Delegates to dsv4_attention with the appropriate flags.
sink_bias is always passed (use zeros when unused) to keep the
custom_op signature tensor-only for Dynamo compatibility.
"""
from dsv4.kernels.attention.production import dsv4_attention as _dsv4_attention
# If sink_bias is all zeros and n_comp == 0, skip sink bias
has_sink = n_comp > 0 and sink_bias.abs().sum().item() > 0
return _dsv4_attention(
q, k, v, scale=scale,
swa_len=swa_len if swa_len > 0 else None,
is_causal=is_causal,
n_comp=n_comp,
sink_bias=sink_bias if has_sink else None,
)
@dsv4_sparse_fmha.register_fake
def _(q, k, v, sink_bias, scale, swa_len, is_causal, n_comp):
return torch.empty_like(q)

View File

@@ -0,0 +1,93 @@
"""torch.library.custom_op wrappers and dispatch for the Router kernels.
Mirrors the pattern in dsv4/ops/custom_ops.py:
- Routers are registered into an integer-keyed table.
- The custom_op takes the integer ID and tensor args only.
- Dynamo can't trace through the kernel; the op is opaque.
"""
import torch
from dsv4.kernels.router import (
dense_router_dispatch, # picks decode vs prefill internally
hash_router_dispatch,
)
_next_router_id = 0
_router_registry: dict[int, object] = {}
def register_router(router) -> int:
global _next_router_id
rid = _next_router_id
_next_router_id += 1
_router_registry[rid] = router
return rid
def get_router(rid: int):
return _router_registry[rid]
def warmup_router_compilation(router) -> None:
"""Trigger eager JIT compilation for the router's kernel path.
Runs a dummy forward at max_num_tokens to compile the kernel for the
expected shape range. Caller already has the buffers allocated.
"""
if router.mode == "dense":
# Dummy forward at small N triggers decode-path compile.
# CuTeDSL fused kernel is WIP — falls through to prefill path.
dummy = torch.zeros(
1, router.hidden_size,
dtype=torch.bfloat16, device=router.device,
)
try:
router._run_dense_impl(dummy)
except Exception:
pass # CuTeDSL kernel not yet working; prefill path is fine
else:
dummy = torch.zeros(1, dtype=torch.int32, device=router.device)
router._run_hash_impl(dummy)
# ----- Dense router custom op -----
@torch.library.custom_op("dsv4::dense_router", mutates_args=())
def dense_router_op(
hidden_states: torch.Tensor,
router_id: int,
num_experts: int,
top_k: int,
) -> tuple[torch.Tensor, torch.Tensor]:
router = get_router(router_id)
return router._run_dense_impl(hidden_states)
@dense_router_op.register_fake
def _(hidden_states, router_id, num_experts, top_k):
N = hidden_states.shape[0]
device = hidden_states.device
return (
torch.empty(N, top_k, dtype=torch.float32, device=device),
torch.empty(N, top_k, dtype=torch.int32, device=device),
)
# ----- Hash router custom op -----
@torch.library.custom_op("dsv4::hash_router", mutates_args=())
def hash_router_op(
token_ids: torch.Tensor,
router_id: int,
top_k: int,
) -> tuple[torch.Tensor, torch.Tensor]:
router = get_router(router_id)
return router._run_hash_impl(token_ids)
@hash_router_op.register_fake
def _(token_ids, router_id, top_k):
N = token_ids.shape[0]
device = token_ids.device
return (
torch.empty(N, top_k, dtype=torch.float32, device=device),
torch.empty(N, top_k, dtype=torch.int32, device=device),
)

View File

@@ -1,180 +1,6 @@
"""DSV4 Attention kernels — public integration API.
====================================================================
STATUS: SKELETON — not yet connected to model
====================================================================
These functions define the API that AttentionSubBlock will call.
They're correct in structure but depend on:
1. LayerCacheHandle being fully implemented (gather_compressed_kv, etc.)
2. The production FMHA wrapper supporting sink_bias and n_comp
3. Custom op registration for torch.compile compatibility
See ROADMAP.md Priority 5 for the full Stage E checklist.
====================================================================
These functions bridge the model's AttentionSubBlock to the production
FMHA kernel wrapper. Each function handles the cache → dense-tensor
materialization that the kernel requires.
The model's attention layer calls these after:
1. Projection (q_down, q_up, kv_down)
2. RoPE application
3. Compression + cache writes
4. Indexer + top-k (CSA only)
These functions handle:
- Gathering sparse/dense KV from cache into dense tensors
- Calling the production FMHA wrapper
- Returning attention output for inverse RoPE + wo_a/wo_b
The live inference path uses dsv4.kernels.attention.production directly.
See production.py for the dsv4_attention function used by single_shot_inference.py.
"""
from dsv4.kernels.attention.production import dsv4_attention
import torch
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from dsv4.cache.handle import LayerCacheHandle
def sparse_fmha_with_swa(
q: torch.Tensor, # (T, n_h * hd) BF16, post-RoPE
cache: "LayerCacheHandle", # provides compressed + SWA KV
selected_indices: torch.Tensor, # (T, top_k) int64 — which compressed blocks
sink_logits: Optional[torch.Tensor] = None, # (n_h,) FP32
sliding_window: int = 128,
) -> torch.Tensor:
"""CSA attention: sparse top-k compressed KV + sliding window, fused sink merge.
Gathers the top-k compressed KV blocks + SWA window into a contiguous
tensor, then calls the production FMHA with sink bias.
Args:
q: (T, n_h * hd) BF16 query (post-RoPE, pre-reshape)
cache: LayerCacheHandle with CSA compressed entries + SWA window
selected_indices: (T, top_k) int64 block indices from the indexer
sink_logits: (n_h,) FP32 per-head sink bias
sliding_window: SWA window length
Returns:
(T, n_h * hd) BF16 attention output (pre inverse-RoPE)
"""
# Reshape q to (n_h, T, hd)
n_h_and_hd = q.shape[-1]
# n_h and hd come from the cache's config
n_h = cache.num_query_heads
hd = cache.head_dim
T = q.shape[0]
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd)
# Gather compressed KV for the selected blocks
# The cache handle provides the materialized dense KV from paged pool
k_compressed, v_compressed = cache.gather_compressed_kv(selected_indices)
# k_compressed: (1, n_comp_kv, hd) or (n_kv, n_comp_kv, hd)
# v_compressed: same shape
# Gather SWA window KV
k_swa, v_swa = cache.gather_swa_kv()
# k_swa: (1, swa_len, hd), v_swa: same
# Concatenate: [compressed, SWA] — single softmax (D5c insight)
k_full = torch.cat([k_compressed, k_swa], dim=-2) # (1, n_comp+swa_len, hd)
v_full = torch.cat([v_compressed, v_swa], dim=-2)
# n_comp = compressed KV length (for sink bias offset)
n_comp = k_compressed.shape[-2]
# Call production attention — MQA (n_kv=1 for DSV4)
output = dsv4_attention(
q_heads, k_full, v_full,
swa_len=sliding_window,
is_causal=True,
n_comp=n_comp,
sink_bias=sink_logits,
) # (n_h, T, hd)
# Reshape back to (T, n_h * hd)
return output.permute(1, 0, 2).reshape(T, n_h * hd)
def dense_fmha_with_swa(
q: torch.Tensor,
cache: "LayerCacheHandle",
sink_logits: Optional[torch.Tensor] = None,
sliding_window: int = 128,
) -> torch.Tensor:
"""HCA attention: dense over all compressed KV + SWA window, fused sink merge.
No indexer — all compressed entries are attended (m'=128 compression
means the sequence is very short).
Args:
q: (T, n_h * hd) BF16 query
cache: LayerCacheHandle with HCA compressed entries + SWA window
sink_logits: (n_h,) FP32 per-head sink bias
sliding_window: SWA window length
Returns:
(T, n_h * hd) BF16 attention output
"""
n_h = cache.num_query_heads
hd = cache.head_dim
T = q.shape[0]
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2)
# Dense: gather ALL compressed KV (no indexer needed)
k_compressed, v_compressed = cache.gather_all_compressed_kv()
k_swa, v_swa = cache.gather_swa_kv()
k_full = torch.cat([k_compressed, k_swa], dim=-2)
v_full = torch.cat([v_compressed, v_swa], dim=-2)
n_comp = k_compressed.shape[-2]
output = dsv4_attention(
q_heads, k_full, v_full,
swa_len=sliding_window,
is_causal=True,
n_comp=n_comp,
sink_bias=sink_logits,
)
return output.permute(1, 0, 2).reshape(T, n_h * hd)
def swa_only_fmha(
q: torch.Tensor,
cache: "LayerCacheHandle",
sink_logits: Optional[torch.Tensor] = None,
sliding_window: int = 128,
) -> torch.Tensor:
"""SWA-only attention: pure local attention over the sliding window.
No compression branch, no indexer. Used for the first two layers
of the Flash variant.
Args:
q: (T, n_h * hd) BF16 query
cache: LayerCacheHandle with SWA window
sink_logits: (n_h,) FP32 per-head sink bias
sliding_window: SWA window length
Returns:
(T, n_h * hd) BF16 attention output
"""
n_h = cache.num_query_heads
hd = cache.head_dim
T = q.shape[0]
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2)
k_swa, v_swa = cache.gather_swa_kv()
# No n_comp (no compressed branch), no sink bias offset
output = dsv4_attention(
q_heads, k_swa, v_swa,
swa_len=sliding_window,
is_causal=True,
n_comp=0,
sink_bias=sink_logits,
)
return output.permute(1, 0, 2).reshape(T, n_h * hd)

View File

@@ -1,56 +1,5 @@
"""CSA/HCA compressor — Python API bridge.
Wraps the compression functions with the interface that
AttentionSubBlock and flush.py expect.
The compressor runs token-level softmax over m entries (CSA) or m' entries (HCA)
to produce compressed KV entries. The compressed entries are then written to the
paged pool by the flush_write kernel.
See dsv4/kernels/compressor/production_compress.py for the live path.
See dsv4/kernels/cuda/compressor_reduce.cu for the CUDA kernel.
"""
import torch
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dsv4.cache.handle import LayerCacheHandle
from dsv4.kernels.compressor.compress_tail import csa_compress_tail, hca_compress_tail
def csa_compress_and_store(
kv_raw: torch.Tensor, # (T, head_dim) BF16 — current KV (goes to tail)
cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool
) -> None:
"""CSA: compress KV entries and store into the classical paged cache.
Steps:
1. Check if tail has enough entries (tail_len >= m=4)
2. If so, run compression (csa_compress_tail)
3. Write compressed output to paged pool via flush_write
4. Update tail buffer (a-stream becomes next b-stream)
"""
from dsv4.kernels.cuda.flush_write import flush_write_csa_cuda
# NOTE: This function is called from AttentionSubBlock.forward, which
# writes the raw KV to the tail buffer first (via cache.write_swa).
# The actual compression + flush happens when tail_len >= m.
# For now, the write_swa call handles the tail buffer write.
# The flush is triggered separately by the flush pipeline.
# See dsv4/cache/flush.py for the flush orchestration.
pass # Compression is handled by flush.py, not directly here
def hca_compress_and_store(
kv_raw: torch.Tensor, # (T, head_dim) BF16
cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool
) -> None:
"""HCA: compress KV entries and store into the classical paged cache.
Same structure as CSA but no b-stream, no overlap, m'=128.
"""
pass # See flush.py
# Make compress_tail functions importable from this package
__all__ = [
'csa_compress_and_store', 'hca_compress_and_store',
'csa_compress_tail', 'hca_compress_tail',
]

View File

@@ -1,2 +1,2 @@
"""CUDA kernel loader — re-exports from loader.py for convenience."""
from dsv4.kernels.cuda.loader import get_cuda_module, preload_all
from dsv4.kernels.cuda.loader import get_cuda_module

View File

@@ -7,7 +7,7 @@ being called on every kernel invocation (was ~100ms per call, called ~500x per t
Usage:
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"])
result = mod.fused_amax_quantize_nvfp4(x, divisor)
result = mod.quantize_nvfp4_from_buffer(x, divisor)
"""
import os
import hashlib
@@ -65,17 +65,4 @@ def get_cuda_module(name, sources, extra_cuda_cflags=None):
return mod
def preload_all():
"""Preload all CUDA kernels at startup (before the hot path)."""
# amax_gsa — computes gsa on GPU (no .item())
get_cuda_module("amax_gsa", ["amax_gsa.cu"])
# quantize-from-buffer — reads gsa from GPU buffer (no .item())
get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"])
# Standalone quantize (for when gsa is known, not hot path)
get_cuda_module("quantize_nvfp4", ["quantize_nvfp4.cu"])
# Sampler
get_cuda_module("sampler", ["sampler.cu"])
# Dequant NVFP4
get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
# Fused compress + quantize
get_cuda_module("compressor_reduce_quant", ["compressor_reduce_quant.cu"])

View File

@@ -1,63 +1,5 @@
"""CSA indexer — Python API bridge.
Wraps the CUDA indexer score+topk kernel with the interface that
AttentionSubBlock expects.
The indexer (paper §2.3.5, eq. 16) scores each query against
compressed blocks via weighted ReLU MQA logits, then selects
top-k blocks for sparse attention.
Currently uses scalar FP32 CUDA cores after FP4 dequant.
The FP4 tensor-core path (Stage F / E7) is a future optimization.
See dsv4/kernels/cuda/indexer_score_topk.cu for the live CUDA kernel.
The live inference path uses the inline indexer in single_shot_inference.py.
"""
import torch
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dsv4.cache.handle import LayerCacheHandle
def compute_index_scores_topk(
q_indexer: torch.Tensor, # (T, n_I_h * c_I) BF16 — indexer query
w_indexer: torch.Tensor, # (T, n_I_h) FP32 — per-head weights
cache: "LayerCacheHandle", # provides FP4 indexer keys
top_k: int = 512, # number of blocks to select
) -> torch.Tensor: # (T, top_k) int64 — selected block indices
"""CSA: score compressed entries and select top-k blocks.
Uses the CUDA indexer_score_topk kernel (raw CUDA, FP4 dequant + scalar
score + min-heap top-k). Returns entry indices for gather_compressed_kv.
"""
from dsv4.kernels.indexer.score_topk import run_indexer_score_topk
# Read the indexer view from the cache
indexer_view = cache.read_indexer_view()
# c_I is the indexer head dimension from schema
n_I_h = cache.schema.indexer_entries_per_block # This is entries, not heads
c_I = cache.schema.indexer_head_dim # 128
# n_I_h (number of indexer heads) comes from the config, not the schema.
# We need to pass it through the handle or compute it.
# For DSV4: n_I_h = 64 (same for Flash and Pro)
# TODO: add indexer_num_heads to schema or handle
n_I_h = 64 # config.indexer_num_heads, hardcoded for now
# Reshape q_indexer from (T, n_I_h * c_I) to (T, n_I_h * c_I) — already flat
# The kernel expects q_I: [T, n_I_h * c_I] BF16
# and w_h: [T, n_I_h] FP32
entries_per_block = cache.schema.entries_per_block
indices = run_indexer_score_topk(
q_I=q_indexer,
w_h=w_indexer.float() if w_indexer.dtype != torch.float32 else w_indexer,
indexer_view=indexer_view,
num_heads=n_I_h,
head_dim=c_I,
top_k=top_k,
entries_per_block=entries_per_block,
)
# indices: (T, top_k) int32 → convert to int64 for gather_compressed_kv
return indices.to(torch.int64)

View File

@@ -1,5 +1,6 @@
# helpers/import_closure.py — list dsv4 modules NOT reachable from the entry points.
# Usage: python helpers/import_closure.py (run from repo root, PYTHONPATH=repo root)
# Usage: python3 helpers/import_closure.py (run from repo root)
# NOTE: handles lazy imports inside functions (single_shot uses these heavily)
import ast, pathlib, sys
ROOT = pathlib.Path(__file__).resolve().parent.parent
ENTRYPOINTS = ["single_shot_inference.py"] # vLLM has 0 imports of dsv4 (Step 0 confirmed)
@@ -11,6 +12,7 @@ def module_to_path(mod):
return p if p.exists() else None
def imports_of(path):
"""Parse ALL imports including lazy ones inside functions."""
tree = ast.parse(path.read_text())
out = set()
for n in ast.walk(tree):