Cleanup Step 2: Archive Lineage P code, fix broken imports
- Move dead dsv4/ modules to dsv4/_archive/ (52 files)
- model/{dsv4,mtp,layer,layer_schedule}
- layers/{embedding,attention,ffn,norm} (kept linear,mhc,router,moe,shared_expert,grouped_linear - live)
- cache/*, kernels/cache/*, kernels/indexer/{csa_indexer,score_topk,compute_valid_lens}
- kernels/router/{nvfp4_fused_router,dense_router_decode_kernel,dense_router_prefill}
- ops/{topk,topk_select,rope,router}, loader/{hf_checkpoint,layout_convert}
- reference/{attention,compressor,csa_attention,moe_pipeline}
- kernels/compressor/{compress_tail,csa_hca}
- Restore dsv4/ops/{router,custom_ops}.py (needed by live layers)
- Fix dsv4/kernels/{indexer,compressor,attention}/__init__.py (removed broken imports)
- Remove preload_all() from loader.py (dead, referenced nonexistent .cu file)
- Fix loader.py docstring (fused_amax_quantize_nvfp4 → quantize_nvfp4_from_buffer)
- Move broken tests to tests/e2e_archive/
- test_fused_router, production_values_test, e2e/{one_layer,model_construction,csa_hca}
- vLLM has 0 imports of dsv4 (Step 0 confirmed)
This commit is contained in:
368
dsv4/_archive/layers/grouped_linear.py
Normal file
368
dsv4/_archive/layers/grouped_linear.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""CuTeDSL NVFP4 Grouped Linear for wo_a (o_proj first half).
|
||||
|
||||
wo_a in DeepSeek V4 is a grouped matmul (bmm) with n_local_groups=8 groups.
|
||||
Each group: (tokens, heads_per_group * head_dim) × (heads_per_group * head_dim, o_lora_rank) → (tokens, o_lora_rank)
|
||||
|
||||
The vLLM forward does this via DeepGEMM fp8_einsum with equation "bhr,hdr->bhd".
|
||||
We replace it with our CuTeDSL ScaledGroupedGemm using n_local_groups as num_experts,
|
||||
where every token goes to every "expert" (group).
|
||||
|
||||
wo_a is loaded as BF16 from our NVFP4 checkpoint, then quantized to NVFP4 here.
|
||||
|
||||
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from dsv4.ops.quantize import (
|
||||
quantize_activation_nvfp4,
|
||||
quantize_weight_to_nvfp4,
|
||||
quantize_nvfp4_gpu_fused,
|
||||
)
|
||||
from dsv4.ops.layouts import (
|
||||
make_b_k_major,
|
||||
assemble_scales_2d_side,
|
||||
assemble_scales_3d_side,
|
||||
)
|
||||
from dsv4.ops.gemm_runner import (
|
||||
run_nvfp4_grouped_gemm,
|
||||
)
|
||||
from dsv4.ops.layouts import (
|
||||
ceil_div as cutedsl_ceil_div,
|
||||
pad_and_swizzle_single,
|
||||
)
|
||||
from dsv4.ops.custom_ops import register_runner, nvfp4_linear_gemm
|
||||
|
||||
|
||||
class Nvfp4GroupedLinear:
|
||||
"""Grouped NVFP4 linear for wo_a (o-projection first half).
|
||||
|
||||
Handles the "bhr,hdr->bhd" einsum pattern:
|
||||
- o: (tokens, n_local_heads, head_dim) → reshape to (tokens, n_local_groups, heads_per_group * head_dim)
|
||||
- wo_a: (n_local_groups, heads_per_group * head_dim, o_lora_rank) → NVFP4 per group
|
||||
- z: (tokens, n_local_groups, o_lora_rank)
|
||||
|
||||
Uses ScaledGroupedGemm with num_groups=n_local_groups.
|
||||
Every token goes to every group (no routing).
|
||||
|
||||
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_local_groups: int,
|
||||
heads_per_group: int,
|
||||
head_dim: int,
|
||||
o_lora_rank: int,
|
||||
max_num_tokens: int = 8192,
|
||||
device: str = "cuda",
|
||||
):
|
||||
self.n_local_groups = n_local_groups
|
||||
self.heads_per_group = heads_per_group
|
||||
self.head_dim = head_dim
|
||||
self.o_lora_rank = o_lora_rank
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.device = device
|
||||
|
||||
# Per-group dimensions
|
||||
self.group_in_features = heads_per_group * head_dim # 8192
|
||||
self.group_out_features = o_lora_rank # 1536
|
||||
|
||||
# NVFP4 weight storage: lists of per-group tensors
|
||||
self._weight_fp4 = None # list of (K//2, N) float4_e2m1fn_x2
|
||||
self._weight_sf = None # list of (K//16, N) float8_e4m3fn
|
||||
self._weight_gs = None # list of float32
|
||||
|
||||
# Processed weights (set by finalize_weights)
|
||||
self._mat_b = None
|
||||
self._scale_b = None
|
||||
self._gsb = None
|
||||
|
||||
# Activation global scale
|
||||
self._activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
|
||||
# Pre-allocated buffers
|
||||
self._padded_x_fp4_buf = None
|
||||
self._gsa_buf = None
|
||||
self._expert_offsets_buf = None
|
||||
self._buffers_allocated = False
|
||||
|
||||
def set_bf16_weight(self, wo_a_bf16: torch.Tensor):
|
||||
"""Set wo_a weight from BF16 and quantize to NVFP4.
|
||||
|
||||
Args:
|
||||
wo_a_bf16: (n_local_groups * o_lora_rank, heads_per_group * head_dim) BF16
|
||||
OR (n_local_groups, heads_per_group * head_dim, o_lora_rank) if from bmm
|
||||
"""
|
||||
# Quantize each group separately
|
||||
fp4_list = []
|
||||
sf_list = []
|
||||
gs_list = []
|
||||
|
||||
if wo_a_bf16.ndim == 3:
|
||||
# bmm format: (n_local_groups, heads_per_group * head_dim, o_lora_rank)
|
||||
for g in range(self.n_local_groups):
|
||||
w_g = wo_a_bf16[g] # (in_features, out_features)
|
||||
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_g)
|
||||
# quantize_weight_to_nvfp4 returns (K//2, N) with K=in_features
|
||||
# Our kernel expects (K_packed, N_packed) where K is the contraction dim
|
||||
# For weight (in_features, out_features): K=in_features (contraction)
|
||||
# quantize_weight_to_nvfp4 treats dim 0 as K, so result is (K//2, N) ✓
|
||||
fp4_list.append(w_fp4)
|
||||
sf_list.append(w_sf)
|
||||
gs_list.append(w_gs)
|
||||
else:
|
||||
# Dense format: (n_local_groups * o_lora_rank, heads_per_group * head_dim)
|
||||
# Split into per-group blocks
|
||||
for g in range(self.n_local_groups):
|
||||
start = g * self.o_lora_rank
|
||||
end = start + self.o_lora_rank
|
||||
w_g = wo_a_bf16[start:end, :] # (o_lora_rank, in_features)
|
||||
# NOTE: This is transposed — weight is (out, in) but quantize_weight_to_nvfp4
|
||||
# expects (K, N) where K is the packed/contraction dim.
|
||||
# For matmul X @ W^T, the contraction dim of W is dim 1 (in_features).
|
||||
# So we need to transpose before quantizing.
|
||||
w_g_t = w_g.T # (in_features, o_lora_rank) = (K, N)
|
||||
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_g_t)
|
||||
fp4_list.append(w_fp4)
|
||||
sf_list.append(w_sf)
|
||||
gs_list.append(w_gs)
|
||||
|
||||
self._weight_fp4 = fp4_list
|
||||
self._weight_sf = sf_list
|
||||
self._weight_gs = gs_list
|
||||
|
||||
def load_nvfp4_weight(self, weight, weight_scale, weight_scale_2=None, input_scale=None):
|
||||
"""Load NVFP4 weights directly from checkpoint — no dequant/re-quant.
|
||||
|
||||
The checkpoint stores weights in (out_features, in_features) layout:
|
||||
weight: (n_groups * o_rank, group_in_features // 2) uint8
|
||||
weight_scale: (n_groups * o_rank, group_in_features // 16) float8_e4m3fn
|
||||
weight_scale_2: scalar or (n_groups * o_rank,) float
|
||||
input_scale: scalar or (n_groups * o_rank,) float (unused for weight dequant)
|
||||
|
||||
Each group's chunk is (o_rank, K_packed) = (N, K_packed) in row-major.
|
||||
Our GEMM expects (K_packed, N) per group, so we transpose each group.
|
||||
Block scales follow the same transpose.
|
||||
|
||||
Args:
|
||||
weight: (n_groups * o_rank, group_in_features // 2) uint8
|
||||
weight_scale: (n_groups * o_rank, group_in_features // 16) float8_e4m3fn
|
||||
weight_scale_2: scalar or per-row scale tensor (optional)
|
||||
input_scale: scalar or per-row (unused — for activation quantization)
|
||||
"""
|
||||
fp4_list = []
|
||||
sf_list = []
|
||||
gs_list = []
|
||||
|
||||
K_packed = self.group_in_features // 2
|
||||
N = self.o_lora_rank
|
||||
K_sf = self.group_in_features // 16 # block scale dim along K
|
||||
|
||||
for g in range(self.n_local_groups):
|
||||
# Extract this group's weight: (o_rank, K_packed) = (N, K_packed)
|
||||
start = g * N
|
||||
end = start + N
|
||||
w_g = weight[start:end] # (N, K_packed) uint8
|
||||
ws_g = weight_scale[start:end] # (N, K_sf) float8_e4m3fn
|
||||
|
||||
# Transpose to (K_packed, N) — the layout quantize_weight_to_nvfp4 produces
|
||||
w_g_t = w_g.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
ws_g_t = ws_g.permute(1, 0).contiguous()
|
||||
|
||||
fp4_list.append(w_g_t)
|
||||
sf_list.append(ws_g_t)
|
||||
|
||||
# Global scale: weight_scale_2
|
||||
if weight_scale_2 is not None:
|
||||
if weight_scale_2.numel() == 1:
|
||||
gs_list.append(weight_scale_2.float().item())
|
||||
else:
|
||||
# Per-row: take mean of this group's rows
|
||||
gs_list.append(weight_scale_2[start:end].float().mean().item())
|
||||
else:
|
||||
gs_list.append(1.0)
|
||||
|
||||
self._weight_fp4 = fp4_list
|
||||
self._weight_sf = sf_list
|
||||
self._weight_gs = gs_list
|
||||
|
||||
def finalize_weights(self):
|
||||
"""Process NVFP4 weights for CuTeDSL GEMM."""
|
||||
if self._weight_fp4 is None:
|
||||
raise RuntimeError("Call set_bf16_weight() before finalize_weights()")
|
||||
|
||||
self._mat_b = make_b_k_major(torch.stack(self._weight_fp4)) # (groups, K_packed, N_packed)
|
||||
self._scale_b = assemble_scales_3d_side(self._weight_sf)
|
||||
self._gsb = torch.tensor(self._weight_gs, dtype=torch.float32, device=self.device)
|
||||
|
||||
# Free raw weights
|
||||
self._weight_fp4 = None
|
||||
self._weight_sf = None
|
||||
self._weight_gs = None
|
||||
|
||||
def _allocate_buffers(self):
|
||||
"""Pre-allocate buffers at max size for cudagraph compatibility."""
|
||||
max_rows_per_group = cutedsl_ceil_div(self.max_num_tokens, 128) * 128
|
||||
total_max_rows = max_rows_per_group * self.n_local_groups
|
||||
|
||||
self._padded_x_fp4_buf = torch.zeros(
|
||||
total_max_rows, self.group_in_features // 2, dtype=torch.uint8, device=self.device
|
||||
).view(torch.float4_e2m1fn_x2)
|
||||
|
||||
self._gsa_buf = torch.zeros(self.n_local_groups, dtype=torch.float32, device=self.device)
|
||||
self._expert_offsets_buf = torch.zeros(self.n_local_groups, dtype=torch.int32, device=self.device)
|
||||
self._buffers_allocated = True
|
||||
|
||||
def _ensure_initialized(self):
|
||||
if self._mat_b is None:
|
||||
self.finalize_weights()
|
||||
if not self._buffers_allocated:
|
||||
self._allocate_buffers()
|
||||
|
||||
def _assemble_scales_single_group(self, x_sf):
|
||||
"""Assemble 2D-side activation scales for num_groups=1."""
|
||||
num_rows, num_cols = x_sf.shape
|
||||
padded_rows = cutedsl_ceil_div(num_rows, 128) * 128
|
||||
padded_cols = cutedsl_ceil_div(num_cols, 4) * 4
|
||||
|
||||
buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn)
|
||||
buf[:num_rows, :num_cols] = x_sf
|
||||
swizzled_flat = pad_and_swizzle_single(buf)
|
||||
return swizzled_flat.reshape(padded_rows, padded_cols)
|
||||
|
||||
def compute_activation_global_scale(self, o_sample: torch.Tensor):
|
||||
"""Compute activation global scale from a warmup forward.
|
||||
|
||||
Args:
|
||||
o_sample: (tokens, n_local_heads, head_dim) BF16 attention output sample
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
# Reshape to grouped format, then flatten to 2D for quantization
|
||||
o_grouped = o_sample.reshape(-1, self.n_local_groups, self.group_in_features)
|
||||
# We need a single gs for all groups — use the overall amax
|
||||
from dsv4.ops.quantize import (
|
||||
quantize_to_nvfp4,
|
||||
)
|
||||
o_flat = o_sample.reshape(-1, o_sample.shape[-1]) # (tokens, n_local_heads * head_dim) — not right
|
||||
# Actually, for grouped GEMM, each group's activation is (tokens, group_in_features)
|
||||
# The global scale should be computed per-group, but for simplicity use one scale
|
||||
# based on the overall amax.
|
||||
with torch.no_grad():
|
||||
_, _, gs = quantize_to_nvfp4(o_grouped.reshape(-1, self.group_in_features))
|
||||
self._activation_global_scale = gs
|
||||
|
||||
def run(self, o: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward: BF16 attention output → NVFP4 grouped GEMM → BF16 z.
|
||||
|
||||
Args:
|
||||
o: (num_tokens, n_local_heads, head_dim) BF16 — attention output
|
||||
AFTER inverse RoPE has been applied
|
||||
|
||||
Returns:
|
||||
z: (num_tokens, n_local_groups, o_lora_rank) BF16
|
||||
"""
|
||||
if not hasattr(self, '_runner_id'):
|
||||
self._runner_id = register_runner(self)
|
||||
return nvfp4_linear_gemm(
|
||||
o, self._runner_id, self.n_local_groups * self.o_lora_rank,
|
||||
)
|
||||
|
||||
def _run_impl(self, o: torch.Tensor) -> torch.Tensor:
|
||||
"""Actual implementation.
|
||||
|
||||
Input o is (tokens, n_local_heads, head_dim).
|
||||
We reshape to (tokens, n_local_groups, heads_per_group * head_dim),
|
||||
then treat each group's (tokens, group_in_features) as one "expert"
|
||||
in our grouped GEMM. All tokens go to all groups.
|
||||
|
||||
The grouped GEMM layout requires each group's tokens to be
|
||||
contiguous at their correct offset:
|
||||
- Group 0: rows [0, padded_T)
|
||||
- Group 1: rows [padded_T, 2*padded_T)
|
||||
- ...
|
||||
- Group G: rows [(G-1)*padded_T, G*padded_T)
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
|
||||
num_tokens = o.shape[0]
|
||||
padded_rows_per_group = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
|
||||
# Reshape: (tokens, n_local_heads, head_dim) → (tokens, n_local_groups, group_in_features)
|
||||
o_grouped = o.reshape(num_tokens, self.n_local_groups, self.group_in_features)
|
||||
|
||||
# Permute to groups-first: (G, T, D)
|
||||
o_grouped = o_grouped.permute(1, 0, 2)
|
||||
|
||||
# Flatten all groups into (G*T, D) for batched fused quantize — single kernel launch
|
||||
o_flat = o_grouped.reshape(self.n_local_groups * num_tokens, self.group_in_features)
|
||||
|
||||
# Fused amax + quantize: zero CPU-GPU syncs.
|
||||
# Computes gsa on GPU, quantizes to NVFP4, returns GPU tensor.
|
||||
# Replaces the old path: .item() sync + Python quantize per group.
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
x_fp4_flat, x_sf_flat, gsa_gpu = quantize_nvfp4_gpu_fused(o_flat)
|
||||
# gsa_gpu is (G*T,) — all rows share same amax (from max over full tensor)
|
||||
# For the GEMM's global_scale_a, fill all group slots with the same gsa value
|
||||
# Use GPU-only copy: no .item(), no CPU sync
|
||||
self._gsa_buf[:1].copy_(gsa_gpu[:1]) # GPU→GPU scalar copy, no sync
|
||||
# Broadcast to all groups (all get same gsa)
|
||||
if self.n_local_groups > 1:
|
||||
self._gsa_buf[1:].copy_(self._gsa_buf[:1].expand(self.n_local_groups - 1))
|
||||
else:
|
||||
self._gsa_buf.fill_(self._activation_global_scale)
|
||||
x_fp4_flat, x_sf_flat = quantize_activation_nvfp4(
|
||||
o_flat, self._activation_global_scale
|
||||
)
|
||||
|
||||
# Reshape FP4 back to (G, T, D//2) and scatter into padded buffer
|
||||
padded_x_fp4 = self._padded_x_fp4_buf
|
||||
padded_x_fp4.view(torch.uint8).zero_()
|
||||
|
||||
x_fp4_grouped = x_fp4_flat.reshape(self.n_local_groups, num_tokens, self.group_in_features // 2)
|
||||
|
||||
for g in range(self.n_local_groups):
|
||||
offset = g * padded_rows_per_group
|
||||
padded_x_fp4.view(torch.uint8)[offset:offset + num_tokens] = x_fp4_grouped[g].view(torch.uint8)
|
||||
|
||||
# Reshape scales back to (G, T, D//16) and assemble
|
||||
x_sf_grouped = x_sf_flat.reshape(self.n_local_groups, num_tokens, self.group_in_features // 16)
|
||||
all_x_sf = [x_sf_grouped[g] for g in range(self.n_local_groups)]
|
||||
|
||||
# Assemble A-side scales for all groups
|
||||
from dsv4.ops.layouts import (
|
||||
assemble_scales_2d_side,
|
||||
)
|
||||
scale_a = assemble_scales_2d_side(all_x_sf)
|
||||
|
||||
# Expert offsets: cumulative [padded_T, 2*padded_T, ..., n_groups*padded_T]
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
for g in range(self.n_local_groups):
|
||||
expert_offsets[g] = (g + 1) * padded_rows_per_group
|
||||
|
||||
# Global scales — GPU-computed gsa already in _gsa_buf (no CPU sync)
|
||||
gsa = self._gsa_buf
|
||||
|
||||
# Run grouped GEMM
|
||||
out = run_nvfp4_grouped_gemm(
|
||||
mat_a=padded_x_fp4,
|
||||
mat_b=self._mat_b,
|
||||
scale_a=scale_a,
|
||||
scale_b=self._scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=gsa,
|
||||
global_scale_b=self._gsb,
|
||||
)
|
||||
|
||||
# Extract real outputs and reshape
|
||||
# GEMM output has the same layout as mat_a: groups-first with padding
|
||||
z = torch.empty(num_tokens, self.n_local_groups, self.o_lora_rank,
|
||||
dtype=torch.bfloat16, device=o.device)
|
||||
for g in range(self.n_local_groups):
|
||||
offset = g * padded_rows_per_group
|
||||
z[:, g, :] = out[offset:offset + num_tokens, :]
|
||||
|
||||
return z
|
||||
|
||||
def __call__(self, o: torch.Tensor) -> torch.Tensor:
|
||||
return self.run(o)
|
||||
267
dsv4/_archive/layers/linear.py
Normal file
267
dsv4/_archive/layers/linear.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""CuTeDSL NVFP4 Linear (single GEMM)
|
||||
|
||||
Generic NVFP4 GEMM runner for attention projections and any single
|
||||
linear layer. Uses ScaledGroupedGemmKernel with num_groups=1.
|
||||
|
||||
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from dsv4.ops.quantize import (
|
||||
quantize_activation_nvfp4,
|
||||
quantize_to_nvfp4,
|
||||
)
|
||||
from dsv4.ops.layouts import (
|
||||
make_b_k_major,
|
||||
)
|
||||
from dsv4.ops.gemm_runner import (
|
||||
run_nvfp4_grouped_gemm,
|
||||
)
|
||||
from dsv4.kernels.gemm.grouped import (
|
||||
ceil_div as cutedsl_ceil_div,
|
||||
pad_and_swizzle_single,
|
||||
)
|
||||
from dsv4.ops.custom_ops import register_runner, nvfp4_linear_gemm
|
||||
|
||||
|
||||
class Nvfp4Linear:
|
||||
"""Single NVFP4 GEMM using CuTeDSL (num_groups=1).
|
||||
|
||||
Handles any (K, N) weight matrix in NVFP4 format.
|
||||
Simple: quantize activation → GEMM → BF16 output.
|
||||
No SiLU, no fusion, no routing.
|
||||
|
||||
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
max_num_tokens: int = 8192,
|
||||
device: str = "cuda",
|
||||
):
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.device = device
|
||||
|
||||
# Weights (set after construction, then call finalize_weights)
|
||||
self.fp4 = None # list of 1 tensor
|
||||
self.sf = None # list of 1 tensor
|
||||
self.gs = None # list of 1 float
|
||||
self.ws2 = None # list of 1 tensor — weight_scale_2 (scalar, folded into global_scale_b)
|
||||
|
||||
# Processed weights
|
||||
self._mat_b = None
|
||||
self._scale_b = None
|
||||
self._gsb = None
|
||||
|
||||
# Activation global scale
|
||||
self._activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
|
||||
# Pre-allocated buffers
|
||||
self._padded_x_fp4_buf = None
|
||||
self._expert_offsets_buf = None
|
||||
self._gsa_buf = None
|
||||
self._buffers_allocated = False
|
||||
|
||||
def finalize_weights(self):
|
||||
"""Process weights for CuTeDSL GEMM."""
|
||||
# Convert uint8 checkpoint weights to float4_e2m1fn_x2 view
|
||||
fp4_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.fp4]
|
||||
# Checkpoint weight is (out_features//2, in_features//2) = (N_packed, K_packed)
|
||||
# make_b_k_major expects (E, K_packed, N_packed), so we need to permute
|
||||
stacked = torch.stack(fp4_view).permute(0, 2, 1).contiguous() # (1, K_packed, N_packed)
|
||||
self._mat_b = make_b_k_major(stacked)
|
||||
# Checkpoint scale is (N_packed, K_sf) — already in the right row order for the
|
||||
# kernel's swizzle. Use assemble_raw_scales_2d3d_3d_side (no transpose),
|
||||
# NOT assemble_scales_3d_side (which transposes K_sf↔N).
|
||||
from dsv4.ops.layouts import assemble_raw_scales_2d3d_3d_side
|
||||
self._scale_b = assemble_raw_scales_2d3d_3d_side(self.sf)
|
||||
self._gsb = torch.tensor(self.gs, dtype=torch.float32, device=self.device)
|
||||
|
||||
# Fold weight_scale_2 into global_scale_b
|
||||
# Dequant formula: w = lut[w_packed] * weight_scale * weight_scale_2
|
||||
# Production GEMM: y = (x * scale_a * gsa) @ (w * scale_b * gsb)
|
||||
# So gsb = input_scale * weight_scale_2
|
||||
if self.ws2 is not None and len(self.ws2) > 0 and self.ws2[0] is not None:
|
||||
ws2_val = self.ws2[0].float().item()
|
||||
self._gsb = self._gsb * ws2_val
|
||||
|
||||
# Free raw weights
|
||||
self.fp4 = None
|
||||
self.sf = None
|
||||
self.gs = None
|
||||
self.ws2 = None
|
||||
|
||||
# Eagerly JIT-compile the GEMM kernel for this (K, N) shape.
|
||||
# Uses num_groups=1 since this is a single linear layer.
|
||||
K_packed = self.in_features // 2
|
||||
N_packed = self.out_features // 2
|
||||
# warmup_compilation(1, K_packed, N_packed, self.device) # Lazy compile on first real forward
|
||||
|
||||
def _ensure_buffer_size(self, num_tokens: int):
|
||||
"""Ensure the padded buffer is large enough for num_tokens."""
|
||||
needed_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
if self._padded_x_fp4_buf is not None and self._padded_x_fp4_buf.shape[0] >= needed_rows:
|
||||
return # Already big enough
|
||||
|
||||
self._padded_x_fp4_buf = torch.zeros(
|
||||
needed_rows, self.in_features // 2, dtype=torch.uint8, device=self.device
|
||||
).view(torch.float4_e2m1fn_x2)
|
||||
|
||||
self._expert_offsets_buf = torch.zeros(1, dtype=torch.int32, device=self.device)
|
||||
self._gsa_buf = torch.full((1,), self._activation_global_scale, dtype=torch.float32, device=self.device)
|
||||
|
||||
def _ensure_initialized(self):
|
||||
if self._mat_b is None:
|
||||
self.finalize_weights()
|
||||
|
||||
def _assemble_scales_single_group(self, x_sf):
|
||||
"""Assemble 2D-side activation scales for num_groups=1."""
|
||||
num_rows, num_cols = x_sf.shape
|
||||
padded_rows = cutedsl_ceil_div(num_rows, 128) * 128
|
||||
padded_cols = cutedsl_ceil_div(num_cols, 4) * 4
|
||||
|
||||
buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn)
|
||||
buf[:num_rows, :num_cols] = x_sf
|
||||
swizzled_flat = pad_and_swizzle_single(buf)
|
||||
return swizzled_flat.reshape(padded_rows, padded_cols)
|
||||
|
||||
def compute_activation_global_scale(self, hidden_states_sample):
|
||||
"""Compute activation global scale from a warmup forward."""
|
||||
self._ensure_initialized()
|
||||
with torch.no_grad():
|
||||
_, _, gs = quantize_to_nvfp4(hidden_states_sample)
|
||||
self._activation_global_scale = gs
|
||||
|
||||
|
||||
def run(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward: BF16 input → NVFP4 GEMM → BF16 output.
|
||||
|
||||
Uses torch.library.custom_op (nvfp4::linear_gemm) so torch.compile
|
||||
treats this as an opaque op. The custom op calls _run_impl internally.
|
||||
"""
|
||||
if not hasattr(self, '_runner_id'):
|
||||
self._runner_id = register_runner(self)
|
||||
return nvfp4_linear_gemm(
|
||||
hidden_states, self._runner_id, self.out_features,
|
||||
)
|
||||
|
||||
def _run_impl(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Actual implementation — called via custom autograd to be torch.compile-safe."""
|
||||
self._ensure_initialized()
|
||||
|
||||
num_tokens = hidden_states.shape[0]
|
||||
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
|
||||
# Ensure buffer is large enough
|
||||
self._ensure_buffer_size(num_tokens)
|
||||
|
||||
# Fused amax + quantize: single kernel launch, zero CPU-GPU syncs.
|
||||
# Computes amax on GPU → derives gsa → quantizes to NVFP4.
|
||||
# gsa written to GPU buffer for downstream GEMM global_scale_a.
|
||||
#
|
||||
# This replaces the two-step path:
|
||||
# compute_amax_gsa_gpu(hidden_states) → .item() sync
|
||||
# quantize_nvfp4_gpu(hidden_states, gsa_float) → another kernel launch
|
||||
#
|
||||
# Old path: ~2 kernel launches + 1 .item() sync per projection.
|
||||
# New path: 1 kernel launch + 0 .item() syncs per projection.
|
||||
# Total across 61 layers: ~486 .item() syncs eliminated.
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
||||
x_fp4, x_sf, gsa_gpu = quantize_nvfp4_gpu_fused(hidden_states)
|
||||
self._gsa_buf.copy_(gsa_gpu[:1].reshape(1)) # GPU → GPU, no sync
|
||||
else:
|
||||
# P2 FIX: No per-call fill_(). The _gsa_buf already has the correct
|
||||
# value — set either during initialization (via _ensure_buffer_size)
|
||||
# or by the first GPU compute when _use_runtime_gsa was True.
|
||||
# Old path: self._gsa_buf.fill_(self._activation_global_scale)
|
||||
# — H2D transfer every call (~5µs each × 244 calls = ~1.2ms/token).
|
||||
# New path: zero H2D transfers on the hot path.
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu
|
||||
x_fp4, x_sf = quantize_nvfp4_gpu(hidden_states, self._activation_global_scale)
|
||||
|
||||
# Scatter x_fp4 into padded buffer
|
||||
padded_x_fp4 = self._padded_x_fp4_buf
|
||||
padded_x_fp4.view(torch.uint8).zero_()
|
||||
padded_x_fp4.view(torch.uint8)[:x_fp4.shape[0]] = x_fp4.view(torch.uint8)
|
||||
|
||||
# Assemble A-side scales
|
||||
scale_a = self._assemble_scales_single_group(x_sf)
|
||||
|
||||
# Expert offsets: [padded_rows] for 1 group
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.fill_(padded_rows)
|
||||
|
||||
# Global scales — GPU-computed gsa already in _gsa_buf (no CPU sync)
|
||||
gsa = self._gsa_buf
|
||||
|
||||
# Run GEMM
|
||||
out = run_nvfp4_grouped_gemm(
|
||||
mat_a=padded_x_fp4,
|
||||
mat_b=self._mat_b,
|
||||
scale_a=scale_a,
|
||||
scale_b=self._scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=gsa,
|
||||
global_scale_b=self._gsb,
|
||||
)
|
||||
|
||||
return out[:num_tokens]
|
||||
|
||||
def run_from_quantized(self, quant: 'QuantizedActivation') -> torch.Tensor:
|
||||
"""Run GEMM with pre-quantized activation (skip quantize step).
|
||||
|
||||
Used when the input has already been quantized by a fused
|
||||
RMSNorm+quantize kernel. Saves 2 kernel launches per call.
|
||||
|
||||
Args:
|
||||
quant: QuantizedActivation with x_fp4, x_sf, gsa
|
||||
"""
|
||||
from dsv4.ops.quantize import QuantizedActivation
|
||||
assert isinstance(quant, QuantizedActivation)
|
||||
|
||||
self._ensure_initialized()
|
||||
num_tokens = quant.num_tokens
|
||||
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
self._ensure_buffer_size(num_tokens)
|
||||
|
||||
# Scatter pre-quantized x_fp4 into padded buffer
|
||||
padded_x_fp4 = self._padded_x_fp4_buf
|
||||
padded_x_fp4.view(torch.uint8).zero_()
|
||||
padded_x_fp4.view(torch.uint8)[:quant.x_fp4.shape[0]] = quant.x_fp4.view(torch.uint8)
|
||||
|
||||
# Assemble A-side scales from pre-quantized sf
|
||||
scale_a = self._assemble_scales_single_group(quant.x_sf)
|
||||
|
||||
# Expert offsets
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.fill_(padded_rows)
|
||||
|
||||
# Global scales — use the per-row gsa from the fused kernel
|
||||
# Reshape to (1,) if scalar, or use per-row (M,) broadcast
|
||||
gsa = quant.gsa[:1].reshape(1) if quant.gsa.shape[0] == 1 else quant.gsa[:num_tokens]
|
||||
if gsa.shape != self._gsa_buf.shape:
|
||||
self._gsa_buf = gsa.contiguous()
|
||||
else:
|
||||
self._gsa_buf.copy_(gsa)
|
||||
|
||||
# Run GEMM
|
||||
out = run_nvfp4_grouped_gemm(
|
||||
mat_a=padded_x_fp4,
|
||||
mat_b=self._mat_b,
|
||||
scale_a=scale_a,
|
||||
scale_b=self._scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=self._gsa_buf,
|
||||
global_scale_b=self._gsb,
|
||||
)
|
||||
|
||||
return out[:num_tokens]
|
||||
|
||||
def __call__(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.run(hidden_states)
|
||||
549
dsv4/_archive/layers/mhc.py
Normal file
549
dsv4/_archive/layers/mhc.py
Normal file
@@ -0,0 +1,549 @@
|
||||
"""
|
||||
mHC (Manifold-Constrained Hyper-Connections) — Inference Layer.
|
||||
|
||||
Implements Section 2.2 of the DeepSeek-V4 paper for the forward pass only.
|
||||
|
||||
Verified against HuggingFace DeepseekV4HyperConnection (transformers main,
|
||||
modeling_deepseek_v4.py). The ordering of fn/base/scale outputs is
|
||||
[pre(4), post(4), comb(16)] — NOT [pre, comb, post]. The comb matrix is
|
||||
consumed TRANSPOSED in post_block. Sinkhorn starts from softmax (not exp).
|
||||
pre (A_l) has an hc_eps additive guard.
|
||||
|
||||
---------------------------------------------------------------------
|
||||
V4-Pro reference dimensions (Section 4.2.1)
|
||||
---------------------------------------------------------------------
|
||||
d = 7168 hidden dim
|
||||
n_hc = 4 hyper-connection expansion factor
|
||||
N_proj = 24 fused output of W_pre(4) + W_post(4) + W_comb(16)
|
||||
K_proj = 4*7168 = 28672 = n_hc * d (flattened residual)
|
||||
t_max = 20 Sinkhorn iterations
|
||||
|
||||
---------------------------------------------------------------------
|
||||
Checkpoint layout (fn / base / scale)
|
||||
---------------------------------------------------------------------
|
||||
fn: (24, 28672) — rows ordered [pre(4), post(4), comb(16)]
|
||||
base: (24,) — ordered [pre(4), post(4), comb(16)]
|
||||
scale: (3,) — [alpha_pre, alpha_post, alpha_comb]
|
||||
|
||||
This matches the HuggingFace split:
|
||||
pre_w, post_w, comb_w = F.linear(flat, fn).split([4, 4, 16])
|
||||
pre_b, post_b, comb_b = base.split([4, 4, 16])
|
||||
pre_scale, post_scale, comb_scale = scale.unbind(0)
|
||||
|
||||
---------------------------------------------------------------------
|
||||
Kernel dependency
|
||||
---------------------------------------------------------------------
|
||||
tf32_hc_prenorm_gemm (DeepGEMM, SM90/SM100)
|
||||
a: (T, K) BF16 — flattened residual X_flat
|
||||
b: (N, K) FP32 — stacked weight [W_pre; W_post; W_comb]
|
||||
d: (S, T, N) or (T, N) FP32 — raw projection outputs (pre-normalised)
|
||||
sqr_sum: (S, T) or (T,) FP32 — Σ a² per token (for RMSNorm denominator)
|
||||
num_splits = S (16 recommended for K=28672)
|
||||
|
||||
After the call:
|
||||
d = d.sum(0) → (T, N)
|
||||
sqr_sum = sqr_sum.sum(0) → (T,)
|
||||
rms_scale = sqrt(K / (sqr_sum + eps))
|
||||
d_norm = d * rms_scale[:,None] — equivalent to RMSNorm(X_flat) @ W_stacked
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Try importing DeepGEMM; fall back to plain BF16 matmul if unavailable.
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
import deep_gemm
|
||||
_HAS_DEEP_GEMM = True
|
||||
except ImportError:
|
||||
_HAS_DEEP_GEMM = False
|
||||
|
||||
|
||||
NUM_SPLITS = 16 # K-split count for tf32_hc_prenorm_gemm numerical stability
|
||||
EPS_RMSN = 1e-6
|
||||
HC_EPS = 1e-6 # eps guard on pre (A_l) and Sinkhorn, matching HF reference
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sinkhorn-Knopp projection (T batched 4×4 matrices)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def sinkhorn_knopp(
|
||||
logits: torch.Tensor, # (T, n, n) raw logits (NOT exp'd)
|
||||
t_max: int = 20,
|
||||
eps: float = HC_EPS,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Project each (n×n) matrix onto the Birkhoff polytope
|
||||
(doubly stochastic matrices) via alternating row/col normalisation.
|
||||
|
||||
Matches HuggingFace DeepseekV4HyperConnection.forward:
|
||||
1. softmax along last dim (row-normalize the logits)
|
||||
2. add eps
|
||||
3. column-normalize
|
||||
4. (t_max - 1) alternating row/col normalizations
|
||||
|
||||
NO PYTHON FALLBACK. If the CUDA kernel fails, the pipeline dies.
|
||||
The kernel MUST compile and run correctly. Period.
|
||||
"""
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("mhc_sinkhorn", ["mhc_sinkhorn.cu"])
|
||||
return mod.mhc_sinkhorn(logits.float(), t_max, eps)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context carried between pre_block and post_block
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class mHCContext:
|
||||
"""Holds the per-token mixing matrices computed in pre_block."""
|
||||
B_l: torch.Tensor # (T, n_hc, n_hc) doubly stochastic residual transform
|
||||
C_l: torch.Tensor # (T, n_hc) output mapping (2*sigmoid)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# mHC layer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class mHCLayer:
|
||||
"""
|
||||
Wraps one transformer sub-layer (attention *or* MoE) with the mHC
|
||||
residual update.
|
||||
|
||||
Typical call pattern per layer:
|
||||
|
||||
x_in, ctx = mhc.pre_block(X_l)
|
||||
F_out = transformer_sublayer(x_in) # (T, d)
|
||||
X_next = mhc.post_block(X_l, F_out, ctx)
|
||||
|
||||
where X_l has shape (T, n_hc, d) — the expanded residual state.
|
||||
The first call at layer 0 should use X_0 initialised via `init_state`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int = 7168,
|
||||
n_hc: int = 4,
|
||||
t_max_sinkhorn: int = 20,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
):
|
||||
self.d = hidden_dim
|
||||
self.n_hc = n_hc
|
||||
self.K_proj = n_hc * hidden_dim # 28672 for V4-Pro
|
||||
self.N_proj = n_hc + n_hc + n_hc * n_hc # 4 + 4 + 16 = 24
|
||||
self.t_max = t_max_sinkhorn
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
# ── Learnable weights (set via load_weights) ──────────────────
|
||||
# Checkpoint fn ordering: [pre(4), post(4), comb(16)]
|
||||
# We store them in this order and build W_stacked = [pre, post, comb]
|
||||
self.W_pre = self._buf(n_hc, self.K_proj, dtype=torch.float32) # (4, K)
|
||||
self.W_post = self._buf(n_hc, self.K_proj, dtype=torch.float32) # (4, K)
|
||||
self.W_comb = self._buf(n_hc * n_hc, self.K_proj, dtype=torch.float32) # (16, K)
|
||||
|
||||
# Checkpoint base ordering: [pre(4), post(4), comb(16)]
|
||||
self.S_pre = self._buf(1, n_hc) # (1, 4) — pre bias
|
||||
self.S_post = self._buf(n_hc, 1) # (4, 1) — post bias
|
||||
self.S_comb = self._buf(n_hc, n_hc) # (4, 4) — comb bias
|
||||
|
||||
# Checkpoint scale ordering: [alpha_pre, alpha_post, alpha_comb]
|
||||
self.alpha_pre = torch.zeros(1, device=device, dtype=torch.float32)
|
||||
self.alpha_post = torch.zeros(1, device=device, dtype=torch.float32)
|
||||
self.alpha_comb = torch.zeros(1, device=device, dtype=torch.float32)
|
||||
|
||||
# Pre-allocated split buffers (set in _ensure_buffers)
|
||||
self._d_split = None # (NUM_SPLITS, max_T, N_proj) FP32
|
||||
self._sqr_sum_split = None # (NUM_SPLITS, max_T) FP32
|
||||
self._max_T = 0
|
||||
|
||||
# Fused stacked weight for DeepGEMM (built once in _build_stacked)
|
||||
self._W_stacked = None # (N_proj, K_proj) FP32
|
||||
|
||||
# ── Construction helpers ──────────────────────────────────────────
|
||||
|
||||
def _buf(self, *shape, dtype=None):
|
||||
dt = dtype or self.dtype
|
||||
return torch.empty(*shape, dtype=dt, device=self.device)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
W_pre: torch.Tensor, # (n_hc, K) FP32
|
||||
W_post: torch.Tensor, # (n_hc, K) FP32
|
||||
W_comb: torch.Tensor, # (n_hc², K) FP32
|
||||
S_pre: torch.Tensor, # (1, n_hc)
|
||||
S_post: torch.Tensor, # (n_hc, 1)
|
||||
S_comb: torch.Tensor, # (n_hc, n_hc)
|
||||
alpha_pre: float,
|
||||
alpha_post: float,
|
||||
alpha_comb: float,
|
||||
):
|
||||
"""
|
||||
Load all mHC parameters from the checkpoint.
|
||||
|
||||
The W tensors must be FP32 — they are loaded as FP32 in the prenorm
|
||||
GEMM (BF16 input × FP32 weight). Everything else can be BF16 in the
|
||||
checkpoint and will be cast here.
|
||||
"""
|
||||
def _f32(t): return t.to(device=self.device, dtype=torch.float32).contiguous()
|
||||
def _cvt(t): return t.to(device=self.device, dtype=self.dtype).contiguous()
|
||||
|
||||
self.W_pre = _f32(W_pre)
|
||||
self.W_post = _f32(W_post)
|
||||
self.W_comb = _f32(W_comb)
|
||||
self.S_pre = _cvt(S_pre)
|
||||
self.S_post = _cvt(S_post)
|
||||
self.S_comb = _cvt(S_comb)
|
||||
self.alpha_pre = torch.tensor(alpha_pre, dtype=torch.float32, device=self.device)
|
||||
self.alpha_post = torch.tensor(alpha_post, dtype=torch.float32, device=self.device)
|
||||
self.alpha_comb = torch.tensor(alpha_comb, dtype=torch.float32, device=self.device)
|
||||
self._W_stacked = None # invalidate cache
|
||||
|
||||
def _build_stacked(self):
|
||||
"""Fuse W_pre / W_post / W_comb into one (N_proj, K_proj) FP32 tensor.
|
||||
|
||||
Order: [pre(4), post(4), comb(16)] — matches checkpoint fn layout.
|
||||
"""
|
||||
self._W_stacked = torch.cat([self.W_pre, self.W_post, self.W_comb], dim=0)
|
||||
# Must be K-major (contiguous along K) for DeepGEMM
|
||||
self._W_stacked = self._W_stacked.contiguous()
|
||||
|
||||
def _ensure_buffers(self, T: int):
|
||||
"""Pre-allocate split buffers if needed (avoids hot-path alloc)."""
|
||||
if T <= self._max_T:
|
||||
return
|
||||
self._d_split = torch.empty(
|
||||
NUM_SPLITS, T, self.N_proj, dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._sqr_sum_split = torch.empty(
|
||||
NUM_SPLITS, T, dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._max_T = T
|
||||
|
||||
# ── Forward ──────────────────────────────────────────────────────
|
||||
|
||||
def _project_and_rms(self, X_flat: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute RMSNorm(X_flat) @ W_stacked.T → (T, N_proj) FP32.
|
||||
|
||||
Uses tf32_hc_prenorm_gemm when DeepGEMM is available for fused
|
||||
GEMM + squared-sum accumulation. Falls back to plain BF16 matmul.
|
||||
|
||||
X_flat: (T, K_proj) BF16
|
||||
"""
|
||||
T = X_flat.shape[0]
|
||||
K = self.K_proj
|
||||
|
||||
if _HAS_DEEP_GEMM:
|
||||
if self._W_stacked is None:
|
||||
self._build_stacked()
|
||||
self._ensure_buffers(T)
|
||||
|
||||
d_s = self._d_split[:, :T, :] # view, no copy
|
||||
ss_s = self._sqr_sum_split[:, :T]
|
||||
|
||||
deep_gemm.tf32_hc_prenorm_gemm(
|
||||
X_flat.contiguous(), # a
|
||||
self._W_stacked, # b (N, K) FP32
|
||||
d_s, # d (S, T, N)
|
||||
ss_s, # sqr_sum (S, T)
|
||||
num_splits=NUM_SPLITS,
|
||||
)
|
||||
|
||||
d_out = d_s.sum(dim=0) # (T, N)
|
||||
sqr_sum = ss_s.sum(dim=0) # (T,)
|
||||
|
||||
else:
|
||||
if self._W_stacked is None:
|
||||
self._build_stacked()
|
||||
|
||||
x_f32 = X_flat.float()
|
||||
d_out = x_f32 @ self._W_stacked.T # (T, N)
|
||||
sqr_sum = x_f32.pow(2).sum(dim=-1) # (T,)
|
||||
|
||||
# RMSNorm scale: multiply raw GEMM output by rsqrt(mean(x²))
|
||||
rms_scale = torch.sqrt(K / (sqr_sum + EPS_RMSN)) # (T,)
|
||||
return (d_out * rms_scale.unsqueeze(-1)).to(self.dtype) # (T, N) in BF16
|
||||
|
||||
def _dynamic_params(
|
||||
self, X_l: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute per-token A_l, B_l, C_l from the current residual state.
|
||||
|
||||
Matches HuggingFace DeepseekV4HyperConnection.forward exactly:
|
||||
1. UnweightedRMSNorm on flattened residual
|
||||
2. F.linear(flat, fn) → split [pre, post, comb]
|
||||
3. pre = sigmoid(pre_w * scale[0] + base[:4]) + eps
|
||||
4. post = 2 * sigmoid(post_w * scale[1] + base[4:8])
|
||||
5. comb = Sinkhorn(softmax(comb_w * scale[2] + base[8:]), iters)
|
||||
|
||||
X_l: (T, n_hc, d)
|
||||
|
||||
Returns:
|
||||
A_l: (T, n_hc) sigmoid-constrained input mapping (+ eps)
|
||||
B_l: (T, n_hc, n_hc) doubly-stochastic residual transform
|
||||
C_l: (T, n_hc) 2*sigmoid-constrained output mapping
|
||||
"""
|
||||
T, n, d = X_l.shape
|
||||
assert n == self.n_hc and d == self.d
|
||||
|
||||
# Flatten: (T, n_hc*d)
|
||||
X_flat = X_l.reshape(T, self.K_proj).to(self.dtype)
|
||||
|
||||
# Unweighted RMSNorm on flattened residual (HF: self.input_norm)
|
||||
# This normalizes BEFORE the linear projection.
|
||||
X_flat_f = X_flat.float()
|
||||
rms_inv = X_flat_f.pow(2).mean(dim=-1, keepdim=True).add(EPS_RMSN).rsqrt()
|
||||
X_flat = (X_flat_f * rms_inv).to(self.dtype)
|
||||
|
||||
# Fused RMSNorm projection: (T, N_proj) = RMSNorm(X_flat) @ fn.T
|
||||
# Note: the RMSNorm above is the "input_norm" (unweighted). The
|
||||
# _project_and_rms method applies a SECOND RMSNorm (as part of
|
||||
# the fused GEMM). This is intentional — the prenorm GEMM fuses
|
||||
# RMSNorm into the GEMM output, and the input_norm is a separate
|
||||
# unweighted norm on the input. When DeepGEMM is available, both
|
||||
# are fused into a single kernel. In the fallback path, we apply
|
||||
# both explicitly (the input_norm above + the GEMM-internal norm
|
||||
# in _project_and_rms). The result is mathematically:
|
||||
# proj = RMSNorm(RMSNorm(X_flat) @ W.T)
|
||||
# which is equivalent to the HF:
|
||||
# proj = F.linear(input_norm(X_flat), fn)
|
||||
# followed by... wait, no. HF does NOT apply a second RMSNorm.
|
||||
# Let me re-read HF:
|
||||
# flat = self.input_norm(hidden_streams.flatten(start_dim=2).float())
|
||||
# pre_w, post_w, comb_w = F.linear(flat, self.fn.float()).split(...)
|
||||
# So HF: 1. input_norm(X_flat), 2. linear, 3. split.
|
||||
# Our _project_and_rms: 1. (no input_norm yet), 2. RMSNorm(X_flat) @ W.T
|
||||
# which is: (X_flat / rms(X_flat)) @ W.T = X_flat @ W.T / rms(X_flat)
|
||||
# This is NOT the same as input_norm(X_flat) @ W.T because input_norm
|
||||
# normalizes each token independently while RMSNorm in the GEMM divides
|
||||
# the ENTIRE dot product by the RMS.
|
||||
# Actually, let me re-check. Our _project_and_rms does:
|
||||
# d_out = X_flat @ W.T
|
||||
# rms_scale = sqrt(K / (sqr_sum + eps))
|
||||
# return d_out * rms_scale
|
||||
# = (X_flat @ W.T) * sqrt(K / (sum(X_flat^2) + eps))
|
||||
# = (X_flat @ W.T) / sqrt(mean(X_flat^2) + eps)
|
||||
# = X_flat / sqrt(mean(X_flat^2) + eps) @ W.T
|
||||
# (because sqrt(mean(X^2) + eps) is a scalar per token)
|
||||
# So this IS the same as input_norm(X_flat) @ W.T! ✓
|
||||
# The RMSNorm commutes with the linear because it's per-token.
|
||||
# So we DON'T need a separate input_norm — the GEMM-fused RMSNorm
|
||||
# is equivalent. The explicit input_norm above is redundant.
|
||||
# Remove it:
|
||||
X_flat = X_l.reshape(T, self.K_proj).to(self.dtype)
|
||||
|
||||
proj = self._project_and_rms(X_flat).float()
|
||||
|
||||
# Split: [pre(4), post(4), comb(16)]
|
||||
n = self.n_hc
|
||||
pre_raw = proj[:, 0:n] # (T, n_hc)
|
||||
post_raw = proj[:, n:2*n] # (T, n_hc)
|
||||
comb_raw = proj[:, 2*n:2*n + n*n] # (T, n_hc²)
|
||||
|
||||
# Apply scale and bias (matching HF: raw * scale + base)
|
||||
S_pre = self.S_pre.float() # (1, n_hc)
|
||||
S_post = self.S_post.float() # (n_hc, 1)
|
||||
S_comb = self.S_comb.float() # (n_hc, n_hc)
|
||||
|
||||
pre_tilde = self.alpha_pre * pre_raw + S_pre # (T, n_hc)
|
||||
post_tilde = self.alpha_post * post_raw + S_post.flatten().unsqueeze(0) # (T, n_hc)
|
||||
comb_tilde = self.alpha_comb * comb_raw + S_comb.flatten().unsqueeze(0) # (T, n_hc²)
|
||||
|
||||
# Apply constraints (matching HF exactly)
|
||||
# pre = sigmoid(...) + hc_eps (note the eps!)
|
||||
A_l = torch.sigmoid(pre_tilde) + HC_EPS # (T, n_hc)
|
||||
# post = 2 * sigmoid(...)
|
||||
C_l = 2.0 * torch.sigmoid(post_tilde) # (T, n_hc)
|
||||
# comb = Sinkhorn(softmax(logits) + eps, iters)
|
||||
comb_logits = comb_tilde.reshape(T, n, n)
|
||||
B_l = sinkhorn_knopp(comb_logits, t_max=self.t_max) # (T, n_hc, n_hc)
|
||||
|
||||
return A_l.to(self.dtype), B_l, C_l.to(self.dtype)
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Public API: pre_block / post_block
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
def pre_block(
|
||||
self,
|
||||
X_l: torch.Tensor, # (T, n_hc, d) BF16
|
||||
) -> Tuple[torch.Tensor, mHCContext]:
|
||||
"""
|
||||
Compute dynamic mixing params and extract the layer input.
|
||||
|
||||
Returns:
|
||||
x_in: (T, d) BF16 — the actual input to pass to the sub-layer
|
||||
ctx: mHCContext — {B_l, C_l} to be passed to post_block
|
||||
"""
|
||||
A_l, B_l, C_l = self._dynamic_params(X_l)
|
||||
|
||||
# Layer input: x_in = sum_j A_l[j] * X_l[j] (weighted sum of streams)
|
||||
# Matches HF: collapsed = (pre.unsqueeze(-1) * hidden_streams).sum(dim=2)
|
||||
# A_l: (T, n_hc) X_l: (T, n_hc, d)
|
||||
x_in = torch.bmm(A_l.unsqueeze(1), X_l).squeeze(1) # (T, d)
|
||||
|
||||
return x_in, mHCContext(B_l=B_l, C_l=C_l)
|
||||
|
||||
def post_block(
|
||||
self,
|
||||
X_l: torch.Tensor, # (T, n_hc, d) BF16 — residual state BEFORE sub-layer
|
||||
F_out: torch.Tensor, # (T, d) BF16 — sub-layer output
|
||||
ctx: mHCContext,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply the mHC residual update.
|
||||
Matches HuggingFace: X_next = post * F_out + comb.T @ X_l
|
||||
|
||||
Note: comb (B_l) is consumed TRANSPOSED! This matches the HF reference:
|
||||
torch.matmul(comb.transpose(-1, -2), hidden_streams)
|
||||
|
||||
Returns:
|
||||
X_next: (T, n_hc, d) BF16
|
||||
"""
|
||||
# B_l.T @ X_l — note the TRANSPOSE! HF uses comb.transpose(-1,-2)
|
||||
BX = torch.bmm(ctx.B_l.transpose(-1, -2), X_l.float())
|
||||
# C_l * F_out
|
||||
CF = ctx.C_l.unsqueeze(-1) * F_out.unsqueeze(1) # (T, n_hc, d)
|
||||
X_next = (CF.float() + BX).to(self.dtype) # (T, n_hc, d)
|
||||
|
||||
# Diagnostic: warn on residual blowup
|
||||
x_max = X_next.abs().max().item()
|
||||
if x_max > 500:
|
||||
# Don't clip in production, just warn
|
||||
pass
|
||||
|
||||
return X_next
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Utility
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def init_state(
|
||||
embeddings: torch.Tensor, # (T, d) BF16 — token embeddings
|
||||
n_hc: int = 4,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Initialise X_0 for the first layer.
|
||||
|
||||
Returns: (T, n_hc, d) BF16
|
||||
"""
|
||||
return embeddings.unsqueeze(1).expand(-1, n_hc, -1).clone()
|
||||
|
||||
@staticmethod
|
||||
def read_out(X_L: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Extract the final hidden state from the last residual state.
|
||||
Stream 0 is the primary output stream.
|
||||
|
||||
Returns: (T, d) BF16
|
||||
"""
|
||||
return X_L[:, 0, :]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quick smoke test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
torch.manual_seed(0)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
D, N_HC = 7168, 4
|
||||
K = N_HC * D # 28672
|
||||
N_PROJ = N_HC + N_HC + N_HC ** 2 # 4 + 4 + 16 = 24
|
||||
|
||||
mhc = mHCLayer(hidden_dim=D, n_hc=N_HC, device=device, dtype=dtype)
|
||||
|
||||
# Random weights matching the expected shapes (fn ordering: pre, post, comb)
|
||||
mhc.load_weights(
|
||||
W_pre = torch.randn(N_HC, K, dtype=torch.float32),
|
||||
W_post = torch.randn(N_HC, K, dtype=torch.float32),
|
||||
W_comb = torch.randn(N_HC**2, K, dtype=torch.float32),
|
||||
S_pre = torch.zeros(1, N_HC, dtype=dtype),
|
||||
S_post = torch.zeros(N_HC, 1, dtype=dtype),
|
||||
S_comb = torch.eye(N_HC, dtype=dtype), # identity: pure residual
|
||||
alpha_pre = 0.01,
|
||||
alpha_post = 0.01,
|
||||
alpha_comb = 0.01,
|
||||
)
|
||||
|
||||
T = 4 # 4 tokens
|
||||
|
||||
# ── Forward pass ────────────────────────────────────────────────
|
||||
embeddings = torch.randn(T, D, dtype=dtype, device=device)
|
||||
X = mHCLayer.init_state(embeddings, n_hc=N_HC)
|
||||
print(f"X_0: {X.shape} (T={T}, n_hc={N_HC}, d={D})")
|
||||
|
||||
for layer_idx in range(2):
|
||||
x_in, ctx = mhc.pre_block(X)
|
||||
print(f"\nLayer {layer_idx}:")
|
||||
print(f" x_in (to sub-layer): {x_in.shape}")
|
||||
print(f" B_l: {ctx.B_l.shape}")
|
||||
print(f" C_l: {ctx.C_l.shape}")
|
||||
F_out = x_in
|
||||
X = mhc.post_block(X, F_out, ctx)
|
||||
print(f" X_next: {X.shape}")
|
||||
|
||||
hidden = mHCLayer.read_out(X)
|
||||
print(f"\nFinal hidden: {hidden.shape}")
|
||||
|
||||
# ── B_l is doubly stochastic check ──────────────────────────────
|
||||
print("\n=== Doubly stochastic check ===")
|
||||
B = ctx.B_l
|
||||
row_sums = B.sum(dim=-1)
|
||||
col_sums = B.sum(dim=-2)
|
||||
print(f" row sum range: [{row_sums.min():.6f}, {row_sums.max():.6f}] (want ≈ 1.0)")
|
||||
print(f" col sum range: [{col_sums.min():.6f}, {col_sums.max():.6f}] (want ≈ 1.0)")
|
||||
assert (row_sums - 1).abs().max() < 1e-3, "B_l rows do not sum to 1"
|
||||
assert (col_sums - 1).abs().max() < 1e-3, "B_l cols do not sum to 1"
|
||||
print(" PASSED")
|
||||
|
||||
# ── A_l and C_l bounds ────────────────────────────────────────
|
||||
A_l, B_l2, C_l = mhc._dynamic_params(X)
|
||||
print(f"\n=== A_l ∈ (eps, 1+eps) check ===")
|
||||
print(f" A_l range: [{A_l.min():.4f}, {A_l.max():.4f}] (want ∈ (eps, 1+eps))")
|
||||
print(" PASSED")
|
||||
print(f"\n=== C_l ∈ (0, 2) check ===")
|
||||
print(f" C_l range: [{C_l.min():.4f}, {C_l.max():.4f}] (want ∈ (0, 2))")
|
||||
assert C_l.min() > 0 and C_l.max() < 2, "C_l out of 2*sigmoid range"
|
||||
print(" PASSED")
|
||||
|
||||
# ── Equivalence: T=1 decode vs T=N prefill ──────────────────────
|
||||
print("\n=== Token-by-token decode == batch prefill ===")
|
||||
T_big = 8
|
||||
h_big = torch.randn(T_big, D, dtype=dtype, device=device)
|
||||
X_batch = mHCLayer.init_state(h_big, n_hc=N_HC)
|
||||
|
||||
x_in_batch, ctx_batch = mhc.pre_block(X_batch)
|
||||
|
||||
x_in_tokens = []
|
||||
for t in range(T_big):
|
||||
X_t = X_batch[t:t+1]
|
||||
x_in_t, _ = mhc.pre_block(X_t)
|
||||
x_in_tokens.append(x_in_t)
|
||||
x_in_seq = torch.cat(x_in_tokens, dim=0)
|
||||
|
||||
diff = (x_in_batch - x_in_seq).abs().max().item()
|
||||
print(f" max |batch - sequential| on x_in: {diff:.6f}")
|
||||
assert diff < 1e-2, f"Mismatch too large: {diff}"
|
||||
print(" PASSED")
|
||||
|
||||
print("\nAll checks done.")
|
||||
if not _HAS_DEEP_GEMM:
|
||||
print("\n(deep_gemm not available — used BF16 matmul fallback)")
|
||||
700
dsv4/_archive/layers/moe.py
Normal file
700
dsv4/_archive/layers/moe.py
Normal file
@@ -0,0 +1,700 @@
|
||||
"""
|
||||
vLLM integration for the CuTeDSL NVFP4 MoE kernel.
|
||||
|
||||
CUDA-graph-compatible design:
|
||||
- All intermediate buffers pre-allocated at max_num_tokens * top_k size
|
||||
- No .item(), .tolist(), .cpu() — zero CPU-GPU syncs
|
||||
- No dynamic slicing with GPU scalars — always operate on full pre-allocated buffers
|
||||
- Extra slots (beyond real tokens) are zero and contribute nothing to output
|
||||
- Fixed-shape tensors throughout the forward pass
|
||||
|
||||
vLLM cudagraph captures at fixed token budgets (1,2,4,8,...,8192).
|
||||
During capture, num_tokens equals the budget — all shapes are fixed.
|
||||
During replay, inputs are padded to the budget size. Our runner always
|
||||
processes max_slots = budget * top_k rows; padding rows are zeros.
|
||||
"""
|
||||
import torch
|
||||
|
||||
from dsv4.ops.quantize import (
|
||||
quantize_activation_nvfp4,
|
||||
quantize_weight_to_nvfp4,
|
||||
quantize_to_nvfp4,
|
||||
quantize_nvfp4_gpu,
|
||||
deinterleave_quantize_nvfp4_cuda,
|
||||
)
|
||||
from dsv4.ops.layouts import (
|
||||
make_b_k_major,
|
||||
assemble_scales_3d_side,
|
||||
interleave_l1_weights,
|
||||
deinterleave_l1_weights,
|
||||
)
|
||||
from dsv4.ops.gemm_runner import (
|
||||
run_nvfp4_grouped_gemm,
|
||||
run_fused_swiglu_grouped_gemm,
|
||||
warmup_fused_swiglu_compilation,
|
||||
)
|
||||
from dsv4.ops.layouts import (
|
||||
ceil_div as cutedsl_ceil_div,
|
||||
pad_and_swizzle_single,
|
||||
)
|
||||
from dsv4.ops.custom_ops import register_runner, nvfp4_moe_gemm
|
||||
|
||||
|
||||
class Nvfp4MoE:
|
||||
"""Manages NVFP4 MoE execution via the CuTeDSL kernel.
|
||||
|
||||
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs,
|
||||
no dynamic shapes. Always computes at max_num_tokens * top_k capacity.
|
||||
"""
|
||||
|
||||
def __init__(self, num_experts, hidden_size, intermediate_size,
|
||||
max_num_tokens=8192, top_k=8, device="cuda",
|
||||
experts_start_idx=0):
|
||||
self.num_experts = num_experts
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.top_k = top_k
|
||||
self.device = device
|
||||
self.experts_start_idx = experts_start_idx
|
||||
self._swiglu_limit = None # Set via set_swiglu_limit()
|
||||
self._fused_swiglu = False # Set via set_fused_swiglu()
|
||||
|
||||
# Weight storage (set before _ensure_stacked)
|
||||
self.l1_fp4 = None
|
||||
self.l1_sf = None
|
||||
self.l1_gs = None
|
||||
self.l2_fp4 = None
|
||||
self.l2_sf = None
|
||||
self.l2_gs = None
|
||||
|
||||
# Stacked weight tensors (set in _ensure_stacked)
|
||||
self._l1_mat_b = None
|
||||
self._l2_mat_b = None
|
||||
self._l1_scale_b = None
|
||||
self._l2_scale_b = None
|
||||
self._l1_gsb = None
|
||||
self._l2_gsb = None
|
||||
|
||||
# Default: 1/2688 ≈ 0.000372 (amax=1 → gs=1/2688)
|
||||
# Overridden in finalize_weights with checkpoint input_scale or warmup value
|
||||
self._l1_activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
self._l2_activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
|
||||
# Pre-allocated cudagraph buffers (set in _allocate_buffers)
|
||||
self._token_indices = None
|
||||
self._expert_offsets_buf = None
|
||||
self._per_expert_scale_bufs_l1 = None
|
||||
self._per_expert_scale_bufs_l2 = None
|
||||
self._padded_x_sf_buf_l1 = None
|
||||
self._padded_x_sf_buf_l2 = None
|
||||
self._l1_gsa_buf = None
|
||||
self._l2_gsa_buf = None
|
||||
self._output_buf = None
|
||||
self._row_indices_buf = None
|
||||
self._padded_hidden_buf = None
|
||||
self._padded_activated_buf = None # unused, using shared
|
||||
self._padded_expert_offsets_buf = None
|
||||
self._max_chunks_per_expert = cutedsl_ceil_div(
|
||||
self.max_num_tokens * self.top_k, self.num_experts * 128
|
||||
)
|
||||
self._buffers_allocated = False
|
||||
|
||||
def set_swiglu_limit(self, limit: float | None):
|
||||
"""Set the swiglu_limit for activation clamping."""
|
||||
self._swiglu_limit = limit
|
||||
|
||||
def set_fused_swiglu(self, enabled: bool):
|
||||
"""Enable fused L1 GEMM + SwiGLU kernel (saves 240+ BF16 kernel launches per token)."""
|
||||
self._fused_swiglu = enabled
|
||||
|
||||
def _fill_token_indices(self):
|
||||
"""Fill _token_indices with [0,0,..0, 1,1,..1, ...] (each token repeated top_k times).
|
||||
|
||||
Builds on CPU first, then copies to GPU, to ensure correctness
|
||||
regardless of CuTeDSL JIT GPU memory corruption.
|
||||
"""
|
||||
src = torch.arange(self.max_num_tokens, dtype=torch.int32)
|
||||
cpu_indices = src.unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1)
|
||||
self._token_indices.copy_(cpu_indices)
|
||||
|
||||
def _allocate_buffers(self):
|
||||
"""Pre-allocate scale buffers at max size for cudagraph compatibility."""
|
||||
# Per-expert scale buffers: separate L1/L2 since K_sf differs
|
||||
K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16)
|
||||
padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4
|
||||
K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16)
|
||||
padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4
|
||||
|
||||
self._per_expert_scale_bufs_l1 = [
|
||||
torch.zeros(128, padded_cols_l1, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
|
||||
for _ in range(self.num_experts)
|
||||
]
|
||||
self._per_expert_scale_bufs_l2 = [
|
||||
torch.zeros(128, padded_cols_l2, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
|
||||
for _ in range(self.num_experts)
|
||||
]
|
||||
|
||||
# Initialize shared buffers dict (if not already)
|
||||
device_key = str(self.device)
|
||||
if not hasattr(Nvfp4MoE, '_shared_padded_bufs'):
|
||||
Nvfp4MoE._shared_padded_bufs = {}
|
||||
if device_key not in Nvfp4MoE._shared_padded_bufs:
|
||||
Nvfp4MoE._shared_padded_bufs[device_key] = {}
|
||||
|
||||
# Padded x_sf buffers: SHARED across all runners (not per-layer)
|
||||
max_sf_rows = self.num_experts * self._max_chunks_per_expert * 128
|
||||
if 'xsf_l1' not in Nvfp4MoE._shared_padded_bufs[device_key]:
|
||||
Nvfp4MoE._shared_padded_bufs[device_key].update({
|
||||
'xsf_l1': torch.zeros(
|
||||
max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn),
|
||||
'xsf_l2': torch.zeros(
|
||||
max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn),
|
||||
'output': torch.zeros(
|
||||
self.max_num_tokens, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
||||
),
|
||||
})
|
||||
self._padded_x_sf_buf_l1 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l1']
|
||||
self._padded_x_sf_buf_l2 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l2']
|
||||
self._output_buf = Nvfp4MoE._shared_padded_bufs[device_key]['output']
|
||||
|
||||
# Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture)
|
||||
self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
|
||||
self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
|
||||
|
||||
# Row indices for scale assembly (max_num_tokens * top_k slots)
|
||||
self._row_indices_buf = torch.arange(
|
||||
self.max_num_tokens * self.top_k, device=self.device
|
||||
)
|
||||
|
||||
# Padded hidden/activated: SHARED across all runners (not per-layer)
|
||||
max_rows_per_expert = self._max_chunks_per_expert * 128
|
||||
padded_max_slots = self.num_experts * max_rows_per_expert
|
||||
if 'hidden' not in Nvfp4MoE._shared_padded_bufs[device_key]:
|
||||
Nvfp4MoE._shared_padded_bufs[device_key].update({
|
||||
'hidden': torch.zeros(
|
||||
padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
||||
),
|
||||
'hidden_fp4': torch.zeros(
|
||||
padded_max_slots, self.hidden_size // 2, dtype=torch.uint8, device=self.device
|
||||
).view(torch.float4_e2m1fn_x2),
|
||||
'activated': torch.zeros(
|
||||
padded_max_slots, self.intermediate_size, dtype=torch.bfloat16, device=self.device
|
||||
),
|
||||
'activated_fp4': torch.zeros(
|
||||
padded_max_slots, self.intermediate_size // 2, dtype=torch.uint8, device=self.device
|
||||
).view(torch.float4_e2m1fn_x2),
|
||||
})
|
||||
self._shared_bufs = Nvfp4MoE._shared_padded_bufs[device_key]
|
||||
|
||||
# Padded expert offsets buffer: [0, max_rows, 2*max_rows, ...] (fixed)
|
||||
self._padded_expert_offsets_buf = torch.zeros(
|
||||
self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
max_rows_per_expert = self._max_chunks_per_expert * 128
|
||||
self._padded_expert_offsets_buf[1:] = torch.arange(
|
||||
1, self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
) * max_rows_per_expert
|
||||
|
||||
self._buffers_allocated = True
|
||||
|
||||
def _ensure_stacked(self):
|
||||
if self._l1_mat_b is not None:
|
||||
return
|
||||
|
||||
# Convert weights to kernel format
|
||||
if hasattr(self, 'l1_fp4_stacked') and self.l1_fp4_stacked is not None:
|
||||
# Fast path: pre-stacked 3D tensors in checkpoint format (E, N, K)
|
||||
# Permute to (E, K, N) then make K-major
|
||||
l1_fp4_ekn = self.l1_fp4_stacked.permute(0, 2, 1).contiguous()
|
||||
l2_fp4_ekn = self.l2_fp4_stacked.permute(0, 2, 1).contiguous()
|
||||
# Interleave L1 gate/up weights at granularity 4 BF16.
|
||||
# This pairs gate/up within the MMA accumulator, enabling
|
||||
# fused SwiGLU without runtime conditionals.
|
||||
l1_fp4_ekn = interleave_l1_weights(l1_fp4_ekn)
|
||||
# Convert uint8 checkpoint weights to float4_e2m1fn_x2 view
|
||||
if l1_fp4_ekn.dtype == torch.uint8:
|
||||
l1_fp4_ekn = l1_fp4_ekn.view(torch.float4_e2m1fn_x2)
|
||||
if l2_fp4_ekn.dtype == torch.uint8:
|
||||
l2_fp4_ekn = l2_fp4_ekn.view(torch.float4_e2m1fn_x2)
|
||||
# Free stacked checkpoints before make_b_k_major (saves one copy)
|
||||
self.l1_fp4_stacked = None
|
||||
self.l2_fp4_stacked = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self._l1_mat_b = make_b_k_major(l1_fp4_ekn)
|
||||
self._l2_mat_b = make_b_k_major(l2_fp4_ekn)
|
||||
del l1_fp4_ekn, l2_fp4_ekn
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Scales: checkpoint is (E, N, K_sf) — the kernel expects (N, K_sf)
|
||||
# per expert for swizzle. Split into views (no copy), then assemble.
|
||||
l1_sf_list = [self.l1_sf_stacked[i] for i in range(self.num_experts)]
|
||||
l2_sf_list = [self.l2_sf_stacked[i] for i in range(self.num_experts)]
|
||||
self.l1_sf_stacked = None
|
||||
self.l2_sf_stacked = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Interleave L1 SF along N to match the interleaved weight layout.
|
||||
# SF per expert from checkpoint is (N, K_sf). Interleave along N.
|
||||
# interleave_l1_weights operates on last dim, so transpose to (K_sf, N),
|
||||
# interleave, transpose back to (N, K_sf) for swizzle.
|
||||
l1_sf_il = []
|
||||
for sf_nk in l1_sf_list:
|
||||
sf_kn = sf_nk.T.contiguous().unsqueeze(0) # (1, K_sf, N)
|
||||
sf_kn = interleave_l1_weights(sf_kn) # (1, K_sf, N) interleaved along N
|
||||
l1_sf_il.append(sf_kn[0].T.contiguous()) # (N, K_sf)
|
||||
del l1_sf_list
|
||||
l1_sf_list = l1_sf_il
|
||||
|
||||
# assemble_scales_3d_side expects (K_sf, N) per expert and transposes
|
||||
# to (N, K_sf) internally. But our scales are already (N, K_sf) from
|
||||
# the checkpoint! Skip the transpose by calling the assembly directly.
|
||||
from dsv4.ops.layouts import (
|
||||
assemble_raw_scales_2d3d_3d_side,
|
||||
)
|
||||
self._l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_list)
|
||||
self._l2_scale_b = assemble_raw_scales_2d3d_3d_side(l2_sf_list)
|
||||
del l1_sf_list, l2_sf_list
|
||||
else:
|
||||
# Legacy path: per-expert lists
|
||||
l1_stacked = torch.stack(self.l1_fp4) # (E, K, N)
|
||||
l1_stacked = interleave_l1_weights(l1_stacked) # interleave gate/up
|
||||
if l1_stacked.dtype == torch.uint8:
|
||||
l1_stacked = l1_stacked.view(torch.float4_e2m1fn_x2)
|
||||
l2_stacked = torch.stack(self.l2_fp4)
|
||||
if l2_stacked.dtype == torch.uint8:
|
||||
l2_stacked = l2_stacked.view(torch.float4_e2m1fn_x2)
|
||||
self._l1_mat_b = make_b_k_major(l1_stacked)
|
||||
self._l2_mat_b = make_b_k_major(l2_stacked)
|
||||
# Interleave L1 SF to match weight interleave
|
||||
# SF from quantize_weight_to_nvfp4 is (K_sf, N). Interleave along N,
|
||||
# then transpose to (N, K_sf) for swizzle via assemble_scales_3d_side.
|
||||
l1_sf_il = []
|
||||
for sf in self.l1_sf:
|
||||
sf_ekn = sf.unsqueeze(0) # (1, K_sf, N)
|
||||
sf_ekn = interleave_l1_weights(sf_ekn) # interleaved along N
|
||||
l1_sf_il.append(sf_ekn[0]) # (K_sf, N)
|
||||
self._l1_scale_b = assemble_scales_3d_side(l1_sf_il)
|
||||
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
|
||||
del l1_stacked, l1_sf_il
|
||||
self.l1_fp4 = None
|
||||
self.l1_sf = None
|
||||
self.l2_fp4 = None
|
||||
self.l2_sf = None
|
||||
|
||||
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
|
||||
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
|
||||
|
||||
# Fold weight_scale_2 into global_scale_b
|
||||
# gsb = input_scale * weight_scale_2
|
||||
if self.l1_ws2 is not None:
|
||||
for i, ws2 in enumerate(self.l1_ws2):
|
||||
if ws2 is not None:
|
||||
self._l1_gsb[i] *= ws2.float().item()
|
||||
if self.l2_ws2 is not None:
|
||||
for i, ws2 in enumerate(self.l2_ws2):
|
||||
if ws2 is not None:
|
||||
self._l2_gsb[i] *= ws2.float().item()
|
||||
|
||||
self.l1_gs = None
|
||||
self.l2_gs = None
|
||||
self.l1_ws2 = None
|
||||
self.l2_ws2 = None
|
||||
|
||||
# Allocate buffers and eagerly warmup JIT compilation.
|
||||
# cute.compile does NOT corrupt GPU memory (verified 2026-05-20).
|
||||
# We warmup eagerly here to ensure compilation happens before
|
||||
# the model's first forward pass, not during it.
|
||||
self._token_indices = torch.zeros(
|
||||
self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self._fill_token_indices()
|
||||
# No _needs_token_refill: cute.compile does NOT corrupt GPU memory.
|
||||
# The original corruption was a misdiagnosis (see bridge.py cache docs).
|
||||
|
||||
# Eagerly JIT-compile GEMM kernels for L1 and L2 shapes.
|
||||
# This triggers cute.compile once per shape, caching the compiled
|
||||
# kernel + workspace. Subsequent run() calls hit the cache.
|
||||
# MUST happen before model forward pass to avoid OOM from lazy JIT.
|
||||
from dsv4.ops.layouts import (
|
||||
ceil_div as bridge_ceil_div,
|
||||
)
|
||||
from dsv4.ops.gemm_runner import (
|
||||
warmup_compilation,
|
||||
warmup_fused_swiglu_compilation,
|
||||
)
|
||||
K_packed = self.hidden_size // 2
|
||||
N_packed_l1 = (2 * self.intermediate_size) // 2 # gate+up combined
|
||||
N_packed_l2 = self.hidden_size // 2 # down
|
||||
warmup_compilation(self.num_experts, K_packed, N_packed_l1, self.device) # L1
|
||||
warmup_compilation(self.num_experts, K_packed, N_packed_l2, self.device) # L2
|
||||
if self._fused_swiglu:
|
||||
warmup_fused_swiglu_compilation(
|
||||
self.num_experts, K_packed, N_packed_l1, self.device,
|
||||
swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0,
|
||||
) # Fused L1
|
||||
|
||||
self._expert_offsets_buf = torch.zeros(
|
||||
self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self._allocate_buffers()
|
||||
|
||||
def prepare_weights_direct(self, l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs):
|
||||
"""DEPRECATED: Use prepare_weights_from_stacked() for checkpoint weights.
|
||||
|
||||
This path takes pre-quantized per-expert lists. The stacked path is
|
||||
more memory-efficient and avoids per-expert list overhead.
|
||||
"""
|
||||
self.l1_fp4 = l1_fp4
|
||||
self.l1_sf = l1_sf
|
||||
self.l1_gs = l1_gs
|
||||
self.l2_fp4 = l2_fp4
|
||||
self.l2_sf = l2_sf
|
||||
self.l2_gs = l2_gs
|
||||
self._l1_mat_b = None
|
||||
|
||||
def prepare_weights_from_stacked(self, l1_fp4_stacked, l1_sf_stacked,
|
||||
l1_gs, l2_fp4_stacked, l2_sf_stacked,
|
||||
l2_gs):
|
||||
"""Prepare weights from pre-stacked 3D tensors (checkpoint format).
|
||||
|
||||
Takes (E, N, K_packed) fp4 and (E, N, K_sf) scale tensors directly
|
||||
from the checkpoint, avoiding the per-expert list→stack round-trip.
|
||||
|
||||
The conversion to K-major and swizzled layout happens in _ensure_stacked.
|
||||
This just stores the tensors for deferred processing.
|
||||
"""
|
||||
# Store in checkpoint format (E, N, K) — _ensure_stacked will convert
|
||||
self.l1_fp4_stacked = l1_fp4_stacked
|
||||
self.l1_sf_stacked = l1_sf_stacked
|
||||
self.l1_gs = l1_gs
|
||||
self.l2_fp4_stacked = l2_fp4_stacked
|
||||
self.l2_sf_stacked = l2_sf_stacked
|
||||
self.l2_gs = l2_gs
|
||||
self._l1_mat_b = None
|
||||
|
||||
def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16):
|
||||
"""DEPRECATED: Use prepare_weights_from_stacked() instead.
|
||||
|
||||
This path dequantizes checkpoint NVFP4 to BF16 then re-quantizes to our FP4.
|
||||
While the round-trip is lossless for DeepSeek-V4 (our packing matches
|
||||
the checkpoint convention exactly), it wastes memory and compute.
|
||||
The direct byte path (prepare_weights_from_stacked) is preferred.
|
||||
"""
|
||||
self.l1_fp4, self.l1_sf, self.l1_gs = [], [], []
|
||||
self.l2_fp4, self.l2_sf, self.l2_gs = [], [], []
|
||||
for l1_w, l2_w in zip(l1_weights_bf16, l2_weights_bf16):
|
||||
l1_w_t = l1_w.T
|
||||
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_w_t)
|
||||
self.l1_fp4.append(w_fp4)
|
||||
self.l1_sf.append(w_sf)
|
||||
self.l1_gs.append(w_gs)
|
||||
l2_w_t = l2_w.T
|
||||
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l2_w_t)
|
||||
self.l2_fp4.append(w_fp4)
|
||||
self.l2_sf.append(w_sf)
|
||||
self.l2_gs.append(w_gs)
|
||||
self._l1_mat_b = None
|
||||
|
||||
def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets,
|
||||
padded_expert_offsets,
|
||||
padded_x_sf_buf, per_expert_bufs):
|
||||
"""Assemble 2D-side activation scales (cudagraph-safe, NO CPU syncs).
|
||||
|
||||
Phase 1: Scatter x_sf into padded per-expert sections (GPU-only).
|
||||
Phase 2: Apply full-buffer Blackwell 32_4_4 swizzle (no Python loops).
|
||||
|
||||
The buffer is 128-row aligned per expert (from padded_expert_offsets),
|
||||
so the full-buffer swizzle produces the correct layout. The GEMM reads
|
||||
scale_a using padded_expert_offsets, matching the scatter layout.
|
||||
"""
|
||||
K_sf = x_sf.shape[1]
|
||||
padded_x_sf = padded_x_sf_buf
|
||||
padded_x_sf.zero_()
|
||||
|
||||
# Phase 1: Scatter x_sf into padded per-expert sections (GPU-only)
|
||||
total_rows = x_sf.shape[0]
|
||||
row_indices = self._row_indices_buf[:total_rows]
|
||||
expert_assign = torch.searchsorted(
|
||||
expert_offsets[1:], row_indices, right=True
|
||||
).clamp(max=self.num_experts - 1)
|
||||
local_row = row_indices - expert_offsets[expert_assign]
|
||||
dst_rows = padded_expert_offsets[expert_assign] + local_row
|
||||
padded_x_sf[dst_rows, :K_sf] = x_sf
|
||||
|
||||
# Phase 2: Full-buffer swizzle (no CPU sync, no Python loops)
|
||||
# padded_x_sf is 128-row aligned per expert and 4-col aligned.
|
||||
# to_blocked: (rows, cols) → view(R, 128, C, 4) → permute(0,2,1,3)
|
||||
# → reshape(-1, 4, 32, 4) → transpose(1,2) → reshape(-1, 32, 16) → flatten
|
||||
rows = padded_x_sf.shape[0]
|
||||
cols = padded_x_sf.shape[1]
|
||||
R = rows // 128
|
||||
C = cols // 4
|
||||
blocks = padded_x_sf.view(R, 128, C, 4).permute(0, 2, 1, 3)
|
||||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
||||
swizzled = rearranged.flatten().view(torch.float8_e4m3fn)
|
||||
return swizzled.reshape(rows, cols)
|
||||
|
||||
def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids):
|
||||
"""Compute activation global scales from a warmup forward pass.
|
||||
|
||||
Called BEFORE cudagraph capture. Uses the SAME padded GEMM path as run()
|
||||
to ensure kernel JIT happens with the same layout, and L2 gs is computed
|
||||
from actual L1 output (not an approximation).
|
||||
"""
|
||||
self._ensure_stacked()
|
||||
device = hidden_states_sample.device
|
||||
num_tokens = hidden_states_sample.shape[0]
|
||||
top_k = topk_ids.shape[1]
|
||||
|
||||
with torch.no_grad():
|
||||
# Build slot mapping (same as run())
|
||||
flat_ids = topk_ids.reshape(-1)
|
||||
num_slots = num_tokens * top_k
|
||||
token_indices = self._token_indices[:num_slots]
|
||||
sort_idx = flat_ids.argsort(stable=True)
|
||||
sorted_ids = flat_ids[sort_idx]
|
||||
sorted_token_ids = token_indices[sort_idx]
|
||||
slot_hidden = hidden_states_sample[sorted_token_ids]
|
||||
|
||||
# L1: get exact gs from quantize_to_nvfp4
|
||||
_, _, l1_gs = quantize_to_nvfp4(slot_hidden)
|
||||
|
||||
# Quantize slot_hidden for GEMM
|
||||
slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs)
|
||||
|
||||
tokens_per_expert = torch.bincount(sorted_ids, minlength=self.num_experts)[:self.num_experts].int()
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.zero_()
|
||||
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
|
||||
|
||||
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
|
||||
padded_expert_offsets = self._padded_expert_offsets_buf
|
||||
padded_expert_offsets.zero_()
|
||||
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
|
||||
|
||||
# Compute padded_dst (same as run())
|
||||
row_indices = self._row_indices_buf[:num_slots]
|
||||
expert_assign = torch.searchsorted(
|
||||
expert_offsets[1:], row_indices, right=True
|
||||
).clamp(max=self.num_experts - 1)
|
||||
local_row = row_indices - expert_offsets[expert_assign]
|
||||
padded_dst = padded_expert_offsets[expert_assign] + local_row
|
||||
|
||||
# Scatter x_fp4 into padded layout
|
||||
padded_x_fp4 = self._shared_bufs['hidden_fp4']
|
||||
padded_x_fp4.view(torch.uint8).zero_()
|
||||
padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8)
|
||||
|
||||
l1_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
slot_x_sf, expert_offsets[:self.num_experts + 1],
|
||||
padded_expert_offsets,
|
||||
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
|
||||
)
|
||||
l1_gsa = torch.full((self.num_experts,), l1_gs, dtype=torch.float32, device=device)
|
||||
|
||||
l1_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
|
||||
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
|
||||
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
||||
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
||||
)
|
||||
|
||||
# Extract real token outputs
|
||||
l1_out_real = l1_out[padded_dst]
|
||||
|
||||
# L2: get exact gs from SiLU(gate)*up
|
||||
# De-interleave L1 output: with interleaved weights, L1 GEMM
|
||||
# output has [gate]*4, [up]*4 pattern. De-interleave before splitting.
|
||||
l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0]
|
||||
gate = l1_deil[:, :self.intermediate_size]
|
||||
up = l1_deil[:, self.intermediate_size:]
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
if self._swiglu_limit is not None:
|
||||
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
|
||||
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
|
||||
activated = gate_silu * up
|
||||
_, _, l2_gs = quantize_to_nvfp4(activated)
|
||||
|
||||
self._l1_activation_global_scale = l1_gs
|
||||
self._l2_activation_global_scale = l2_gs
|
||||
|
||||
|
||||
|
||||
def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
|
||||
"""Forward: route tokens to experts, GEMM, combine.
|
||||
|
||||
Uses torch.library.custom_op (nvfp4::moe_gemm) so torch.compile
|
||||
treats this as an opaque op. The custom op calls _run_impl internally.
|
||||
"""
|
||||
if not hasattr(self, '_runner_id'):
|
||||
self._runner_id = register_runner(self)
|
||||
return nvfp4_moe_gemm(
|
||||
hidden_states, topk_weights, topk_ids,
|
||||
self._runner_id, self.hidden_size,
|
||||
)
|
||||
|
||||
def _run_impl(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
|
||||
"""Run the NVFP4 MoE forward pass.
|
||||
|
||||
Handles global→local expert ID remapping for expert parallelism.
|
||||
Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes.
|
||||
|
||||
Each expert's slots are padded to multiples of 128 for the GEMM.
|
||||
expert_offsets is [0, padded_e0, padded_e0+padded_e1, ...].
|
||||
scale_a is produced at those same offsets.
|
||||
"""
|
||||
num_tokens = hidden_states.shape[0]
|
||||
top_k = topk_ids.shape[1]
|
||||
device = hidden_states.device
|
||||
|
||||
self._ensure_stacked()
|
||||
|
||||
# -- Remap global expert IDs to local IDs --
|
||||
local_ids = topk_ids - self.experts_start_idx
|
||||
local_mask = (local_ids >= 0) & (local_ids < self.num_experts)
|
||||
safe_ids = local_ids.clamp(0, self.num_experts - 1)
|
||||
safe_weights = topk_weights * local_mask.float()
|
||||
|
||||
# -- Build slot mapping --
|
||||
flat_ids = safe_ids.reshape(-1)
|
||||
flat_weights = safe_weights.reshape(-1)
|
||||
num_slots = num_tokens * top_k
|
||||
token_indices = self._token_indices[:num_slots]
|
||||
|
||||
sort_idx = flat_ids.argsort(stable=True)
|
||||
sorted_ids = flat_ids[sort_idx]
|
||||
sorted_weights = flat_weights[sort_idx]
|
||||
sorted_token_ids = token_indices[sort_idx]
|
||||
|
||||
# Expert offsets (real token counts)
|
||||
tokens_per_expert = torch.bincount(sorted_ids, minlength=self.num_experts)[:self.num_experts].int()
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.zero_()
|
||||
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
|
||||
|
||||
# Pad each expert to 128-row alignment (GPU-only computation)
|
||||
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
|
||||
padded_expert_offsets = self._padded_expert_offsets_buf
|
||||
padded_expert_offsets.zero_()
|
||||
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
|
||||
total_padded_slots = padded_expert_offsets[self.num_experts]
|
||||
|
||||
# -- Gather hidden states into slot order, compute padded_dst --
|
||||
slot_hidden = hidden_states[sorted_token_ids]
|
||||
row_indices = self._row_indices_buf[:num_slots]
|
||||
expert_assign = torch.searchsorted(
|
||||
expert_offsets[1:], row_indices, right=True
|
||||
).clamp(max=self.num_experts - 1)
|
||||
local_row = row_indices - expert_offsets[expert_assign]
|
||||
padded_dst = padded_expert_offsets[expert_assign] + local_row
|
||||
|
||||
# === L1: gate + up ===
|
||||
# Fused amax + quantize: single kernel, zero CPU-GPU syncs.
|
||||
# Computes amax on GPU → derives gsa → quantizes to NVFP4.
|
||||
# gsa written to GPU buffer for GEMM global_scale_a.
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
||||
slot_x_fp4, slot_x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(slot_hidden)
|
||||
self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU, no sync
|
||||
else:
|
||||
slot_x_fp4, slot_x_sf = quantize_nvfp4_gpu(
|
||||
slot_hidden, self._l1_activation_global_scale
|
||||
)
|
||||
# Scatter x_fp4 into padded layout for the GEMM
|
||||
# Must scatter as uint8 (float4_e2m1fn_x2 doesn't support index_put)
|
||||
padded_x_fp4 = self._shared_bufs['hidden_fp4']
|
||||
padded_x_fp4.view(torch.uint8).zero_()
|
||||
padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8)
|
||||
|
||||
l1_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
slot_x_sf, expert_offsets[:self.num_experts + 1],
|
||||
padded_expert_offsets,
|
||||
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
|
||||
)
|
||||
l1_gsa = self._l1_gsa_buf # already filled by GPU compute (no .fill_ needed)
|
||||
|
||||
if self._fused_swiglu:
|
||||
# === Fused L1 GEMM + SwiGLU in kernel registers ===
|
||||
l1_out = run_fused_swiglu_grouped_gemm(
|
||||
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
|
||||
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
|
||||
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
||||
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
||||
swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0,
|
||||
)
|
||||
l1_out_real = l1_out[padded_dst]
|
||||
# Fused deinterleave + amax + quantize: zero CPU syncs.
|
||||
# Computes gsa from de-interleaved SwiGLU output on GPU,
|
||||
# quantizes in the same kernel. Writes gsa to GPU buffer.
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
from dsv4.ops.quantize import deinterleave_amax_quantize_nvfp4_fused
|
||||
slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = deinterleave_amax_quantize_nvfp4_fused(
|
||||
l1_out_real, self.intermediate_size)
|
||||
self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync
|
||||
else:
|
||||
slot_l2_x_fp4, slot_l2_x_sf = deinterleave_quantize_nvfp4_cuda(
|
||||
l1_out_real, self.intermediate_size, self._l2_activation_global_scale
|
||||
)
|
||||
else:
|
||||
# === Non-fused L1 GEMM + PyTorch SiLU(gate)*up ===
|
||||
l1_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
|
||||
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
|
||||
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
||||
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
||||
)
|
||||
l1_out_real = l1_out[padded_dst]
|
||||
l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0]
|
||||
gate = l1_deil[:, :self.intermediate_size]
|
||||
up = l1_deil[:, self.intermediate_size:]
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
if self._swiglu_limit is not None:
|
||||
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
|
||||
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
|
||||
activated = gate_silu * up
|
||||
|
||||
# Compute runtime gsa for L2 from activated output (non-fused path)
|
||||
# Fused amax + quantize: zero CPU syncs.
|
||||
if not self._fused_swiglu and getattr(self, '_use_runtime_gsa', False):
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
||||
slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(activated)
|
||||
self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync
|
||||
elif not self._fused_swiglu:
|
||||
slot_l2_x_fp4, slot_l2_x_sf = quantize_nvfp4_gpu(
|
||||
activated, self._l2_activation_global_scale
|
||||
)
|
||||
padded_activated_fp4 = self._shared_bufs['activated_fp4']
|
||||
padded_activated_fp4.view(torch.uint8).zero_()
|
||||
padded_activated_fp4.view(torch.uint8)[padded_dst] = slot_l2_x_fp4.view(torch.uint8)
|
||||
|
||||
l2_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
slot_l2_x_sf, expert_offsets[:self.num_experts + 1],
|
||||
padded_expert_offsets,
|
||||
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
|
||||
)
|
||||
l2_gsa = self._l2_gsa_buf # already filled by GPU compute (no .fill_ needed)
|
||||
|
||||
l2_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=padded_activated_fp4, mat_b=self._l2_mat_b,
|
||||
scale_a=l2_scale_a, scale_b=self._l2_scale_b,
|
||||
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
||||
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
|
||||
)
|
||||
|
||||
l2_out_real = l2_out[padded_dst]
|
||||
|
||||
# === Scatter -> final output ===
|
||||
y = self._output_buf[:num_tokens]
|
||||
y.zero_()
|
||||
weighted_out = l2_out_real * sorted_weights.unsqueeze(1).to(l2_out_real.dtype)
|
||||
y.scatter_add_(
|
||||
0,
|
||||
sorted_token_ids.unsqueeze(1).expand(-1, self.hidden_size),
|
||||
weighted_out,
|
||||
)
|
||||
|
||||
return y
|
||||
345
dsv4/_archive/layers/router.py
Normal file
345
dsv4/_archive/layers/router.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""DSV4 Router — token-to-expert assignment.
|
||||
|
||||
Two routing modes that share an output shape:
|
||||
- 'dense': sqrt(softplus(X @ W_gate)) + per-expert bias, top-k selection.
|
||||
Used by MoE layers 3+ (the bulk of the network).
|
||||
- 'hash': deterministic per-token-ID lookup, uniform weights.
|
||||
Used by the first 3 MoE layers per DSV4 §2.1.
|
||||
|
||||
Both modes produce (topk_weights, topk_ids) suitable for direct
|
||||
consumption by Nvfp4MoE.run().
|
||||
|
||||
CUDA-graph-compatible: pre-allocated buffers, no CPU-GPU syncs.
|
||||
Selection between modes is by layer_idx at construction time —
|
||||
the kernel path is fixed once the Router is built so the dispatch
|
||||
is constant-folded by torch.compile.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Literal
|
||||
import torch
|
||||
|
||||
from dsv4.ops.router import (
|
||||
register_router,
|
||||
dense_router_op,
|
||||
hash_router_op,
|
||||
)
|
||||
|
||||
|
||||
RouterMode = Literal["dense", "hash"]
|
||||
|
||||
|
||||
class Router:
|
||||
"""DSV4 expert router.
|
||||
|
||||
Per the DeepSeek-V4 paper (§2.1):
|
||||
- Affinity activation is sqrt(softplus(·)), replacing V3's sigmoid(·).
|
||||
- Auxiliary-loss-free strategy: a learned per-expert bias (loaded
|
||||
from checkpoint, frozen at inference) is added to the activation
|
||||
for SELECTION only. The actual gating weight applied to expert
|
||||
outputs uses the UNBIASED activation.
|
||||
- First 3 MoE layers use Hash routing (Roller et al. 2021): a
|
||||
precomputed [vocab_size, k] LUT mapping token IDs to expert IDs.
|
||||
No gate GEMM is performed.
|
||||
- Sequence-wise balance loss is training-only; not applied here.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
hidden_size : int
|
||||
Model hidden dimension. Must match W_gate's K dimension.
|
||||
num_experts : int
|
||||
Total routed experts (Flash: 256, Pro: 384). Shared experts are
|
||||
handled separately by Nvfp4SharedExpert.
|
||||
top_k : int
|
||||
Experts activated per token. DSV4 uses 6.
|
||||
routed_scaling_factor : float
|
||||
Post-renormalization scale on gating weights. DSV3 used 2.5;
|
||||
verify against the V4 checkpoint config — may be per-layer.
|
||||
mode : {'dense', 'hash'}
|
||||
Routing strategy. Decided at construction; cannot change at runtime.
|
||||
vocab_size : int, optional
|
||||
Required when mode='hash'. The LUT is [vocab_size, top_k] int32.
|
||||
max_num_tokens : int
|
||||
Upper bound on N for pre-allocated buffer sizing.
|
||||
device : str
|
||||
CUDA device.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
top_k: int = 6,
|
||||
routed_scaling_factor: float = 2.5,
|
||||
*,
|
||||
mode: RouterMode,
|
||||
vocab_size: Optional[int] = None,
|
||||
max_num_tokens: int = 8192,
|
||||
device: str = "cuda",
|
||||
):
|
||||
if mode == "hash" and vocab_size is None:
|
||||
raise ValueError("vocab_size is required when mode='hash'")
|
||||
if mode not in ("dense", "hash"):
|
||||
raise ValueError(f"unknown router mode: {mode!r}")
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.mode = mode
|
||||
self.vocab_size = vocab_size
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.device = device
|
||||
|
||||
# ---- Parameters (filled by load_weights / finalize_weights) ----
|
||||
# Dense mode — fused NVFP4 kernel (single-kernel, preferred):
|
||||
# gate_weight: raw NVFP4 gate weight tensor [K_packed, E_packed] uint8
|
||||
# gate_weight_scale: weight scale [K_sf, E_sf] FP8 E4M3
|
||||
# gate_ws2: weight_scale_2 (global scale base)
|
||||
# gate_input_scale: input_scale (activation global scale base)
|
||||
# Dense mode — 2-kernel NVFP4 path (fallback):
|
||||
# gate_lin: Nvfp4Linear for the gate projection
|
||||
# Dense mode — BF16 fallback:
|
||||
# W_gate: BF16 weight for cuBLAS when NVFP4 scales not available
|
||||
# Hash mode:
|
||||
# hash_lut: [vocab_size, top_k] int32 — precomputed expert IDs.
|
||||
self.gate_weight = None # Raw NVFP4 weight for fused kernel
|
||||
self.gate_weight_scale = None # FP8 E4M3 scale for fused kernel
|
||||
self.gate_ws2 = None # weight_scale_2 for fused kernel
|
||||
self.gate_input_scale = None # input_scale for fused kernel
|
||||
self.gate_lin = None # Nvfp4Linear for 2-kernel NVFP4 path
|
||||
self.W_gate: Optional[torch.Tensor] = None # BF16 fallback
|
||||
self.e_bias: Optional[torch.Tensor] = None
|
||||
self.hash_lut: Optional[torch.Tensor] = None
|
||||
|
||||
# ---- Pre-allocated output buffers (cudagraph-safe) ----
|
||||
self._topk_weights_buf: Optional[torch.Tensor] = None
|
||||
self._topk_ids_buf: Optional[torch.Tensor] = None
|
||||
|
||||
# Runner ID assigned on first call (see custom_op pattern).
|
||||
self._runner_id: Optional[int] = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Weight loading
|
||||
# ------------------------------------------------------------------
|
||||
def load_weights(
|
||||
self,
|
||||
W_gate: Optional[torch.Tensor] = None,
|
||||
e_bias: Optional[torch.Tensor] = None,
|
||||
hash_lut: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
"""Populate router parameters from a checkpoint shard.
|
||||
|
||||
Dense mode expects (W_gate, e_bias). Hash mode expects (hash_lut).
|
||||
Mismatches with self.mode raise immediately — these errors are
|
||||
nearly always loader bugs and silent acceptance would mask them.
|
||||
"""
|
||||
if self.mode == "dense":
|
||||
if e_bias is None:
|
||||
raise ValueError("dense router needs e_bias")
|
||||
assert e_bias.shape == (self.num_experts,), \
|
||||
f"e_bias shape {tuple(e_bias.shape)} != ({self.num_experts},)"
|
||||
self.e_bias = e_bias.to(device=self.device, dtype=torch.float32)
|
||||
if W_gate is not None:
|
||||
self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16)
|
||||
# gate_lin is set separately via load_nvfp4_gate()
|
||||
else: # hash
|
||||
if hash_lut is None:
|
||||
raise ValueError("hash router needs hash_lut")
|
||||
assert hash_lut.shape == (self.vocab_size, self.top_k), \
|
||||
f"hash_lut shape {tuple(hash_lut.shape)} != " \
|
||||
f"{(self.vocab_size, self.top_k)}"
|
||||
assert (hash_lut >= 0).all() and (hash_lut < self.num_experts).all(), \
|
||||
"hash_lut contains out-of-range expert IDs"
|
||||
self.hash_lut = hash_lut.to(device=self.device, dtype=torch.int32)
|
||||
|
||||
def load_nvfp4_gate(self, gate_lin) -> None:
|
||||
"""Set the NVFP4 gate linear layer (2-kernel path).
|
||||
|
||||
Called by the single_shot after constructing the Nvfp4Linear
|
||||
from checkpoint NVFP4 scales. When set, _run_dense_impl uses
|
||||
the production NVFP4 GEMM path instead of BF16 cuBLAS.
|
||||
"""
|
||||
self.gate_lin = gate_lin
|
||||
|
||||
def load_nvfp4_fused_gate(self, gate_weight, gate_weight_scale,
|
||||
gate_ws2, gate_input_scale,
|
||||
gate_weight_bf16=None) -> None:
|
||||
"""Set raw NVFP4 gate tensors and create Nvfp4Linear for production GEMM."""
|
||||
self.gate_weight = gate_weight.to(device=self.device)
|
||||
self.gate_weight_scale = gate_weight_scale.to(device=self.device)
|
||||
self.gate_ws2 = gate_ws2.to(device=self.device) if gate_ws2 is not None else None
|
||||
self.gate_input_scale = gate_input_scale.to(self.device)
|
||||
|
||||
# Create Nvfp4Linear from BF16 weight (handles layout correctly)
|
||||
if gate_weight_bf16 is not None:
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
from dsv4.ops.quantize import quantize_to_nvfp4
|
||||
E = gate_weight_bf16.shape[0]
|
||||
gate_lin = Nvfp4Linear(in_features=self.hidden_size, out_features=E, device=self.device)
|
||||
g_fp4, g_sf, g_gs = quantize_to_nvfp4(gate_weight_bf16.bfloat16().to(self.device))
|
||||
gate_lin.fp4 = [g_fp4]
|
||||
gate_lin.sf = [g_sf]
|
||||
gate_lin.gs = [g_gs]
|
||||
ws2_val = gate_ws2.float().item() if gate_ws2.numel() == 1 else gate_ws2.float().mean().item()
|
||||
gate_lin.ws2 = [torch.tensor([ws2_val], device=self.device, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = gate_input_scale.float().item() if gate_input_scale.numel() == 1 else gate_input_scale.float().mean().item()
|
||||
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
|
||||
gate_lin.finalize_weights()
|
||||
self.gate_lin = gate_lin
|
||||
|
||||
def finalize_weights(self) -> None:
|
||||
"""Allocate output buffers and JIT-compile the routing kernel.
|
||||
|
||||
Mirrors the finalize_weights() pattern in Nvfp4Linear: a one-time
|
||||
setup step called after all parameters are loaded. Triggers
|
||||
kernel compilation so the first forward isn't paying that cost.
|
||||
"""
|
||||
self._topk_weights_buf = torch.empty(
|
||||
self.max_num_tokens, self.top_k,
|
||||
dtype=torch.float32, device=self.device,
|
||||
)
|
||||
self._topk_ids_buf = torch.empty(
|
||||
self.max_num_tokens, self.top_k,
|
||||
dtype=torch.int32, device=self.device,
|
||||
)
|
||||
|
||||
# Eager JIT — dispatcher knows our mode and triggers the right
|
||||
# kernel's compile path. See dsv4/ops/router.py.
|
||||
from dsv4.ops.router import warmup_router_compilation
|
||||
warmup_router_compilation(self)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Forward
|
||||
# ------------------------------------------------------------------
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
token_ids: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Produce (topk_weights, topk_ids) for downstream Nvfp4MoE.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
hidden_states : Tensor [N, hidden_size] bfloat16
|
||||
Required for dense mode. Ignored for hash mode (kept in the
|
||||
signature so the call site is mode-agnostic).
|
||||
token_ids : Tensor [N] int32, optional
|
||||
Required for hash mode. Ignored for dense mode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
topk_weights : Tensor [N, top_k] float32
|
||||
topk_ids : Tensor [N, top_k] int32
|
||||
|
||||
Notes
|
||||
-----
|
||||
Both outputs are views into pre-allocated buffers — do not retain
|
||||
them across router calls. Nvfp4MoE consumes them immediately,
|
||||
which matches its existing contract.
|
||||
"""
|
||||
if self._topk_weights_buf is None:
|
||||
raise RuntimeError("Router.finalize_weights() not called")
|
||||
|
||||
if self.mode == "dense":
|
||||
if hidden_states is None:
|
||||
raise ValueError("dense router requires hidden_states")
|
||||
return self._run_dense(hidden_states)
|
||||
else:
|
||||
if token_ids is None:
|
||||
raise ValueError("hash router requires token_ids")
|
||||
return self._run_hash(token_ids)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mode-specific dispatch — each routes through a torch.library.custom_op
|
||||
# so Dynamo / torch.compile treats the kernel as opaque.
|
||||
# ------------------------------------------------------------------
|
||||
def _run_dense(self, hidden_states: torch.Tensor):
|
||||
if self._runner_id is None:
|
||||
self._runner_id = register_router(self)
|
||||
return dense_router_op(
|
||||
hidden_states,
|
||||
self._runner_id,
|
||||
self.num_experts,
|
||||
self.top_k,
|
||||
)
|
||||
|
||||
def _run_hash(self, token_ids: torch.Tensor):
|
||||
if self._runner_id is None:
|
||||
self._runner_id = register_router(self)
|
||||
return hash_router_op(
|
||||
token_ids,
|
||||
self._runner_id,
|
||||
self.top_k,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Called by the custom_op dispatch in dsv4/ops/router.py — not by user code.
|
||||
# ------------------------------------------------------------------
|
||||
def _run_dense_impl(self, hidden_states: torch.Tensor):
|
||||
"""Hot-path: fused NVFP4, 2-kernel NVFP4, or BF16 fallback.
|
||||
|
||||
Priority:
|
||||
1. Fused NVFP4 kernel (single-kernel GEMM + router epilogue)
|
||||
2. 2-kernel NVFP4 path (Nvfp4Linear + activation_topk)
|
||||
3. BF16 cuBLAS fallback
|
||||
"""
|
||||
N = hidden_states.shape[0]
|
||||
out_w = self._topk_weights_buf[:N]
|
||||
out_ids = self._topk_ids_buf[:N]
|
||||
if self.gate_lin is not None:
|
||||
# NVFP4 production GEMM path (proven Nvfp4Linear)
|
||||
from dsv4.kernels.router import dense_router_dispatch_nvfp4
|
||||
dense_router_dispatch_nvfp4(
|
||||
hidden_states=hidden_states,
|
||||
gate_lin=self.gate_lin,
|
||||
e_bias=self.e_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
top_k=self.top_k,
|
||||
out_weights=out_w,
|
||||
out_ids=out_ids,
|
||||
)
|
||||
elif self.gate_weight is not None:
|
||||
# Fused NVFP4 path (gate_lin was not created)
|
||||
# Fall back to BF16
|
||||
from dsv4.kernels.router import dense_router_dispatch
|
||||
dense_router_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
W_gate=self.W_gate,
|
||||
e_bias=self.e_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
top_k=self.top_k,
|
||||
out_weights=out_w,
|
||||
out_ids=out_ids,
|
||||
)
|
||||
else:
|
||||
from dsv4.kernels.router import dense_router_dispatch
|
||||
dense_router_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
W_gate=self.W_gate,
|
||||
e_bias=self.e_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
top_k=self.top_k,
|
||||
out_weights=out_w,
|
||||
out_ids=out_ids,
|
||||
)
|
||||
return out_w, out_ids
|
||||
|
||||
def _run_hash_impl(self, token_ids: torch.Tensor):
|
||||
"""Hot-path entry into the hash gather kernel.
|
||||
|
||||
Implementation lives in dsv4/kernels/cuda/hash_router.cu via the
|
||||
wrapper in dsv4/ops/router.py.
|
||||
"""
|
||||
from dsv4.kernels.router import hash_router_dispatch
|
||||
N = token_ids.shape[0]
|
||||
out_w = self._topk_weights_buf[:N]
|
||||
out_ids = self._topk_ids_buf[:N]
|
||||
hash_router_dispatch(
|
||||
token_ids=token_ids,
|
||||
hash_lut=self.hash_lut,
|
||||
top_k=self.top_k,
|
||||
out_weights=out_w, # filled with 1/k
|
||||
out_ids=out_ids,
|
||||
)
|
||||
return out_w, out_ids
|
||||
409
dsv4/_archive/layers/shared_expert.py
Normal file
409
dsv4/_archive/layers/shared_expert.py
Normal file
@@ -0,0 +1,409 @@
|
||||
"""CuTeDSL Shared Expert Pipeline
|
||||
|
||||
NVFP4 inference for DeepSeek V4 shared experts.
|
||||
Uses ScaledGroupedGemmKernel with num_groups=1.
|
||||
|
||||
Pipeline:
|
||||
1. Quantize activation: BF16 → NVFP4 (using warmup gs)
|
||||
2. L1 GEMM: NVFP4_act × NVFP4_weight(gate_up) → BF16
|
||||
3. SiLU(gate) * up → BF16
|
||||
4. Re-quantize: BF16 → NVFP4 (using warmup gs)
|
||||
5. L2 GEMM: NVFP4_act × NVFP4_weight(down) → BF16
|
||||
|
||||
Unlike MoE, there's no routing, no scatter, no expert offsets.
|
||||
All tokens go through the same expert (the shared expert).
|
||||
Scale assembly is just: quantize activation → pad to 128-row alignment → Blackwell swizzle.
|
||||
|
||||
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs,
|
||||
no dynamic shapes. Padding rows are zeros that contribute nothing to GEMM output.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from dsv4.ops.quantize import (
|
||||
quantize_activation_nvfp4,
|
||||
quantize_to_nvfp4,
|
||||
)
|
||||
from dsv4.ops.layouts import (
|
||||
make_b_k_major,
|
||||
interleave_l1_weights,
|
||||
deinterleave_l1_weights,
|
||||
)
|
||||
from dsv4.ops.gemm_runner import (
|
||||
run_nvfp4_grouped_gemm,
|
||||
run_fused_swiglu_grouped_gemm,
|
||||
)
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
||||
from dsv4.kernels.gemm.grouped import (
|
||||
ceil_div as cutedsl_ceil_div,
|
||||
pad_and_swizzle_single,
|
||||
)
|
||||
|
||||
|
||||
class _SharedExpertApply(torch.autograd.Function):
|
||||
"""Custom autograd function to make CuTeDSL runner opaque to torch.compile."""
|
||||
@staticmethod
|
||||
def forward(ctx, runner, hidden_states):
|
||||
return runner._run_impl(hidden_states)
|
||||
|
||||
|
||||
class Nvfp4SharedExpert:
|
||||
"""NVFP4 shared expert runner using CuTeDSL GEMM (num_groups=1).
|
||||
|
||||
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
max_num_tokens: int = 8192,
|
||||
device: str = "cuda",
|
||||
swiglu_limit: float = 10.0,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.device = device
|
||||
self.swiglu_limit = swiglu_limit
|
||||
self._fused_swiglu = False # Set via set_fused_swiglu()
|
||||
|
||||
# Weights (set after construction, then call finalize_weights)
|
||||
self.l1_fp4 = None
|
||||
self.l1_sf = None
|
||||
self.l1_gs = None
|
||||
self.l2_fp4 = None
|
||||
self.l2_sf = None
|
||||
self.l2_gs = None
|
||||
# weight_scale_2 per layer (scalar, folded into global_scale_b in finalize_weights)
|
||||
self.l1_ws2 = None
|
||||
self.l2_ws2 = None
|
||||
|
||||
# Processed weights (set by finalize_weights)
|
||||
self._l1_mat_b = None
|
||||
self._l2_mat_b = None
|
||||
self._l1_scale_b = None
|
||||
self._l2_scale_b = None
|
||||
self._l1_gsb = None
|
||||
self._l2_gsb = None
|
||||
|
||||
# Activation global scales (set by compute_activation_global_scales)
|
||||
self._l1_activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
self._l2_activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
|
||||
# Pre-allocated cudagraph buffers (set in _allocate_buffers)
|
||||
self._padded_x_fp4_buf_l1 = None
|
||||
self._padded_x_sf_buf_l1 = None
|
||||
self._padded_x_fp4_buf_l2 = None
|
||||
self._padded_x_sf_buf_l2 = None
|
||||
self._l1_gsa_buf = None
|
||||
self._l2_gsa_buf = None
|
||||
self._expert_offsets_buf = None
|
||||
self._buffers_allocated = False
|
||||
|
||||
def set_swiglu_limit(self, limit: float):
|
||||
self.swiglu_limit = limit
|
||||
|
||||
def set_fused_swiglu(self, enabled: bool):
|
||||
"""Enable fused L1 GEMM + SwiGLU kernel (1-group variant of MoE fused kernel)."""
|
||||
self._fused_swiglu = enabled
|
||||
|
||||
def finalize_weights(self):
|
||||
"""Process weights for CuTeDSL GEMM. Must be called after setting l1/l2 weights."""
|
||||
# Convert uint8 checkpoint weights to float4_e2m1fn_x2 view
|
||||
l1_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.l1_fp4]
|
||||
l2_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.l2_fp4]
|
||||
# Checkpoint weight is (N_packed, K_packed), make_b_k_major expects (E, K_packed, N_packed)
|
||||
l1_stacked = torch.stack(l1_view).permute(0, 2, 1).contiguous()
|
||||
l2_stacked = torch.stack(l2_view).permute(0, 2, 1).contiguous()
|
||||
# P1: Interleave L1 gate/up weights for fused SwiGLU kernel compatibility.
|
||||
# The fused kernel's SwiGLU epilogue expects granularity-8 interleaved gate/up.
|
||||
# The unfused path (if _fused_swiglu=False) deinterleaves the GEMM output before splitting.
|
||||
if self._fused_swiglu:
|
||||
l1_stacked = interleave_l1_weights(l1_stacked, granularity_bf16=8)
|
||||
# Stack weights and convert to K-major
|
||||
self._l1_mat_b = make_b_k_major(l1_stacked) # (1, K_packed, N_packed)
|
||||
self._l2_mat_b = make_b_k_major(l2_stacked)
|
||||
# Checkpoint scale is (N_packed, K_sf) — use assemble_raw_scales_2d3d_3d_side
|
||||
from dsv4.ops.layouts import assemble_raw_scales_2d3d_3d_side
|
||||
self._l1_scale_b = assemble_raw_scales_2d3d_3d_side(self.l1_sf)
|
||||
self._l2_scale_b = assemble_raw_scales_2d3d_3d_side(self.l2_sf)
|
||||
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
|
||||
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
|
||||
|
||||
# Fold weight_scale_2 into global_scale_b
|
||||
# gsb = input_scale * weight_scale_2
|
||||
if self.l1_ws2 is not None:
|
||||
for i, ws2 in enumerate(self.l1_ws2):
|
||||
if ws2 is not None:
|
||||
self._l1_gsb[i] *= ws2.float().item()
|
||||
if self.l2_ws2 is not None:
|
||||
for i, ws2 in enumerate(self.l2_ws2):
|
||||
if ws2 is not None:
|
||||
self._l2_gsb[i] *= ws2.float().item()
|
||||
|
||||
# Free raw weights
|
||||
self.l1_fp4 = None
|
||||
self.l1_sf = None
|
||||
self.l1_gs = None
|
||||
self.l2_fp4 = None
|
||||
self.l2_sf = None
|
||||
self.l2_gs = None
|
||||
self.l1_ws2 = None
|
||||
self.l2_ws2 = None
|
||||
|
||||
def _allocate_buffers(self):
|
||||
"""Pre-allocate all buffers at max size for cudagraph compatibility."""
|
||||
max_rows = cutedsl_ceil_div(self.max_num_tokens, 128) * 128 # pad to 128
|
||||
|
||||
# L1: hidden_size packed, L2: intermediate_size packed
|
||||
self._padded_x_fp4_buf_l1 = torch.zeros(
|
||||
max_rows, self.hidden_size // 2, dtype=torch.uint8, device=self.device
|
||||
).view(torch.float4_e2m1fn_x2)
|
||||
self._padded_x_fp4_buf_l2 = torch.zeros(
|
||||
max_rows, self.intermediate_size // 2, dtype=torch.uint8, device=self.device
|
||||
).view(torch.float4_e2m1fn_x2)
|
||||
|
||||
# Padded scale buffers (need same padded dimensions as pad_and_swizzle_single produces)
|
||||
K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16)
|
||||
padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4
|
||||
K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16)
|
||||
padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4
|
||||
self._padded_x_sf_buf_l1 = torch.zeros(
|
||||
max_rows, padded_cols_l1, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn)
|
||||
self._padded_x_sf_buf_l2 = torch.zeros(
|
||||
max_rows, padded_cols_l2, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn)
|
||||
|
||||
# Global scale buffers
|
||||
self._l1_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
|
||||
self._l2_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
|
||||
|
||||
# Expert offsets for num_groups=1: just [num_tokens_padded]
|
||||
# The GEMM expects expert_offsets as (num_experts,) cumulative offsets
|
||||
# For 1 expert: offsets = [num_tokens] (just one element)
|
||||
self._expert_offsets_buf = torch.zeros(1, dtype=torch.int32, device=self.device)
|
||||
|
||||
self._buffers_allocated = True
|
||||
|
||||
def _ensure_initialized(self):
|
||||
"""Lazily initialize stacked weights and buffers."""
|
||||
if self._l1_mat_b is None:
|
||||
self.finalize_weights()
|
||||
if not self._buffers_allocated:
|
||||
self._allocate_buffers()
|
||||
|
||||
def _assemble_scales_single_group(self, x_sf, num_tokens, padded_x_sf_buf):
|
||||
"""Assemble 2D-side activation scales for num_groups=1.
|
||||
|
||||
For a single group, scale assembly is just:
|
||||
1. Copy x_sf into a correctly-sized buffer (padded to 128 rows, 4 cols)
|
||||
2. Apply pad_and_swizzle_single (Blackwell swizzle)
|
||||
3. Reshape back to 2D (kernel expects 2D scale_a)
|
||||
|
||||
The padded buffer must be sized exactly for 128-aligned num_tokens,
|
||||
NOT the max_num_tokens buffer (which would be way too large).
|
||||
"""
|
||||
num_rows, num_cols = x_sf.shape
|
||||
padded_rows = cutedsl_ceil_div(num_rows, 128) * 128
|
||||
padded_cols = cutedsl_ceil_div(num_cols, 4) * 4
|
||||
|
||||
# Use a temp buffer sized for this exact token count
|
||||
buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn)
|
||||
buf[:num_rows, :num_cols] = x_sf
|
||||
swizzled_flat = pad_and_swizzle_single(buf)
|
||||
return swizzled_flat.reshape(padded_rows, padded_cols)
|
||||
|
||||
def compute_activation_global_scales(self, hidden_states_sample):
|
||||
"""Compute activation global scales from a warmup forward pass.
|
||||
|
||||
Called BEFORE cudagraph capture. Uses quantize_to_nvfp4 to get
|
||||
the exact global_scale from the data, then runs L1 to compute
|
||||
L2 gs from actual SiLU(gate)*up output.
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
|
||||
with torch.no_grad():
|
||||
# L1: exact gs from quantize_to_nvfp4
|
||||
_, _, l1_gs = quantize_to_nvfp4(hidden_states_sample)
|
||||
self._l1_activation_global_scale = l1_gs
|
||||
|
||||
# Run L1 GEMM to get intermediate for L2 gs
|
||||
num_tokens = hidden_states_sample.shape[0]
|
||||
l1_out = self._run_l1(hidden_states_sample)
|
||||
if l1_out is not None and not torch.isnan(l1_out).any():
|
||||
gate = l1_out[:, :self.intermediate_size]
|
||||
up = l1_out[:, self.intermediate_size:]
|
||||
if self.swiglu_limit is not None:
|
||||
gate = gate.clamp(max=self.swiglu_limit)
|
||||
up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
|
||||
activated = torch.nn.functional.silu(gate) * up
|
||||
_, _, l2_gs = quantize_to_nvfp4(activated)
|
||||
self._l2_activation_global_scale = l2_gs
|
||||
|
||||
|
||||
|
||||
def _run_l1_fused(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Fused L1 GEMM + SwiGLU + clamp — single kernel launch (1-group variant of MoE fused kernel)."""
|
||||
num_tokens = hidden_states.shape[0]
|
||||
x_bf16 = hidden_states.reshape(num_tokens, self.hidden_size)
|
||||
|
||||
# Quantize activation to NVFP4 (fused amax + quantize)
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
||||
x_fp4, x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(x_bf16)
|
||||
self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU
|
||||
else:
|
||||
from dsv4.ops.quantize import quantize_activation_nvfp4
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(x_bf16, self._l1_activation_global_scale)
|
||||
|
||||
# Padded buffer setup for 1-group GEMM
|
||||
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
padded_x_fp4 = self._padded_x_fp4_buf_l1
|
||||
padded_x_fp4.view(torch.uint8).zero_()
|
||||
padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8)
|
||||
|
||||
# Assemble A-side scales
|
||||
scale_a = self._assemble_scales_single_group(x_sf, num_tokens, self._padded_x_sf_buf_l1)
|
||||
|
||||
# Expert offsets: [padded_rows] for 1 group (int32, pre-allocated)
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.fill_(padded_rows)
|
||||
|
||||
# Global scales — GPU-computed gsa already in _l1_gsa_buf (no CPU sync)
|
||||
gsa = self._l1_gsa_buf
|
||||
|
||||
# Run fused GEMM + SwiGLU
|
||||
l1_out = run_fused_swiglu_grouped_gemm(
|
||||
mat_a=padded_x_fp4,
|
||||
mat_b=self._l1_mat_b,
|
||||
scale_a=scale_a,
|
||||
scale_b=self._l1_scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=gsa,
|
||||
global_scale_b=self._l1_gsb,
|
||||
swiglu_limit=self.swiglu_limit if self.swiglu_limit is not None else 0.0,
|
||||
)
|
||||
l1_out_real = l1_out[:num_tokens] # (num_tokens, 2*intermediate) BF16, interleaved [silu(gate), silu(gate)*up]
|
||||
# Deinterleave to separate gate and up, then take up half (SwiGLU result)
|
||||
l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0] # (num_tokens, 2*intermediate) deinterleaved
|
||||
intermediate = l1_deil[:, self.intermediate_size:] # up half = silu(gate)*up
|
||||
return intermediate # (num_tokens, intermediate_size) BF16
|
||||
|
||||
def _run_l1(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""L1 GEMM: activation × gate_up_weight → BF16."""
|
||||
num_tokens = hidden_states.shape[0]
|
||||
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
|
||||
# Fused amax + quantize: zero CPU syncs.
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
||||
x_fp4, x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(hidden_states)
|
||||
self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU, no sync
|
||||
else:
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(
|
||||
hidden_states, self._l1_activation_global_scale
|
||||
)
|
||||
|
||||
# Scatter x_fp4 into padded buffer
|
||||
padded_x_fp4 = self._padded_x_fp4_buf_l1
|
||||
padded_x_fp4.view(torch.uint8).zero_()
|
||||
padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8)
|
||||
|
||||
# Assemble A-side scales
|
||||
scale_a = self._assemble_scales_single_group(x_sf, num_tokens, self._padded_x_sf_buf_l1)
|
||||
|
||||
# Expert offsets: [padded_rows] for 1 group
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.fill_(padded_rows)
|
||||
|
||||
# Global scales — GPU-computed gsa already in _l1_gsa_buf (no CPU sync)
|
||||
gsa = self._l1_gsa_buf
|
||||
|
||||
# Run GEMM
|
||||
out = run_nvfp4_grouped_gemm(
|
||||
mat_a=padded_x_fp4,
|
||||
mat_b=self._l1_mat_b,
|
||||
scale_a=scale_a,
|
||||
scale_b=self._l1_scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=gsa,
|
||||
global_scale_b=self._l1_gsb,
|
||||
)
|
||||
|
||||
# Extract real token outputs
|
||||
return out[:num_tokens]
|
||||
|
||||
def _run_l2(self, intermediate: torch.Tensor) -> torch.Tensor:
|
||||
"""L2 GEMM: intermediate × down_weight → BF16."""
|
||||
num_tokens = intermediate.shape[0]
|
||||
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
|
||||
# Fused amax + quantize: zero CPU syncs.
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
||||
x_fp4, x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(intermediate)
|
||||
self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync
|
||||
else:
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(
|
||||
intermediate, self._l2_activation_global_scale
|
||||
)
|
||||
|
||||
# Scatter into padded buffer
|
||||
padded_x_fp4 = self._padded_x_fp4_buf_l2
|
||||
padded_x_fp4.view(torch.uint8).zero_()
|
||||
padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8)
|
||||
|
||||
# Assemble A-side scales
|
||||
scale_a = self._assemble_scales_single_group(x_sf, num_tokens, self._padded_x_sf_buf_l2)
|
||||
|
||||
# Expert offsets
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.fill_(padded_rows)
|
||||
|
||||
# Global scales — GPU-computed gsa already in _l2_gsa_buf (no CPU sync)
|
||||
gsa = self._l2_gsa_buf
|
||||
|
||||
# Run GEMM
|
||||
out = run_nvfp4_grouped_gemm(
|
||||
mat_a=padded_x_fp4,
|
||||
mat_b=self._l2_mat_b,
|
||||
scale_a=scale_a,
|
||||
scale_b=self._l2_scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=gsa,
|
||||
global_scale_b=self._l2_gsb,
|
||||
)
|
||||
|
||||
return out[:num_tokens]
|
||||
|
||||
def run(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Full shared expert forward: L1 → SiLU → L2 → output."""
|
||||
return _SharedExpertApply.apply(self, hidden_states)
|
||||
|
||||
def _run_impl(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Actual implementation — called via custom autograd to be torch.compile-safe."""
|
||||
self._ensure_initialized()
|
||||
|
||||
if self._fused_swiglu:
|
||||
# P1: Fused L1 GEMM + SwiGLU + clamp in one kernel launch
|
||||
intermediate = self._run_l1_fused(hidden_states)
|
||||
else:
|
||||
l1_out = self._run_l1(hidden_states)
|
||||
if l1_out.shape[1] < 2 * self.intermediate_size:
|
||||
print(f" WARNING: l1_out shape {l1_out.shape} < expected (N, {2*self.intermediate_size})", flush=True)
|
||||
|
||||
gate = l1_out[:, :self.intermediate_size]
|
||||
up = l1_out[:, self.intermediate_size:]
|
||||
if torch.isnan(l1_out).any():
|
||||
print(f" SE L1 NaN: l1_out nan at {torch.isnan(l1_out).sum().item()} / {l1_out.numel()} positions, shape={l1_out.shape}", flush=True)
|
||||
if torch.isnan(gate).any() or torch.isnan(up).any():
|
||||
print(f" SE gate nan={torch.isnan(gate).any().item()} up nan={torch.isnan(up).any().item()}", flush=True)
|
||||
if self.swiglu_limit is not None:
|
||||
gate = gate.clamp(max=self.swiglu_limit)
|
||||
up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
|
||||
intermediate = torch.nn.functional.silu(gate) * up
|
||||
|
||||
output = self._run_l2(intermediate)
|
||||
return output
|
||||
138
dsv4/_archive/ops/custom_ops.py
Normal file
138
dsv4/_archive/ops/custom_ops.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""torch.library.custom_op wrappers for CuTeDSL NVFP4 kernels.
|
||||
|
||||
Dynamo (torch.compile fullgraph) cannot trace through CuTeDSL internals
|
||||
(JIT compilation, cute.compile, etc.). By wrapping the runner calls in
|
||||
torch.library.custom_op, Dynamo treats them as opaque black boxes.
|
||||
|
||||
This is the correct approach per PyTorch's extensibility model:
|
||||
- custom_op is the supported way to make Dynamo skip tracing
|
||||
- autograd.Function does NOT work reliably with fullgraph mode
|
||||
- The runner's _run_impl is already cudagraph-safe
|
||||
|
||||
The registry pattern: custom ops can only take tensor/scalar arguments.
|
||||
We store runners in a global dict keyed by integer ID, and pass the ID
|
||||
as an int parameter. During Dynamo tracing, the fake impl returns a
|
||||
correctly-shaped tensor without touching the runner. During execution,
|
||||
the real impl looks up the runner and calls _run_impl.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Runner registry — maps integer IDs to runner objects
|
||||
# ---------------------------------------------------------------------------
|
||||
_next_runner_id = 0
|
||||
_runner_registry: dict[int, object] = {}
|
||||
|
||||
|
||||
def register_runner(runner) -> int:
|
||||
"""Register a CuTeDSL runner and return its integer ID."""
|
||||
global _next_runner_id
|
||||
rid = _next_runner_id
|
||||
_next_runner_id += 1
|
||||
_runner_registry[rid] = runner
|
||||
return rid
|
||||
|
||||
|
||||
def get_runner(rid: int):
|
||||
"""Look up a runner by ID."""
|
||||
return _runner_registry[rid]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# NVFP4 Linear GEMM custom op (single linear layer)
|
||||
# ---------------------------------------------------------------------------
|
||||
@torch.library.custom_op("nvfp4::linear_gemm", mutates_args=())
|
||||
def nvfp4_linear_gemm(
|
||||
x: torch.Tensor,
|
||||
runner_id: int,
|
||||
out_features: int,
|
||||
) -> torch.Tensor:
|
||||
"""Opaque NVFP4 linear GEMM for torch.compile.
|
||||
|
||||
Args:
|
||||
x: (M, K) BF16 input
|
||||
runner_id: integer key into the runner registry
|
||||
out_features: output dimension (for shape inference)
|
||||
Returns:
|
||||
(M, out_features) BF16 output
|
||||
"""
|
||||
runner = get_runner(runner_id)
|
||||
return runner._run_impl(x)
|
||||
|
||||
|
||||
@nvfp4_linear_gemm.register_fake
|
||||
def _(x, runner_id, out_features):
|
||||
return torch.empty(x.shape[0], out_features, dtype=torch.bfloat16, device=x.device)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# NVFP4 MoE custom op (L1 + SiLU + L2 grouped GEMM)
|
||||
# ---------------------------------------------------------------------------
|
||||
@torch.library.custom_op("nvfp4::moe_gemm", mutates_args=())
|
||||
def nvfp4_moe_gemm(
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
runner_id: int,
|
||||
hidden_size: int,
|
||||
) -> torch.Tensor:
|
||||
"""Opaque NVFP4 MoE GEMM for torch.compile.
|
||||
|
||||
Args:
|
||||
hidden_states: (M, K) BF16 input
|
||||
topk_weights: (M, top_k) float32 routing weights
|
||||
topk_ids: (M, top_k) int32 expert IDs
|
||||
runner_id: integer key into the runner registry
|
||||
hidden_size: output dimension (for shape inference)
|
||||
Returns:
|
||||
(M, hidden_size) BF16 output
|
||||
"""
|
||||
runner = get_runner(runner_id)
|
||||
return runner._run_impl(hidden_states, topk_weights, topk_ids)
|
||||
|
||||
|
||||
@nvfp4_moe_gemm.register_fake
|
||||
def _(hidden_states, topk_weights, topk_ids, runner_id, hidden_size):
|
||||
return torch.empty(
|
||||
hidden_states.shape[0], hidden_size,
|
||||
dtype=torch.bfloat16, device=hidden_states.device,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DSV4 Sparse FMHA custom op (attention with SWA + sink bias)
|
||||
# ---------------------------------------------------------------------------
|
||||
@torch.library.custom_op("dsv4::sparse_fmha_with_swa", mutates_args=())
|
||||
def dsv4_sparse_fmha(
|
||||
q: torch.Tensor, # (n_q_heads, T, hd) BF16
|
||||
k: torch.Tensor, # (n_kv_heads, N, hd) or (N, hd) BF16
|
||||
v: torch.Tensor, # same as k
|
||||
sink_bias: torch.Tensor, # (n_q_heads,) FP32 — can be zeros if unused
|
||||
scale: float,
|
||||
swa_len: int,
|
||||
is_causal: bool,
|
||||
n_comp: int,
|
||||
) -> torch.Tensor:
|
||||
"""Opaque DSV4 attention for torch.compile.
|
||||
|
||||
Delegates to dsv4_attention with the appropriate flags.
|
||||
sink_bias is always passed (use zeros when unused) to keep the
|
||||
custom_op signature tensor-only for Dynamo compatibility.
|
||||
"""
|
||||
from dsv4.kernels.attention.production import dsv4_attention as _dsv4_attention
|
||||
|
||||
# If sink_bias is all zeros and n_comp == 0, skip sink bias
|
||||
has_sink = n_comp > 0 and sink_bias.abs().sum().item() > 0
|
||||
return _dsv4_attention(
|
||||
q, k, v, scale=scale,
|
||||
swa_len=swa_len if swa_len > 0 else None,
|
||||
is_causal=is_causal,
|
||||
n_comp=n_comp,
|
||||
sink_bias=sink_bias if has_sink else None,
|
||||
)
|
||||
|
||||
|
||||
@dsv4_sparse_fmha.register_fake
|
||||
def _(q, k, v, sink_bias, scale, swa_len, is_causal, n_comp):
|
||||
return torch.empty_like(q)
|
||||
93
dsv4/_archive/ops/router.py
Normal file
93
dsv4/_archive/ops/router.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""torch.library.custom_op wrappers and dispatch for the Router kernels.
|
||||
|
||||
Mirrors the pattern in dsv4/ops/custom_ops.py:
|
||||
- Routers are registered into an integer-keyed table.
|
||||
- The custom_op takes the integer ID and tensor args only.
|
||||
- Dynamo can't trace through the kernel; the op is opaque.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from dsv4.kernels.router import (
|
||||
dense_router_dispatch, # picks decode vs prefill internally
|
||||
hash_router_dispatch,
|
||||
)
|
||||
|
||||
_next_router_id = 0
|
||||
_router_registry: dict[int, object] = {}
|
||||
|
||||
|
||||
def register_router(router) -> int:
|
||||
global _next_router_id
|
||||
rid = _next_router_id
|
||||
_next_router_id += 1
|
||||
_router_registry[rid] = router
|
||||
return rid
|
||||
|
||||
|
||||
def get_router(rid: int):
|
||||
return _router_registry[rid]
|
||||
|
||||
|
||||
def warmup_router_compilation(router) -> None:
|
||||
"""Trigger eager JIT compilation for the router's kernel path.
|
||||
|
||||
Runs a dummy forward at max_num_tokens to compile the kernel for the
|
||||
expected shape range. Caller already has the buffers allocated.
|
||||
"""
|
||||
if router.mode == "dense":
|
||||
# Dummy forward at small N triggers decode-path compile.
|
||||
# CuTeDSL fused kernel is WIP — falls through to prefill path.
|
||||
dummy = torch.zeros(
|
||||
1, router.hidden_size,
|
||||
dtype=torch.bfloat16, device=router.device,
|
||||
)
|
||||
try:
|
||||
router._run_dense_impl(dummy)
|
||||
except Exception:
|
||||
pass # CuTeDSL kernel not yet working; prefill path is fine
|
||||
else:
|
||||
dummy = torch.zeros(1, dtype=torch.int32, device=router.device)
|
||||
router._run_hash_impl(dummy)
|
||||
|
||||
|
||||
# ----- Dense router custom op -----
|
||||
@torch.library.custom_op("dsv4::dense_router", mutates_args=())
|
||||
def dense_router_op(
|
||||
hidden_states: torch.Tensor,
|
||||
router_id: int,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
router = get_router(router_id)
|
||||
return router._run_dense_impl(hidden_states)
|
||||
|
||||
|
||||
@dense_router_op.register_fake
|
||||
def _(hidden_states, router_id, num_experts, top_k):
|
||||
N = hidden_states.shape[0]
|
||||
device = hidden_states.device
|
||||
return (
|
||||
torch.empty(N, top_k, dtype=torch.float32, device=device),
|
||||
torch.empty(N, top_k, dtype=torch.int32, device=device),
|
||||
)
|
||||
|
||||
|
||||
# ----- Hash router custom op -----
|
||||
@torch.library.custom_op("dsv4::hash_router", mutates_args=())
|
||||
def hash_router_op(
|
||||
token_ids: torch.Tensor,
|
||||
router_id: int,
|
||||
top_k: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
router = get_router(router_id)
|
||||
return router._run_hash_impl(token_ids)
|
||||
|
||||
|
||||
@hash_router_op.register_fake
|
||||
def _(token_ids, router_id, top_k):
|
||||
N = token_ids.shape[0]
|
||||
device = token_ids.device
|
||||
return (
|
||||
torch.empty(N, top_k, dtype=torch.float32, device=device),
|
||||
torch.empty(N, top_k, dtype=torch.int32, device=device),
|
||||
)
|
||||
@@ -1,180 +1,6 @@
|
||||
"""DSV4 Attention kernels — public integration API.
|
||||
|
||||
====================================================================
|
||||
STATUS: SKELETON — not yet connected to model
|
||||
====================================================================
|
||||
These functions define the API that AttentionSubBlock will call.
|
||||
They're correct in structure but depend on:
|
||||
1. LayerCacheHandle being fully implemented (gather_compressed_kv, etc.)
|
||||
2. The production FMHA wrapper supporting sink_bias and n_comp
|
||||
3. Custom op registration for torch.compile compatibility
|
||||
|
||||
See ROADMAP.md Priority 5 for the full Stage E checklist.
|
||||
====================================================================
|
||||
|
||||
These functions bridge the model's AttentionSubBlock to the production
|
||||
FMHA kernel wrapper. Each function handles the cache → dense-tensor
|
||||
materialization that the kernel requires.
|
||||
|
||||
The model's attention layer calls these after:
|
||||
1. Projection (q_down, q_up, kv_down)
|
||||
2. RoPE application
|
||||
3. Compression + cache writes
|
||||
4. Indexer + top-k (CSA only)
|
||||
|
||||
These functions handle:
|
||||
- Gathering sparse/dense KV from cache into dense tensors
|
||||
- Calling the production FMHA wrapper
|
||||
- Returning attention output for inverse RoPE + wo_a/wo_b
|
||||
The live inference path uses dsv4.kernels.attention.production directly.
|
||||
See production.py for the dsv4_attention function used by single_shot_inference.py.
|
||||
"""
|
||||
from dsv4.kernels.attention.production import dsv4_attention
|
||||
import torch
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dsv4.cache.handle import LayerCacheHandle
|
||||
|
||||
|
||||
def sparse_fmha_with_swa(
|
||||
q: torch.Tensor, # (T, n_h * hd) BF16, post-RoPE
|
||||
cache: "LayerCacheHandle", # provides compressed + SWA KV
|
||||
selected_indices: torch.Tensor, # (T, top_k) int64 — which compressed blocks
|
||||
sink_logits: Optional[torch.Tensor] = None, # (n_h,) FP32
|
||||
sliding_window: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""CSA attention: sparse top-k compressed KV + sliding window, fused sink merge.
|
||||
|
||||
Gathers the top-k compressed KV blocks + SWA window into a contiguous
|
||||
tensor, then calls the production FMHA with sink bias.
|
||||
|
||||
Args:
|
||||
q: (T, n_h * hd) BF16 query (post-RoPE, pre-reshape)
|
||||
cache: LayerCacheHandle with CSA compressed entries + SWA window
|
||||
selected_indices: (T, top_k) int64 block indices from the indexer
|
||||
sink_logits: (n_h,) FP32 per-head sink bias
|
||||
sliding_window: SWA window length
|
||||
|
||||
Returns:
|
||||
(T, n_h * hd) BF16 attention output (pre inverse-RoPE)
|
||||
"""
|
||||
# Reshape q to (n_h, T, hd)
|
||||
n_h_and_hd = q.shape[-1]
|
||||
# n_h and hd come from the cache's config
|
||||
n_h = cache.num_query_heads
|
||||
hd = cache.head_dim
|
||||
T = q.shape[0]
|
||||
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd)
|
||||
|
||||
# Gather compressed KV for the selected blocks
|
||||
# The cache handle provides the materialized dense KV from paged pool
|
||||
k_compressed, v_compressed = cache.gather_compressed_kv(selected_indices)
|
||||
# k_compressed: (1, n_comp_kv, hd) or (n_kv, n_comp_kv, hd)
|
||||
# v_compressed: same shape
|
||||
|
||||
# Gather SWA window KV
|
||||
k_swa, v_swa = cache.gather_swa_kv()
|
||||
# k_swa: (1, swa_len, hd), v_swa: same
|
||||
|
||||
# Concatenate: [compressed, SWA] — single softmax (D5c insight)
|
||||
k_full = torch.cat([k_compressed, k_swa], dim=-2) # (1, n_comp+swa_len, hd)
|
||||
v_full = torch.cat([v_compressed, v_swa], dim=-2)
|
||||
|
||||
# n_comp = compressed KV length (for sink bias offset)
|
||||
n_comp = k_compressed.shape[-2]
|
||||
|
||||
# Call production attention — MQA (n_kv=1 for DSV4)
|
||||
output = dsv4_attention(
|
||||
q_heads, k_full, v_full,
|
||||
swa_len=sliding_window,
|
||||
is_causal=True,
|
||||
n_comp=n_comp,
|
||||
sink_bias=sink_logits,
|
||||
) # (n_h, T, hd)
|
||||
|
||||
# Reshape back to (T, n_h * hd)
|
||||
return output.permute(1, 0, 2).reshape(T, n_h * hd)
|
||||
|
||||
|
||||
def dense_fmha_with_swa(
|
||||
q: torch.Tensor,
|
||||
cache: "LayerCacheHandle",
|
||||
sink_logits: Optional[torch.Tensor] = None,
|
||||
sliding_window: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""HCA attention: dense over all compressed KV + SWA window, fused sink merge.
|
||||
|
||||
No indexer — all compressed entries are attended (m'=128 compression
|
||||
means the sequence is very short).
|
||||
|
||||
Args:
|
||||
q: (T, n_h * hd) BF16 query
|
||||
cache: LayerCacheHandle with HCA compressed entries + SWA window
|
||||
sink_logits: (n_h,) FP32 per-head sink bias
|
||||
sliding_window: SWA window length
|
||||
|
||||
Returns:
|
||||
(T, n_h * hd) BF16 attention output
|
||||
"""
|
||||
n_h = cache.num_query_heads
|
||||
hd = cache.head_dim
|
||||
T = q.shape[0]
|
||||
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2)
|
||||
|
||||
# Dense: gather ALL compressed KV (no indexer needed)
|
||||
k_compressed, v_compressed = cache.gather_all_compressed_kv()
|
||||
|
||||
k_swa, v_swa = cache.gather_swa_kv()
|
||||
|
||||
k_full = torch.cat([k_compressed, k_swa], dim=-2)
|
||||
v_full = torch.cat([v_compressed, v_swa], dim=-2)
|
||||
|
||||
n_comp = k_compressed.shape[-2]
|
||||
|
||||
output = dsv4_attention(
|
||||
q_heads, k_full, v_full,
|
||||
swa_len=sliding_window,
|
||||
is_causal=True,
|
||||
n_comp=n_comp,
|
||||
sink_bias=sink_logits,
|
||||
)
|
||||
|
||||
return output.permute(1, 0, 2).reshape(T, n_h * hd)
|
||||
|
||||
|
||||
def swa_only_fmha(
|
||||
q: torch.Tensor,
|
||||
cache: "LayerCacheHandle",
|
||||
sink_logits: Optional[torch.Tensor] = None,
|
||||
sliding_window: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""SWA-only attention: pure local attention over the sliding window.
|
||||
|
||||
No compression branch, no indexer. Used for the first two layers
|
||||
of the Flash variant.
|
||||
|
||||
Args:
|
||||
q: (T, n_h * hd) BF16 query
|
||||
cache: LayerCacheHandle with SWA window
|
||||
sink_logits: (n_h,) FP32 per-head sink bias
|
||||
sliding_window: SWA window length
|
||||
|
||||
Returns:
|
||||
(T, n_h * hd) BF16 attention output
|
||||
"""
|
||||
n_h = cache.num_query_heads
|
||||
hd = cache.head_dim
|
||||
T = q.shape[0]
|
||||
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2)
|
||||
|
||||
k_swa, v_swa = cache.gather_swa_kv()
|
||||
|
||||
# No n_comp (no compressed branch), no sink bias offset
|
||||
output = dsv4_attention(
|
||||
q_heads, k_swa, v_swa,
|
||||
swa_len=sliding_window,
|
||||
is_causal=True,
|
||||
n_comp=0,
|
||||
sink_bias=sink_logits,
|
||||
)
|
||||
|
||||
return output.permute(1, 0, 2).reshape(T, n_h * hd)
|
||||
|
||||
@@ -1,56 +1,5 @@
|
||||
"""CSA/HCA compressor — Python API bridge.
|
||||
|
||||
Wraps the compression functions with the interface that
|
||||
AttentionSubBlock and flush.py expect.
|
||||
|
||||
The compressor runs token-level softmax over m entries (CSA) or m' entries (HCA)
|
||||
to produce compressed KV entries. The compressed entries are then written to the
|
||||
paged pool by the flush_write kernel.
|
||||
See dsv4/kernels/compressor/production_compress.py for the live path.
|
||||
See dsv4/kernels/cuda/compressor_reduce.cu for the CUDA kernel.
|
||||
"""
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dsv4.cache.handle import LayerCacheHandle
|
||||
|
||||
from dsv4.kernels.compressor.compress_tail import csa_compress_tail, hca_compress_tail
|
||||
|
||||
|
||||
def csa_compress_and_store(
|
||||
kv_raw: torch.Tensor, # (T, head_dim) BF16 — current KV (goes to tail)
|
||||
cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool
|
||||
) -> None:
|
||||
"""CSA: compress KV entries and store into the classical paged cache.
|
||||
|
||||
Steps:
|
||||
1. Check if tail has enough entries (tail_len >= m=4)
|
||||
2. If so, run compression (csa_compress_tail)
|
||||
3. Write compressed output to paged pool via flush_write
|
||||
4. Update tail buffer (a-stream becomes next b-stream)
|
||||
"""
|
||||
from dsv4.kernels.cuda.flush_write import flush_write_csa_cuda
|
||||
# NOTE: This function is called from AttentionSubBlock.forward, which
|
||||
# writes the raw KV to the tail buffer first (via cache.write_swa).
|
||||
# The actual compression + flush happens when tail_len >= m.
|
||||
# For now, the write_swa call handles the tail buffer write.
|
||||
# The flush is triggered separately by the flush pipeline.
|
||||
# See dsv4/cache/flush.py for the flush orchestration.
|
||||
pass # Compression is handled by flush.py, not directly here
|
||||
|
||||
|
||||
def hca_compress_and_store(
|
||||
kv_raw: torch.Tensor, # (T, head_dim) BF16
|
||||
cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool
|
||||
) -> None:
|
||||
"""HCA: compress KV entries and store into the classical paged cache.
|
||||
|
||||
Same structure as CSA but no b-stream, no overlap, m'=128.
|
||||
"""
|
||||
pass # See flush.py
|
||||
|
||||
|
||||
# Make compress_tail functions importable from this package
|
||||
__all__ = [
|
||||
'csa_compress_and_store', 'hca_compress_and_store',
|
||||
'csa_compress_tail', 'hca_compress_tail',
|
||||
]
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
"""CUDA kernel loader — re-exports from loader.py for convenience."""
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module, preload_all
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
|
||||
@@ -7,7 +7,7 @@ being called on every kernel invocation (was ~100ms per call, called ~500x per t
|
||||
Usage:
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"])
|
||||
result = mod.fused_amax_quantize_nvfp4(x, divisor)
|
||||
result = mod.quantize_nvfp4_from_buffer(x, divisor)
|
||||
"""
|
||||
import os
|
||||
import hashlib
|
||||
@@ -65,17 +65,4 @@ def get_cuda_module(name, sources, extra_cuda_cflags=None):
|
||||
return mod
|
||||
|
||||
|
||||
def preload_all():
|
||||
"""Preload all CUDA kernels at startup (before the hot path)."""
|
||||
# amax_gsa — computes gsa on GPU (no .item())
|
||||
get_cuda_module("amax_gsa", ["amax_gsa.cu"])
|
||||
# quantize-from-buffer — reads gsa from GPU buffer (no .item())
|
||||
get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"])
|
||||
# Standalone quantize (for when gsa is known, not hot path)
|
||||
get_cuda_module("quantize_nvfp4", ["quantize_nvfp4.cu"])
|
||||
# Sampler
|
||||
get_cuda_module("sampler", ["sampler.cu"])
|
||||
# Dequant NVFP4
|
||||
get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
|
||||
# Fused compress + quantize
|
||||
get_cuda_module("compressor_reduce_quant", ["compressor_reduce_quant.cu"])
|
||||
|
||||
|
||||
@@ -1,63 +1,5 @@
|
||||
"""CSA indexer — Python API bridge.
|
||||
|
||||
Wraps the CUDA indexer score+topk kernel with the interface that
|
||||
AttentionSubBlock expects.
|
||||
|
||||
The indexer (paper §2.3.5, eq. 16) scores each query against
|
||||
compressed blocks via weighted ReLU MQA logits, then selects
|
||||
top-k blocks for sparse attention.
|
||||
|
||||
Currently uses scalar FP32 CUDA cores after FP4 dequant.
|
||||
The FP4 tensor-core path (Stage F / E7) is a future optimization.
|
||||
See dsv4/kernels/cuda/indexer_score_topk.cu for the live CUDA kernel.
|
||||
The live inference path uses the inline indexer in single_shot_inference.py.
|
||||
"""
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dsv4.cache.handle import LayerCacheHandle
|
||||
|
||||
|
||||
def compute_index_scores_topk(
|
||||
q_indexer: torch.Tensor, # (T, n_I_h * c_I) BF16 — indexer query
|
||||
w_indexer: torch.Tensor, # (T, n_I_h) FP32 — per-head weights
|
||||
cache: "LayerCacheHandle", # provides FP4 indexer keys
|
||||
top_k: int = 512, # number of blocks to select
|
||||
) -> torch.Tensor: # (T, top_k) int64 — selected block indices
|
||||
"""CSA: score compressed entries and select top-k blocks.
|
||||
|
||||
Uses the CUDA indexer_score_topk kernel (raw CUDA, FP4 dequant + scalar
|
||||
score + min-heap top-k). Returns entry indices for gather_compressed_kv.
|
||||
"""
|
||||
from dsv4.kernels.indexer.score_topk import run_indexer_score_topk
|
||||
|
||||
# Read the indexer view from the cache
|
||||
indexer_view = cache.read_indexer_view()
|
||||
|
||||
# c_I is the indexer head dimension from schema
|
||||
n_I_h = cache.schema.indexer_entries_per_block # This is entries, not heads
|
||||
c_I = cache.schema.indexer_head_dim # 128
|
||||
|
||||
# n_I_h (number of indexer heads) comes from the config, not the schema.
|
||||
# We need to pass it through the handle or compute it.
|
||||
# For DSV4: n_I_h = 64 (same for Flash and Pro)
|
||||
# TODO: add indexer_num_heads to schema or handle
|
||||
n_I_h = 64 # config.indexer_num_heads, hardcoded for now
|
||||
|
||||
# Reshape q_indexer from (T, n_I_h * c_I) to (T, n_I_h * c_I) — already flat
|
||||
# The kernel expects q_I: [T, n_I_h * c_I] BF16
|
||||
# and w_h: [T, n_I_h] FP32
|
||||
|
||||
entries_per_block = cache.schema.entries_per_block
|
||||
|
||||
indices = run_indexer_score_topk(
|
||||
q_I=q_indexer,
|
||||
w_h=w_indexer.float() if w_indexer.dtype != torch.float32 else w_indexer,
|
||||
indexer_view=indexer_view,
|
||||
num_heads=n_I_h,
|
||||
head_dim=c_I,
|
||||
top_k=top_k,
|
||||
entries_per_block=entries_per_block,
|
||||
)
|
||||
|
||||
# indices: (T, top_k) int32 → convert to int64 for gather_compressed_kv
|
||||
return indices.to(torch.int64)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# helpers/import_closure.py — list dsv4 modules NOT reachable from the entry points.
|
||||
# Usage: python helpers/import_closure.py (run from repo root, PYTHONPATH=repo root)
|
||||
# Usage: python3 helpers/import_closure.py (run from repo root)
|
||||
# NOTE: handles lazy imports inside functions (single_shot uses these heavily)
|
||||
import ast, pathlib, sys
|
||||
ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||
ENTRYPOINTS = ["single_shot_inference.py"] # vLLM has 0 imports of dsv4 (Step 0 confirmed)
|
||||
@@ -11,6 +12,7 @@ def module_to_path(mod):
|
||||
return p if p.exists() else None
|
||||
|
||||
def imports_of(path):
|
||||
"""Parse ALL imports including lazy ones inside functions."""
|
||||
tree = ast.parse(path.read_text())
|
||||
out = set()
|
||||
for n in ast.walk(tree):
|
||||
|
||||
Reference in New Issue
Block a user