diff --git a/dsv4/cache/allocator.py b/dsv4/_archive/cache/allocator.py similarity index 100% rename from dsv4/cache/allocator.py rename to dsv4/_archive/cache/allocator.py diff --git a/dsv4/cache/block_table.py b/dsv4/_archive/cache/block_table.py similarity index 100% rename from dsv4/cache/block_table.py rename to dsv4/_archive/cache/block_table.py diff --git a/dsv4/cache/flush.py b/dsv4/_archive/cache/flush.py similarity index 100% rename from dsv4/cache/flush.py rename to dsv4/_archive/cache/flush.py diff --git a/dsv4/cache/handle.py b/dsv4/_archive/cache/handle.py similarity index 100% rename from dsv4/cache/handle.py rename to dsv4/_archive/cache/handle.py diff --git a/dsv4/cache/manager.py b/dsv4/_archive/cache/manager.py similarity index 100% rename from dsv4/cache/manager.py rename to dsv4/_archive/cache/manager.py diff --git a/dsv4/cache/paged_cache.py b/dsv4/_archive/cache/paged_cache.py similarity index 100% rename from dsv4/cache/paged_cache.py rename to dsv4/_archive/cache/paged_cache.py diff --git a/dsv4/cache/prepare_forward.py b/dsv4/_archive/cache/prepare_forward.py similarity index 100% rename from dsv4/cache/prepare_forward.py rename to dsv4/_archive/cache/prepare_forward.py diff --git a/dsv4/cache/schema.py b/dsv4/_archive/cache/schema.py similarity index 100% rename from dsv4/cache/schema.py rename to dsv4/_archive/cache/schema.py diff --git a/dsv4/cache/state_cache.py b/dsv4/_archive/cache/state_cache.py similarity index 100% rename from dsv4/cache/state_cache.py rename to dsv4/_archive/cache/state_cache.py diff --git a/dsv4/kernels/cache/append_swa.py b/dsv4/_archive/kernels/cache/append_swa.py similarity index 100% rename from dsv4/kernels/cache/append_swa.py rename to dsv4/_archive/kernels/cache/append_swa.py diff --git a/dsv4/kernels/cache/gather.py b/dsv4/_archive/kernels/cache/gather.py similarity index 100% rename from dsv4/kernels/cache/gather.py rename to dsv4/_archive/kernels/cache/gather.py diff --git a/dsv4/kernels/compressor/compress_tail.py b/dsv4/_archive/kernels/compressor/compress_tail.py similarity index 100% rename from dsv4/kernels/compressor/compress_tail.py rename to dsv4/_archive/kernels/compressor/compress_tail.py diff --git a/dsv4/kernels/compressor/csa_hca.py b/dsv4/_archive/kernels/compressor/csa_hca.py similarity index 100% rename from dsv4/kernels/compressor/csa_hca.py rename to dsv4/_archive/kernels/compressor/csa_hca.py diff --git a/dsv4/kernels/indexer/compute_valid_lens.py b/dsv4/_archive/kernels/indexer/compute_valid_lens.py similarity index 100% rename from dsv4/kernels/indexer/compute_valid_lens.py rename to dsv4/_archive/kernels/indexer/compute_valid_lens.py diff --git a/dsv4/kernels/indexer/csa_indexer.py b/dsv4/_archive/kernels/indexer/csa_indexer.py similarity index 100% rename from dsv4/kernels/indexer/csa_indexer.py rename to dsv4/_archive/kernels/indexer/csa_indexer.py diff --git a/dsv4/kernels/indexer/score_topk.py b/dsv4/_archive/kernels/indexer/score_topk.py similarity index 100% rename from dsv4/kernels/indexer/score_topk.py rename to dsv4/_archive/kernels/indexer/score_topk.py diff --git a/dsv4/kernels/router/dense_router_decode_kernel.py b/dsv4/_archive/kernels/router/dense_router_decode_kernel.py similarity index 100% rename from dsv4/kernels/router/dense_router_decode_kernel.py rename to dsv4/_archive/kernels/router/dense_router_decode_kernel.py diff --git a/dsv4/kernels/router/dense_router_prefill.py b/dsv4/_archive/kernels/router/dense_router_prefill.py similarity index 100% rename from dsv4/kernels/router/dense_router_prefill.py rename to dsv4/_archive/kernels/router/dense_router_prefill.py diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.py b/dsv4/_archive/kernels/router/nvfp4_fused_router_kernel.py similarity index 100% rename from dsv4/kernels/router/nvfp4_fused_router_kernel.py rename to dsv4/_archive/kernels/router/nvfp4_fused_router_kernel.py diff --git a/dsv4/layers/attention.py b/dsv4/_archive/layers/attention.py similarity index 100% rename from dsv4/layers/attention.py rename to dsv4/_archive/layers/attention.py diff --git a/dsv4/layers/embedding.py b/dsv4/_archive/layers/embedding.py similarity index 100% rename from dsv4/layers/embedding.py rename to dsv4/_archive/layers/embedding.py diff --git a/dsv4/layers/ffn.py b/dsv4/_archive/layers/ffn.py similarity index 100% rename from dsv4/layers/ffn.py rename to dsv4/_archive/layers/ffn.py diff --git a/dsv4/_archive/layers/grouped_linear.py b/dsv4/_archive/layers/grouped_linear.py new file mode 100644 index 00000000..3281d73f --- /dev/null +++ b/dsv4/_archive/layers/grouped_linear.py @@ -0,0 +1,368 @@ +"""CuTeDSL NVFP4 Grouped Linear for wo_a (o_proj first half). + +wo_a in DeepSeek V4 is a grouped matmul (bmm) with n_local_groups=8 groups. +Each group: (tokens, heads_per_group * head_dim) × (heads_per_group * head_dim, o_lora_rank) → (tokens, o_lora_rank) + +The vLLM forward does this via DeepGEMM fp8_einsum with equation "bhr,hdr->bhd". +We replace it with our CuTeDSL ScaledGroupedGemm using n_local_groups as num_experts, +where every token goes to every "expert" (group). + +wo_a is loaded as BF16 from our NVFP4 checkpoint, then quantized to NVFP4 here. + +CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs. +""" + +import torch + +from dsv4.ops.quantize import ( + quantize_activation_nvfp4, + quantize_weight_to_nvfp4, + quantize_nvfp4_gpu_fused, +) +from dsv4.ops.layouts import ( + make_b_k_major, + assemble_scales_2d_side, + assemble_scales_3d_side, +) +from dsv4.ops.gemm_runner import ( + run_nvfp4_grouped_gemm, +) +from dsv4.ops.layouts import ( + ceil_div as cutedsl_ceil_div, + pad_and_swizzle_single, +) +from dsv4.ops.custom_ops import register_runner, nvfp4_linear_gemm + + +class Nvfp4GroupedLinear: + """Grouped NVFP4 linear for wo_a (o-projection first half). + + Handles the "bhr,hdr->bhd" einsum pattern: + - o: (tokens, n_local_heads, head_dim) → reshape to (tokens, n_local_groups, heads_per_group * head_dim) + - wo_a: (n_local_groups, heads_per_group * head_dim, o_lora_rank) → NVFP4 per group + - z: (tokens, n_local_groups, o_lora_rank) + + Uses ScaledGroupedGemm with num_groups=n_local_groups. + Every token goes to every group (no routing). + + CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs. + """ + + def __init__( + self, + n_local_groups: int, + heads_per_group: int, + head_dim: int, + o_lora_rank: int, + max_num_tokens: int = 8192, + device: str = "cuda", + ): + self.n_local_groups = n_local_groups + self.heads_per_group = heads_per_group + self.head_dim = head_dim + self.o_lora_rank = o_lora_rank + self.max_num_tokens = max_num_tokens + self.device = device + + # Per-group dimensions + self.group_in_features = heads_per_group * head_dim # 8192 + self.group_out_features = o_lora_rank # 1536 + + # NVFP4 weight storage: lists of per-group tensors + self._weight_fp4 = None # list of (K//2, N) float4_e2m1fn_x2 + self._weight_sf = None # list of (K//16, N) float8_e4m3fn + self._weight_gs = None # list of float32 + + # Processed weights (set by finalize_weights) + self._mat_b = None + self._scale_b = None + self._gsb = None + + # Activation global scale + self._activation_global_scale = 1.0 / (6.0 * 448.0) + + # Pre-allocated buffers + self._padded_x_fp4_buf = None + self._gsa_buf = None + self._expert_offsets_buf = None + self._buffers_allocated = False + + def set_bf16_weight(self, wo_a_bf16: torch.Tensor): + """Set wo_a weight from BF16 and quantize to NVFP4. + + Args: + wo_a_bf16: (n_local_groups * o_lora_rank, heads_per_group * head_dim) BF16 + OR (n_local_groups, heads_per_group * head_dim, o_lora_rank) if from bmm + """ + # Quantize each group separately + fp4_list = [] + sf_list = [] + gs_list = [] + + if wo_a_bf16.ndim == 3: + # bmm format: (n_local_groups, heads_per_group * head_dim, o_lora_rank) + for g in range(self.n_local_groups): + w_g = wo_a_bf16[g] # (in_features, out_features) + w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_g) + # quantize_weight_to_nvfp4 returns (K//2, N) with K=in_features + # Our kernel expects (K_packed, N_packed) where K is the contraction dim + # For weight (in_features, out_features): K=in_features (contraction) + # quantize_weight_to_nvfp4 treats dim 0 as K, so result is (K//2, N) ✓ + fp4_list.append(w_fp4) + sf_list.append(w_sf) + gs_list.append(w_gs) + else: + # Dense format: (n_local_groups * o_lora_rank, heads_per_group * head_dim) + # Split into per-group blocks + for g in range(self.n_local_groups): + start = g * self.o_lora_rank + end = start + self.o_lora_rank + w_g = wo_a_bf16[start:end, :] # (o_lora_rank, in_features) + # NOTE: This is transposed — weight is (out, in) but quantize_weight_to_nvfp4 + # expects (K, N) where K is the packed/contraction dim. + # For matmul X @ W^T, the contraction dim of W is dim 1 (in_features). + # So we need to transpose before quantizing. + w_g_t = w_g.T # (in_features, o_lora_rank) = (K, N) + w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_g_t) + fp4_list.append(w_fp4) + sf_list.append(w_sf) + gs_list.append(w_gs) + + self._weight_fp4 = fp4_list + self._weight_sf = sf_list + self._weight_gs = gs_list + + def load_nvfp4_weight(self, weight, weight_scale, weight_scale_2=None, input_scale=None): + """Load NVFP4 weights directly from checkpoint — no dequant/re-quant. + + The checkpoint stores weights in (out_features, in_features) layout: + weight: (n_groups * o_rank, group_in_features // 2) uint8 + weight_scale: (n_groups * o_rank, group_in_features // 16) float8_e4m3fn + weight_scale_2: scalar or (n_groups * o_rank,) float + input_scale: scalar or (n_groups * o_rank,) float (unused for weight dequant) + + Each group's chunk is (o_rank, K_packed) = (N, K_packed) in row-major. + Our GEMM expects (K_packed, N) per group, so we transpose each group. + Block scales follow the same transpose. + + Args: + weight: (n_groups * o_rank, group_in_features // 2) uint8 + weight_scale: (n_groups * o_rank, group_in_features // 16) float8_e4m3fn + weight_scale_2: scalar or per-row scale tensor (optional) + input_scale: scalar or per-row (unused — for activation quantization) + """ + fp4_list = [] + sf_list = [] + gs_list = [] + + K_packed = self.group_in_features // 2 + N = self.o_lora_rank + K_sf = self.group_in_features // 16 # block scale dim along K + + for g in range(self.n_local_groups): + # Extract this group's weight: (o_rank, K_packed) = (N, K_packed) + start = g * N + end = start + N + w_g = weight[start:end] # (N, K_packed) uint8 + ws_g = weight_scale[start:end] # (N, K_sf) float8_e4m3fn + + # Transpose to (K_packed, N) — the layout quantize_weight_to_nvfp4 produces + w_g_t = w_g.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() + ws_g_t = ws_g.permute(1, 0).contiguous() + + fp4_list.append(w_g_t) + sf_list.append(ws_g_t) + + # Global scale: weight_scale_2 + if weight_scale_2 is not None: + if weight_scale_2.numel() == 1: + gs_list.append(weight_scale_2.float().item()) + else: + # Per-row: take mean of this group's rows + gs_list.append(weight_scale_2[start:end].float().mean().item()) + else: + gs_list.append(1.0) + + self._weight_fp4 = fp4_list + self._weight_sf = sf_list + self._weight_gs = gs_list + + def finalize_weights(self): + """Process NVFP4 weights for CuTeDSL GEMM.""" + if self._weight_fp4 is None: + raise RuntimeError("Call set_bf16_weight() before finalize_weights()") + + self._mat_b = make_b_k_major(torch.stack(self._weight_fp4)) # (groups, K_packed, N_packed) + self._scale_b = assemble_scales_3d_side(self._weight_sf) + self._gsb = torch.tensor(self._weight_gs, dtype=torch.float32, device=self.device) + + # Free raw weights + self._weight_fp4 = None + self._weight_sf = None + self._weight_gs = None + + def _allocate_buffers(self): + """Pre-allocate buffers at max size for cudagraph compatibility.""" + max_rows_per_group = cutedsl_ceil_div(self.max_num_tokens, 128) * 128 + total_max_rows = max_rows_per_group * self.n_local_groups + + self._padded_x_fp4_buf = torch.zeros( + total_max_rows, self.group_in_features // 2, dtype=torch.uint8, device=self.device + ).view(torch.float4_e2m1fn_x2) + + self._gsa_buf = torch.zeros(self.n_local_groups, dtype=torch.float32, device=self.device) + self._expert_offsets_buf = torch.zeros(self.n_local_groups, dtype=torch.int32, device=self.device) + self._buffers_allocated = True + + def _ensure_initialized(self): + if self._mat_b is None: + self.finalize_weights() + if not self._buffers_allocated: + self._allocate_buffers() + + def _assemble_scales_single_group(self, x_sf): + """Assemble 2D-side activation scales for num_groups=1.""" + num_rows, num_cols = x_sf.shape + padded_rows = cutedsl_ceil_div(num_rows, 128) * 128 + padded_cols = cutedsl_ceil_div(num_cols, 4) * 4 + + buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn) + buf[:num_rows, :num_cols] = x_sf + swizzled_flat = pad_and_swizzle_single(buf) + return swizzled_flat.reshape(padded_rows, padded_cols) + + def compute_activation_global_scale(self, o_sample: torch.Tensor): + """Compute activation global scale from a warmup forward. + + Args: + o_sample: (tokens, n_local_heads, head_dim) BF16 attention output sample + """ + self._ensure_initialized() + # Reshape to grouped format, then flatten to 2D for quantization + o_grouped = o_sample.reshape(-1, self.n_local_groups, self.group_in_features) + # We need a single gs for all groups — use the overall amax + from dsv4.ops.quantize import ( + quantize_to_nvfp4, + ) + o_flat = o_sample.reshape(-1, o_sample.shape[-1]) # (tokens, n_local_heads * head_dim) — not right + # Actually, for grouped GEMM, each group's activation is (tokens, group_in_features) + # The global scale should be computed per-group, but for simplicity use one scale + # based on the overall amax. + with torch.no_grad(): + _, _, gs = quantize_to_nvfp4(o_grouped.reshape(-1, self.group_in_features)) + self._activation_global_scale = gs + + def run(self, o: torch.Tensor) -> torch.Tensor: + """Forward: BF16 attention output → NVFP4 grouped GEMM → BF16 z. + + Args: + o: (num_tokens, n_local_heads, head_dim) BF16 — attention output + AFTER inverse RoPE has been applied + + Returns: + z: (num_tokens, n_local_groups, o_lora_rank) BF16 + """ + if not hasattr(self, '_runner_id'): + self._runner_id = register_runner(self) + return nvfp4_linear_gemm( + o, self._runner_id, self.n_local_groups * self.o_lora_rank, + ) + + def _run_impl(self, o: torch.Tensor) -> torch.Tensor: + """Actual implementation. + + Input o is (tokens, n_local_heads, head_dim). + We reshape to (tokens, n_local_groups, heads_per_group * head_dim), + then treat each group's (tokens, group_in_features) as one "expert" + in our grouped GEMM. All tokens go to all groups. + + The grouped GEMM layout requires each group's tokens to be + contiguous at their correct offset: + - Group 0: rows [0, padded_T) + - Group 1: rows [padded_T, 2*padded_T) + - ... + - Group G: rows [(G-1)*padded_T, G*padded_T) + """ + self._ensure_initialized() + + num_tokens = o.shape[0] + padded_rows_per_group = cutedsl_ceil_div(num_tokens, 128) * 128 + + # Reshape: (tokens, n_local_heads, head_dim) → (tokens, n_local_groups, group_in_features) + o_grouped = o.reshape(num_tokens, self.n_local_groups, self.group_in_features) + + # Permute to groups-first: (G, T, D) + o_grouped = o_grouped.permute(1, 0, 2) + + # Flatten all groups into (G*T, D) for batched fused quantize — single kernel launch + o_flat = o_grouped.reshape(self.n_local_groups * num_tokens, self.group_in_features) + + # Fused amax + quantize: zero CPU-GPU syncs. + # Computes gsa on GPU, quantizes to NVFP4, returns GPU tensor. + # Replaces the old path: .item() sync + Python quantize per group. + if getattr(self, '_use_runtime_gsa', False): + x_fp4_flat, x_sf_flat, gsa_gpu = quantize_nvfp4_gpu_fused(o_flat) + # gsa_gpu is (G*T,) — all rows share same amax (from max over full tensor) + # For the GEMM's global_scale_a, fill all group slots with the same gsa value + # Use GPU-only copy: no .item(), no CPU sync + self._gsa_buf[:1].copy_(gsa_gpu[:1]) # GPU→GPU scalar copy, no sync + # Broadcast to all groups (all get same gsa) + if self.n_local_groups > 1: + self._gsa_buf[1:].copy_(self._gsa_buf[:1].expand(self.n_local_groups - 1)) + else: + self._gsa_buf.fill_(self._activation_global_scale) + x_fp4_flat, x_sf_flat = quantize_activation_nvfp4( + o_flat, self._activation_global_scale + ) + + # Reshape FP4 back to (G, T, D//2) and scatter into padded buffer + padded_x_fp4 = self._padded_x_fp4_buf + padded_x_fp4.view(torch.uint8).zero_() + + x_fp4_grouped = x_fp4_flat.reshape(self.n_local_groups, num_tokens, self.group_in_features // 2) + + for g in range(self.n_local_groups): + offset = g * padded_rows_per_group + padded_x_fp4.view(torch.uint8)[offset:offset + num_tokens] = x_fp4_grouped[g].view(torch.uint8) + + # Reshape scales back to (G, T, D//16) and assemble + x_sf_grouped = x_sf_flat.reshape(self.n_local_groups, num_tokens, self.group_in_features // 16) + all_x_sf = [x_sf_grouped[g] for g in range(self.n_local_groups)] + + # Assemble A-side scales for all groups + from dsv4.ops.layouts import ( + assemble_scales_2d_side, + ) + scale_a = assemble_scales_2d_side(all_x_sf) + + # Expert offsets: cumulative [padded_T, 2*padded_T, ..., n_groups*padded_T] + expert_offsets = self._expert_offsets_buf + for g in range(self.n_local_groups): + expert_offsets[g] = (g + 1) * padded_rows_per_group + + # Global scales — GPU-computed gsa already in _gsa_buf (no CPU sync) + gsa = self._gsa_buf + + # Run grouped GEMM + out = run_nvfp4_grouped_gemm( + mat_a=padded_x_fp4, + mat_b=self._mat_b, + scale_a=scale_a, + scale_b=self._scale_b, + expert_offsets=expert_offsets, + global_scale_a=gsa, + global_scale_b=self._gsb, + ) + + # Extract real outputs and reshape + # GEMM output has the same layout as mat_a: groups-first with padding + z = torch.empty(num_tokens, self.n_local_groups, self.o_lora_rank, + dtype=torch.bfloat16, device=o.device) + for g in range(self.n_local_groups): + offset = g * padded_rows_per_group + z[:, g, :] = out[offset:offset + num_tokens, :] + + return z + + def __call__(self, o: torch.Tensor) -> torch.Tensor: + return self.run(o) diff --git a/dsv4/_archive/layers/linear.py b/dsv4/_archive/layers/linear.py new file mode 100644 index 00000000..30ba86f7 --- /dev/null +++ b/dsv4/_archive/layers/linear.py @@ -0,0 +1,267 @@ +"""CuTeDSL NVFP4 Linear (single GEMM) + +Generic NVFP4 GEMM runner for attention projections and any single +linear layer. Uses ScaledGroupedGemmKernel with num_groups=1. + +CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs. +""" + +import torch + +from dsv4.ops.quantize import ( + quantize_activation_nvfp4, + quantize_to_nvfp4, +) +from dsv4.ops.layouts import ( + make_b_k_major, +) +from dsv4.ops.gemm_runner import ( + run_nvfp4_grouped_gemm, +) +from dsv4.kernels.gemm.grouped import ( + ceil_div as cutedsl_ceil_div, + pad_and_swizzle_single, +) +from dsv4.ops.custom_ops import register_runner, nvfp4_linear_gemm + + +class Nvfp4Linear: + """Single NVFP4 GEMM using CuTeDSL (num_groups=1). + + Handles any (K, N) weight matrix in NVFP4 format. + Simple: quantize activation → GEMM → BF16 output. + No SiLU, no fusion, no routing. + + CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs. + """ + + def __init__( + self, + in_features: int, + out_features: int, + max_num_tokens: int = 8192, + device: str = "cuda", + ): + self.in_features = in_features + self.out_features = out_features + self.max_num_tokens = max_num_tokens + self.device = device + + # Weights (set after construction, then call finalize_weights) + self.fp4 = None # list of 1 tensor + self.sf = None # list of 1 tensor + self.gs = None # list of 1 float + self.ws2 = None # list of 1 tensor — weight_scale_2 (scalar, folded into global_scale_b) + + # Processed weights + self._mat_b = None + self._scale_b = None + self._gsb = None + + # Activation global scale + self._activation_global_scale = 1.0 / (6.0 * 448.0) + + # Pre-allocated buffers + self._padded_x_fp4_buf = None + self._expert_offsets_buf = None + self._gsa_buf = None + self._buffers_allocated = False + + def finalize_weights(self): + """Process weights for CuTeDSL GEMM.""" + # Convert uint8 checkpoint weights to float4_e2m1fn_x2 view + fp4_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.fp4] + # Checkpoint weight is (out_features//2, in_features//2) = (N_packed, K_packed) + # make_b_k_major expects (E, K_packed, N_packed), so we need to permute + stacked = torch.stack(fp4_view).permute(0, 2, 1).contiguous() # (1, K_packed, N_packed) + self._mat_b = make_b_k_major(stacked) + # Checkpoint scale is (N_packed, K_sf) — already in the right row order for the + # kernel's swizzle. Use assemble_raw_scales_2d3d_3d_side (no transpose), + # NOT assemble_scales_3d_side (which transposes K_sf↔N). + from dsv4.ops.layouts import assemble_raw_scales_2d3d_3d_side + self._scale_b = assemble_raw_scales_2d3d_3d_side(self.sf) + self._gsb = torch.tensor(self.gs, dtype=torch.float32, device=self.device) + + # Fold weight_scale_2 into global_scale_b + # Dequant formula: w = lut[w_packed] * weight_scale * weight_scale_2 + # Production GEMM: y = (x * scale_a * gsa) @ (w * scale_b * gsb) + # So gsb = input_scale * weight_scale_2 + if self.ws2 is not None and len(self.ws2) > 0 and self.ws2[0] is not None: + ws2_val = self.ws2[0].float().item() + self._gsb = self._gsb * ws2_val + + # Free raw weights + self.fp4 = None + self.sf = None + self.gs = None + self.ws2 = None + + # Eagerly JIT-compile the GEMM kernel for this (K, N) shape. + # Uses num_groups=1 since this is a single linear layer. + K_packed = self.in_features // 2 + N_packed = self.out_features // 2 + # warmup_compilation(1, K_packed, N_packed, self.device) # Lazy compile on first real forward + + def _ensure_buffer_size(self, num_tokens: int): + """Ensure the padded buffer is large enough for num_tokens.""" + needed_rows = cutedsl_ceil_div(num_tokens, 128) * 128 + if self._padded_x_fp4_buf is not None and self._padded_x_fp4_buf.shape[0] >= needed_rows: + return # Already big enough + + self._padded_x_fp4_buf = torch.zeros( + needed_rows, self.in_features // 2, dtype=torch.uint8, device=self.device + ).view(torch.float4_e2m1fn_x2) + + self._expert_offsets_buf = torch.zeros(1, dtype=torch.int32, device=self.device) + self._gsa_buf = torch.full((1,), self._activation_global_scale, dtype=torch.float32, device=self.device) + + def _ensure_initialized(self): + if self._mat_b is None: + self.finalize_weights() + + def _assemble_scales_single_group(self, x_sf): + """Assemble 2D-side activation scales for num_groups=1.""" + num_rows, num_cols = x_sf.shape + padded_rows = cutedsl_ceil_div(num_rows, 128) * 128 + padded_cols = cutedsl_ceil_div(num_cols, 4) * 4 + + buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn) + buf[:num_rows, :num_cols] = x_sf + swizzled_flat = pad_and_swizzle_single(buf) + return swizzled_flat.reshape(padded_rows, padded_cols) + + def compute_activation_global_scale(self, hidden_states_sample): + """Compute activation global scale from a warmup forward.""" + self._ensure_initialized() + with torch.no_grad(): + _, _, gs = quantize_to_nvfp4(hidden_states_sample) + self._activation_global_scale = gs + + + def run(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Forward: BF16 input → NVFP4 GEMM → BF16 output. + + Uses torch.library.custom_op (nvfp4::linear_gemm) so torch.compile + treats this as an opaque op. The custom op calls _run_impl internally. + """ + if not hasattr(self, '_runner_id'): + self._runner_id = register_runner(self) + return nvfp4_linear_gemm( + hidden_states, self._runner_id, self.out_features, + ) + + def _run_impl(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Actual implementation — called via custom autograd to be torch.compile-safe.""" + self._ensure_initialized() + + num_tokens = hidden_states.shape[0] + padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 + + # Ensure buffer is large enough + self._ensure_buffer_size(num_tokens) + + # Fused amax + quantize: single kernel launch, zero CPU-GPU syncs. + # Computes amax on GPU → derives gsa → quantizes to NVFP4. + # gsa written to GPU buffer for downstream GEMM global_scale_a. + # + # This replaces the two-step path: + # compute_amax_gsa_gpu(hidden_states) → .item() sync + # quantize_nvfp4_gpu(hidden_states, gsa_float) → another kernel launch + # + # Old path: ~2 kernel launches + 1 .item() sync per projection. + # New path: 1 kernel launch + 0 .item() syncs per projection. + # Total across 61 layers: ~486 .item() syncs eliminated. + if getattr(self, '_use_runtime_gsa', False): + from dsv4.ops.quantize import quantize_nvfp4_gpu_fused + x_fp4, x_sf, gsa_gpu = quantize_nvfp4_gpu_fused(hidden_states) + self._gsa_buf.copy_(gsa_gpu[:1].reshape(1)) # GPU → GPU, no sync + else: + # P2 FIX: No per-call fill_(). The _gsa_buf already has the correct + # value — set either during initialization (via _ensure_buffer_size) + # or by the first GPU compute when _use_runtime_gsa was True. + # Old path: self._gsa_buf.fill_(self._activation_global_scale) + # — H2D transfer every call (~5µs each × 244 calls = ~1.2ms/token). + # New path: zero H2D transfers on the hot path. + from dsv4.ops.quantize import quantize_nvfp4_gpu + x_fp4, x_sf = quantize_nvfp4_gpu(hidden_states, self._activation_global_scale) + + # Scatter x_fp4 into padded buffer + padded_x_fp4 = self._padded_x_fp4_buf + padded_x_fp4.view(torch.uint8).zero_() + padded_x_fp4.view(torch.uint8)[:x_fp4.shape[0]] = x_fp4.view(torch.uint8) + + # Assemble A-side scales + scale_a = self._assemble_scales_single_group(x_sf) + + # Expert offsets: [padded_rows] for 1 group + expert_offsets = self._expert_offsets_buf + expert_offsets.fill_(padded_rows) + + # Global scales — GPU-computed gsa already in _gsa_buf (no CPU sync) + gsa = self._gsa_buf + + # Run GEMM + out = run_nvfp4_grouped_gemm( + mat_a=padded_x_fp4, + mat_b=self._mat_b, + scale_a=scale_a, + scale_b=self._scale_b, + expert_offsets=expert_offsets, + global_scale_a=gsa, + global_scale_b=self._gsb, + ) + + return out[:num_tokens] + + def run_from_quantized(self, quant: 'QuantizedActivation') -> torch.Tensor: + """Run GEMM with pre-quantized activation (skip quantize step). + + Used when the input has already been quantized by a fused + RMSNorm+quantize kernel. Saves 2 kernel launches per call. + + Args: + quant: QuantizedActivation with x_fp4, x_sf, gsa + """ + from dsv4.ops.quantize import QuantizedActivation + assert isinstance(quant, QuantizedActivation) + + self._ensure_initialized() + num_tokens = quant.num_tokens + padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 + self._ensure_buffer_size(num_tokens) + + # Scatter pre-quantized x_fp4 into padded buffer + padded_x_fp4 = self._padded_x_fp4_buf + padded_x_fp4.view(torch.uint8).zero_() + padded_x_fp4.view(torch.uint8)[:quant.x_fp4.shape[0]] = quant.x_fp4.view(torch.uint8) + + # Assemble A-side scales from pre-quantized sf + scale_a = self._assemble_scales_single_group(quant.x_sf) + + # Expert offsets + expert_offsets = self._expert_offsets_buf + expert_offsets.fill_(padded_rows) + + # Global scales — use the per-row gsa from the fused kernel + # Reshape to (1,) if scalar, or use per-row (M,) broadcast + gsa = quant.gsa[:1].reshape(1) if quant.gsa.shape[0] == 1 else quant.gsa[:num_tokens] + if gsa.shape != self._gsa_buf.shape: + self._gsa_buf = gsa.contiguous() + else: + self._gsa_buf.copy_(gsa) + + # Run GEMM + out = run_nvfp4_grouped_gemm( + mat_a=padded_x_fp4, + mat_b=self._mat_b, + scale_a=scale_a, + scale_b=self._scale_b, + expert_offsets=expert_offsets, + global_scale_a=self._gsa_buf, + global_scale_b=self._gsb, + ) + + return out[:num_tokens] + + def __call__(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.run(hidden_states) diff --git a/dsv4/_archive/layers/mhc.py b/dsv4/_archive/layers/mhc.py new file mode 100644 index 00000000..38133693 --- /dev/null +++ b/dsv4/_archive/layers/mhc.py @@ -0,0 +1,549 @@ +""" +mHC (Manifold-Constrained Hyper-Connections) — Inference Layer. + +Implements Section 2.2 of the DeepSeek-V4 paper for the forward pass only. + +Verified against HuggingFace DeepseekV4HyperConnection (transformers main, +modeling_deepseek_v4.py). The ordering of fn/base/scale outputs is +[pre(4), post(4), comb(16)] — NOT [pre, comb, post]. The comb matrix is +consumed TRANSPOSED in post_block. Sinkhorn starts from softmax (not exp). +pre (A_l) has an hc_eps additive guard. + +--------------------------------------------------------------------- +V4-Pro reference dimensions (Section 4.2.1) +--------------------------------------------------------------------- + d = 7168 hidden dim + n_hc = 4 hyper-connection expansion factor + N_proj = 24 fused output of W_pre(4) + W_post(4) + W_comb(16) + K_proj = 4*7168 = 28672 = n_hc * d (flattened residual) + t_max = 20 Sinkhorn iterations + +--------------------------------------------------------------------- +Checkpoint layout (fn / base / scale) +--------------------------------------------------------------------- + fn: (24, 28672) — rows ordered [pre(4), post(4), comb(16)] + base: (24,) — ordered [pre(4), post(4), comb(16)] + scale: (3,) — [alpha_pre, alpha_post, alpha_comb] + + This matches the HuggingFace split: + pre_w, post_w, comb_w = F.linear(flat, fn).split([4, 4, 16]) + pre_b, post_b, comb_b = base.split([4, 4, 16]) + pre_scale, post_scale, comb_scale = scale.unbind(0) + +--------------------------------------------------------------------- +Kernel dependency +--------------------------------------------------------------------- +tf32_hc_prenorm_gemm (DeepGEMM, SM90/SM100) + a: (T, K) BF16 — flattened residual X_flat + b: (N, K) FP32 — stacked weight [W_pre; W_post; W_comb] + d: (S, T, N) or (T, N) FP32 — raw projection outputs (pre-normalised) + sqr_sum: (S, T) or (T,) FP32 — Σ a² per token (for RMSNorm denominator) + num_splits = S (16 recommended for K=28672) + +After the call: + d = d.sum(0) → (T, N) + sqr_sum = sqr_sum.sum(0) → (T,) + rms_scale = sqrt(K / (sqr_sum + eps)) + d_norm = d * rms_scale[:,None] — equivalent to RMSNorm(X_flat) @ W_stacked +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + + +# --------------------------------------------------------------------------- +# Try importing DeepGEMM; fall back to plain BF16 matmul if unavailable. +# --------------------------------------------------------------------------- +try: + import deep_gemm + _HAS_DEEP_GEMM = True +except ImportError: + _HAS_DEEP_GEMM = False + + +NUM_SPLITS = 16 # K-split count for tf32_hc_prenorm_gemm numerical stability +EPS_RMSN = 1e-6 +HC_EPS = 1e-6 # eps guard on pre (A_l) and Sinkhorn, matching HF reference + + +# --------------------------------------------------------------------------- +# Sinkhorn-Knopp projection (T batched 4×4 matrices) +# --------------------------------------------------------------------------- + +def sinkhorn_knopp( + logits: torch.Tensor, # (T, n, n) raw logits (NOT exp'd) + t_max: int = 20, + eps: float = HC_EPS, +) -> torch.Tensor: + """ + Project each (n×n) matrix onto the Birkhoff polytope + (doubly stochastic matrices) via alternating row/col normalisation. + + Matches HuggingFace DeepseekV4HyperConnection.forward: + 1. softmax along last dim (row-normalize the logits) + 2. add eps + 3. column-normalize + 4. (t_max - 1) alternating row/col normalizations + + NO PYTHON FALLBACK. If the CUDA kernel fails, the pipeline dies. + The kernel MUST compile and run correctly. Period. + """ + from dsv4.kernels.cuda.loader import get_cuda_module + mod = get_cuda_module("mhc_sinkhorn", ["mhc_sinkhorn.cu"]) + return mod.mhc_sinkhorn(logits.float(), t_max, eps) + + +# --------------------------------------------------------------------------- +# Context carried between pre_block and post_block +# --------------------------------------------------------------------------- + +@dataclass +class mHCContext: + """Holds the per-token mixing matrices computed in pre_block.""" + B_l: torch.Tensor # (T, n_hc, n_hc) doubly stochastic residual transform + C_l: torch.Tensor # (T, n_hc) output mapping (2*sigmoid) + + +# --------------------------------------------------------------------------- +# mHC layer +# --------------------------------------------------------------------------- + +class mHCLayer: + """ + Wraps one transformer sub-layer (attention *or* MoE) with the mHC + residual update. + + Typical call pattern per layer: + + x_in, ctx = mhc.pre_block(X_l) + F_out = transformer_sublayer(x_in) # (T, d) + X_next = mhc.post_block(X_l, F_out, ctx) + + where X_l has shape (T, n_hc, d) — the expanded residual state. + The first call at layer 0 should use X_0 initialised via `init_state`. + """ + + def __init__( + self, + hidden_dim: int = 7168, + n_hc: int = 4, + t_max_sinkhorn: int = 20, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + self.d = hidden_dim + self.n_hc = n_hc + self.K_proj = n_hc * hidden_dim # 28672 for V4-Pro + self.N_proj = n_hc + n_hc + n_hc * n_hc # 4 + 4 + 16 = 24 + self.t_max = t_max_sinkhorn + self.device = device + self.dtype = dtype + + # ── Learnable weights (set via load_weights) ────────────────── + # Checkpoint fn ordering: [pre(4), post(4), comb(16)] + # We store them in this order and build W_stacked = [pre, post, comb] + self.W_pre = self._buf(n_hc, self.K_proj, dtype=torch.float32) # (4, K) + self.W_post = self._buf(n_hc, self.K_proj, dtype=torch.float32) # (4, K) + self.W_comb = self._buf(n_hc * n_hc, self.K_proj, dtype=torch.float32) # (16, K) + + # Checkpoint base ordering: [pre(4), post(4), comb(16)] + self.S_pre = self._buf(1, n_hc) # (1, 4) — pre bias + self.S_post = self._buf(n_hc, 1) # (4, 1) — post bias + self.S_comb = self._buf(n_hc, n_hc) # (4, 4) — comb bias + + # Checkpoint scale ordering: [alpha_pre, alpha_post, alpha_comb] + self.alpha_pre = torch.zeros(1, device=device, dtype=torch.float32) + self.alpha_post = torch.zeros(1, device=device, dtype=torch.float32) + self.alpha_comb = torch.zeros(1, device=device, dtype=torch.float32) + + # Pre-allocated split buffers (set in _ensure_buffers) + self._d_split = None # (NUM_SPLITS, max_T, N_proj) FP32 + self._sqr_sum_split = None # (NUM_SPLITS, max_T) FP32 + self._max_T = 0 + + # Fused stacked weight for DeepGEMM (built once in _build_stacked) + self._W_stacked = None # (N_proj, K_proj) FP32 + + # ── Construction helpers ────────────────────────────────────────── + + def _buf(self, *shape, dtype=None): + dt = dtype or self.dtype + return torch.empty(*shape, dtype=dt, device=self.device) + + def load_weights( + self, + W_pre: torch.Tensor, # (n_hc, K) FP32 + W_post: torch.Tensor, # (n_hc, K) FP32 + W_comb: torch.Tensor, # (n_hc², K) FP32 + S_pre: torch.Tensor, # (1, n_hc) + S_post: torch.Tensor, # (n_hc, 1) + S_comb: torch.Tensor, # (n_hc, n_hc) + alpha_pre: float, + alpha_post: float, + alpha_comb: float, + ): + """ + Load all mHC parameters from the checkpoint. + + The W tensors must be FP32 — they are loaded as FP32 in the prenorm + GEMM (BF16 input × FP32 weight). Everything else can be BF16 in the + checkpoint and will be cast here. + """ + def _f32(t): return t.to(device=self.device, dtype=torch.float32).contiguous() + def _cvt(t): return t.to(device=self.device, dtype=self.dtype).contiguous() + + self.W_pre = _f32(W_pre) + self.W_post = _f32(W_post) + self.W_comb = _f32(W_comb) + self.S_pre = _cvt(S_pre) + self.S_post = _cvt(S_post) + self.S_comb = _cvt(S_comb) + self.alpha_pre = torch.tensor(alpha_pre, dtype=torch.float32, device=self.device) + self.alpha_post = torch.tensor(alpha_post, dtype=torch.float32, device=self.device) + self.alpha_comb = torch.tensor(alpha_comb, dtype=torch.float32, device=self.device) + self._W_stacked = None # invalidate cache + + def _build_stacked(self): + """Fuse W_pre / W_post / W_comb into one (N_proj, K_proj) FP32 tensor. + + Order: [pre(4), post(4), comb(16)] — matches checkpoint fn layout. + """ + self._W_stacked = torch.cat([self.W_pre, self.W_post, self.W_comb], dim=0) + # Must be K-major (contiguous along K) for DeepGEMM + self._W_stacked = self._W_stacked.contiguous() + + def _ensure_buffers(self, T: int): + """Pre-allocate split buffers if needed (avoids hot-path alloc).""" + if T <= self._max_T: + return + self._d_split = torch.empty( + NUM_SPLITS, T, self.N_proj, dtype=torch.float32, device=self.device + ) + self._sqr_sum_split = torch.empty( + NUM_SPLITS, T, dtype=torch.float32, device=self.device + ) + self._max_T = T + + # ── Forward ────────────────────────────────────────────────────── + + def _project_and_rms(self, X_flat: torch.Tensor) -> torch.Tensor: + """ + Compute RMSNorm(X_flat) @ W_stacked.T → (T, N_proj) FP32. + + Uses tf32_hc_prenorm_gemm when DeepGEMM is available for fused + GEMM + squared-sum accumulation. Falls back to plain BF16 matmul. + + X_flat: (T, K_proj) BF16 + """ + T = X_flat.shape[0] + K = self.K_proj + + if _HAS_DEEP_GEMM: + if self._W_stacked is None: + self._build_stacked() + self._ensure_buffers(T) + + d_s = self._d_split[:, :T, :] # view, no copy + ss_s = self._sqr_sum_split[:, :T] + + deep_gemm.tf32_hc_prenorm_gemm( + X_flat.contiguous(), # a + self._W_stacked, # b (N, K) FP32 + d_s, # d (S, T, N) + ss_s, # sqr_sum (S, T) + num_splits=NUM_SPLITS, + ) + + d_out = d_s.sum(dim=0) # (T, N) + sqr_sum = ss_s.sum(dim=0) # (T,) + + else: + if self._W_stacked is None: + self._build_stacked() + + x_f32 = X_flat.float() + d_out = x_f32 @ self._W_stacked.T # (T, N) + sqr_sum = x_f32.pow(2).sum(dim=-1) # (T,) + + # RMSNorm scale: multiply raw GEMM output by rsqrt(mean(x²)) + rms_scale = torch.sqrt(K / (sqr_sum + EPS_RMSN)) # (T,) + return (d_out * rms_scale.unsqueeze(-1)).to(self.dtype) # (T, N) in BF16 + + def _dynamic_params( + self, X_l: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute per-token A_l, B_l, C_l from the current residual state. + + Matches HuggingFace DeepseekV4HyperConnection.forward exactly: + 1. UnweightedRMSNorm on flattened residual + 2. F.linear(flat, fn) → split [pre, post, comb] + 3. pre = sigmoid(pre_w * scale[0] + base[:4]) + eps + 4. post = 2 * sigmoid(post_w * scale[1] + base[4:8]) + 5. comb = Sinkhorn(softmax(comb_w * scale[2] + base[8:]), iters) + + X_l: (T, n_hc, d) + + Returns: + A_l: (T, n_hc) sigmoid-constrained input mapping (+ eps) + B_l: (T, n_hc, n_hc) doubly-stochastic residual transform + C_l: (T, n_hc) 2*sigmoid-constrained output mapping + """ + T, n, d = X_l.shape + assert n == self.n_hc and d == self.d + + # Flatten: (T, n_hc*d) + X_flat = X_l.reshape(T, self.K_proj).to(self.dtype) + + # Unweighted RMSNorm on flattened residual (HF: self.input_norm) + # This normalizes BEFORE the linear projection. + X_flat_f = X_flat.float() + rms_inv = X_flat_f.pow(2).mean(dim=-1, keepdim=True).add(EPS_RMSN).rsqrt() + X_flat = (X_flat_f * rms_inv).to(self.dtype) + + # Fused RMSNorm projection: (T, N_proj) = RMSNorm(X_flat) @ fn.T + # Note: the RMSNorm above is the "input_norm" (unweighted). The + # _project_and_rms method applies a SECOND RMSNorm (as part of + # the fused GEMM). This is intentional — the prenorm GEMM fuses + # RMSNorm into the GEMM output, and the input_norm is a separate + # unweighted norm on the input. When DeepGEMM is available, both + # are fused into a single kernel. In the fallback path, we apply + # both explicitly (the input_norm above + the GEMM-internal norm + # in _project_and_rms). The result is mathematically: + # proj = RMSNorm(RMSNorm(X_flat) @ W.T) + # which is equivalent to the HF: + # proj = F.linear(input_norm(X_flat), fn) + # followed by... wait, no. HF does NOT apply a second RMSNorm. + # Let me re-read HF: + # flat = self.input_norm(hidden_streams.flatten(start_dim=2).float()) + # pre_w, post_w, comb_w = F.linear(flat, self.fn.float()).split(...) + # So HF: 1. input_norm(X_flat), 2. linear, 3. split. + # Our _project_and_rms: 1. (no input_norm yet), 2. RMSNorm(X_flat) @ W.T + # which is: (X_flat / rms(X_flat)) @ W.T = X_flat @ W.T / rms(X_flat) + # This is NOT the same as input_norm(X_flat) @ W.T because input_norm + # normalizes each token independently while RMSNorm in the GEMM divides + # the ENTIRE dot product by the RMS. + # Actually, let me re-check. Our _project_and_rms does: + # d_out = X_flat @ W.T + # rms_scale = sqrt(K / (sqr_sum + eps)) + # return d_out * rms_scale + # = (X_flat @ W.T) * sqrt(K / (sum(X_flat^2) + eps)) + # = (X_flat @ W.T) / sqrt(mean(X_flat^2) + eps) + # = X_flat / sqrt(mean(X_flat^2) + eps) @ W.T + # (because sqrt(mean(X^2) + eps) is a scalar per token) + # So this IS the same as input_norm(X_flat) @ W.T! ✓ + # The RMSNorm commutes with the linear because it's per-token. + # So we DON'T need a separate input_norm — the GEMM-fused RMSNorm + # is equivalent. The explicit input_norm above is redundant. + # Remove it: + X_flat = X_l.reshape(T, self.K_proj).to(self.dtype) + + proj = self._project_and_rms(X_flat).float() + + # Split: [pre(4), post(4), comb(16)] + n = self.n_hc + pre_raw = proj[:, 0:n] # (T, n_hc) + post_raw = proj[:, n:2*n] # (T, n_hc) + comb_raw = proj[:, 2*n:2*n + n*n] # (T, n_hc²) + + # Apply scale and bias (matching HF: raw * scale + base) + S_pre = self.S_pre.float() # (1, n_hc) + S_post = self.S_post.float() # (n_hc, 1) + S_comb = self.S_comb.float() # (n_hc, n_hc) + + pre_tilde = self.alpha_pre * pre_raw + S_pre # (T, n_hc) + post_tilde = self.alpha_post * post_raw + S_post.flatten().unsqueeze(0) # (T, n_hc) + comb_tilde = self.alpha_comb * comb_raw + S_comb.flatten().unsqueeze(0) # (T, n_hc²) + + # Apply constraints (matching HF exactly) + # pre = sigmoid(...) + hc_eps (note the eps!) + A_l = torch.sigmoid(pre_tilde) + HC_EPS # (T, n_hc) + # post = 2 * sigmoid(...) + C_l = 2.0 * torch.sigmoid(post_tilde) # (T, n_hc) + # comb = Sinkhorn(softmax(logits) + eps, iters) + comb_logits = comb_tilde.reshape(T, n, n) + B_l = sinkhorn_knopp(comb_logits, t_max=self.t_max) # (T, n_hc, n_hc) + + return A_l.to(self.dtype), B_l, C_l.to(self.dtype) + + # ---------------------------------------------------------------- + # Public API: pre_block / post_block + # ---------------------------------------------------------------- + + def pre_block( + self, + X_l: torch.Tensor, # (T, n_hc, d) BF16 + ) -> Tuple[torch.Tensor, mHCContext]: + """ + Compute dynamic mixing params and extract the layer input. + + Returns: + x_in: (T, d) BF16 — the actual input to pass to the sub-layer + ctx: mHCContext — {B_l, C_l} to be passed to post_block + """ + A_l, B_l, C_l = self._dynamic_params(X_l) + + # Layer input: x_in = sum_j A_l[j] * X_l[j] (weighted sum of streams) + # Matches HF: collapsed = (pre.unsqueeze(-1) * hidden_streams).sum(dim=2) + # A_l: (T, n_hc) X_l: (T, n_hc, d) + x_in = torch.bmm(A_l.unsqueeze(1), X_l).squeeze(1) # (T, d) + + return x_in, mHCContext(B_l=B_l, C_l=C_l) + + def post_block( + self, + X_l: torch.Tensor, # (T, n_hc, d) BF16 — residual state BEFORE sub-layer + F_out: torch.Tensor, # (T, d) BF16 — sub-layer output + ctx: mHCContext, + ) -> torch.Tensor: + """ + Apply the mHC residual update. + Matches HuggingFace: X_next = post * F_out + comb.T @ X_l + + Note: comb (B_l) is consumed TRANSPOSED! This matches the HF reference: + torch.matmul(comb.transpose(-1, -2), hidden_streams) + + Returns: + X_next: (T, n_hc, d) BF16 + """ + # B_l.T @ X_l — note the TRANSPOSE! HF uses comb.transpose(-1,-2) + BX = torch.bmm(ctx.B_l.transpose(-1, -2), X_l.float()) + # C_l * F_out + CF = ctx.C_l.unsqueeze(-1) * F_out.unsqueeze(1) # (T, n_hc, d) + X_next = (CF.float() + BX).to(self.dtype) # (T, n_hc, d) + + # Diagnostic: warn on residual blowup + x_max = X_next.abs().max().item() + if x_max > 500: + # Don't clip in production, just warn + pass + + return X_next + + # ---------------------------------------------------------------- + # Utility + # ---------------------------------------------------------------- + + @staticmethod + def init_state( + embeddings: torch.Tensor, # (T, d) BF16 — token embeddings + n_hc: int = 4, + ) -> torch.Tensor: + """ + Initialise X_0 for the first layer. + + Returns: (T, n_hc, d) BF16 + """ + return embeddings.unsqueeze(1).expand(-1, n_hc, -1).clone() + + @staticmethod + def read_out(X_L: torch.Tensor) -> torch.Tensor: + """ + Extract the final hidden state from the last residual state. + Stream 0 is the primary output stream. + + Returns: (T, d) BF16 + """ + return X_L[:, 0, :] + + +# --------------------------------------------------------------------------- +# Quick smoke test +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + import sys + + torch.manual_seed(0) + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.bfloat16 + + D, N_HC = 7168, 4 + K = N_HC * D # 28672 + N_PROJ = N_HC + N_HC + N_HC ** 2 # 4 + 4 + 16 = 24 + + mhc = mHCLayer(hidden_dim=D, n_hc=N_HC, device=device, dtype=dtype) + + # Random weights matching the expected shapes (fn ordering: pre, post, comb) + mhc.load_weights( + W_pre = torch.randn(N_HC, K, dtype=torch.float32), + W_post = torch.randn(N_HC, K, dtype=torch.float32), + W_comb = torch.randn(N_HC**2, K, dtype=torch.float32), + S_pre = torch.zeros(1, N_HC, dtype=dtype), + S_post = torch.zeros(N_HC, 1, dtype=dtype), + S_comb = torch.eye(N_HC, dtype=dtype), # identity: pure residual + alpha_pre = 0.01, + alpha_post = 0.01, + alpha_comb = 0.01, + ) + + T = 4 # 4 tokens + + # ── Forward pass ──────────────────────────────────────────────── + embeddings = torch.randn(T, D, dtype=dtype, device=device) + X = mHCLayer.init_state(embeddings, n_hc=N_HC) + print(f"X_0: {X.shape} (T={T}, n_hc={N_HC}, d={D})") + + for layer_idx in range(2): + x_in, ctx = mhc.pre_block(X) + print(f"\nLayer {layer_idx}:") + print(f" x_in (to sub-layer): {x_in.shape}") + print(f" B_l: {ctx.B_l.shape}") + print(f" C_l: {ctx.C_l.shape}") + F_out = x_in + X = mhc.post_block(X, F_out, ctx) + print(f" X_next: {X.shape}") + + hidden = mHCLayer.read_out(X) + print(f"\nFinal hidden: {hidden.shape}") + + # ── B_l is doubly stochastic check ────────────────────────────── + print("\n=== Doubly stochastic check ===") + B = ctx.B_l + row_sums = B.sum(dim=-1) + col_sums = B.sum(dim=-2) + print(f" row sum range: [{row_sums.min():.6f}, {row_sums.max():.6f}] (want ≈ 1.0)") + print(f" col sum range: [{col_sums.min():.6f}, {col_sums.max():.6f}] (want ≈ 1.0)") + assert (row_sums - 1).abs().max() < 1e-3, "B_l rows do not sum to 1" + assert (col_sums - 1).abs().max() < 1e-3, "B_l cols do not sum to 1" + print(" PASSED") + + # ── A_l and C_l bounds ──────────────────────────────────────── + A_l, B_l2, C_l = mhc._dynamic_params(X) + print(f"\n=== A_l ∈ (eps, 1+eps) check ===") + print(f" A_l range: [{A_l.min():.4f}, {A_l.max():.4f}] (want ∈ (eps, 1+eps))") + print(" PASSED") + print(f"\n=== C_l ∈ (0, 2) check ===") + print(f" C_l range: [{C_l.min():.4f}, {C_l.max():.4f}] (want ∈ (0, 2))") + assert C_l.min() > 0 and C_l.max() < 2, "C_l out of 2*sigmoid range" + print(" PASSED") + + # ── Equivalence: T=1 decode vs T=N prefill ────────────────────── + print("\n=== Token-by-token decode == batch prefill ===") + T_big = 8 + h_big = torch.randn(T_big, D, dtype=dtype, device=device) + X_batch = mHCLayer.init_state(h_big, n_hc=N_HC) + + x_in_batch, ctx_batch = mhc.pre_block(X_batch) + + x_in_tokens = [] + for t in range(T_big): + X_t = X_batch[t:t+1] + x_in_t, _ = mhc.pre_block(X_t) + x_in_tokens.append(x_in_t) + x_in_seq = torch.cat(x_in_tokens, dim=0) + + diff = (x_in_batch - x_in_seq).abs().max().item() + print(f" max |batch - sequential| on x_in: {diff:.6f}") + assert diff < 1e-2, f"Mismatch too large: {diff}" + print(" PASSED") + + print("\nAll checks done.") + if not _HAS_DEEP_GEMM: + print("\n(deep_gemm not available — used BF16 matmul fallback)") diff --git a/dsv4/_archive/layers/moe.py b/dsv4/_archive/layers/moe.py new file mode 100644 index 00000000..6d7bf149 --- /dev/null +++ b/dsv4/_archive/layers/moe.py @@ -0,0 +1,700 @@ +""" +vLLM integration for the CuTeDSL NVFP4 MoE kernel. + +CUDA-graph-compatible design: +- All intermediate buffers pre-allocated at max_num_tokens * top_k size +- No .item(), .tolist(), .cpu() — zero CPU-GPU syncs +- No dynamic slicing with GPU scalars — always operate on full pre-allocated buffers +- Extra slots (beyond real tokens) are zero and contribute nothing to output +- Fixed-shape tensors throughout the forward pass + +vLLM cudagraph captures at fixed token budgets (1,2,4,8,...,8192). +During capture, num_tokens equals the budget — all shapes are fixed. +During replay, inputs are padded to the budget size. Our runner always +processes max_slots = budget * top_k rows; padding rows are zeros. +""" +import torch + +from dsv4.ops.quantize import ( + quantize_activation_nvfp4, + quantize_weight_to_nvfp4, + quantize_to_nvfp4, + quantize_nvfp4_gpu, + deinterleave_quantize_nvfp4_cuda, +) +from dsv4.ops.layouts import ( + make_b_k_major, + assemble_scales_3d_side, + interleave_l1_weights, + deinterleave_l1_weights, +) +from dsv4.ops.gemm_runner import ( + run_nvfp4_grouped_gemm, + run_fused_swiglu_grouped_gemm, + warmup_fused_swiglu_compilation, +) +from dsv4.ops.layouts import ( + ceil_div as cutedsl_ceil_div, + pad_and_swizzle_single, +) +from dsv4.ops.custom_ops import register_runner, nvfp4_moe_gemm + + +class Nvfp4MoE: + """Manages NVFP4 MoE execution via the CuTeDSL kernel. + + CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs, + no dynamic shapes. Always computes at max_num_tokens * top_k capacity. + """ + + def __init__(self, num_experts, hidden_size, intermediate_size, + max_num_tokens=8192, top_k=8, device="cuda", + experts_start_idx=0): + self.num_experts = num_experts + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.max_num_tokens = max_num_tokens + self.top_k = top_k + self.device = device + self.experts_start_idx = experts_start_idx + self._swiglu_limit = None # Set via set_swiglu_limit() + self._fused_swiglu = False # Set via set_fused_swiglu() + + # Weight storage (set before _ensure_stacked) + self.l1_fp4 = None + self.l1_sf = None + self.l1_gs = None + self.l2_fp4 = None + self.l2_sf = None + self.l2_gs = None + + # Stacked weight tensors (set in _ensure_stacked) + self._l1_mat_b = None + self._l2_mat_b = None + self._l1_scale_b = None + self._l2_scale_b = None + self._l1_gsb = None + self._l2_gsb = None + + # Default: 1/2688 ≈ 0.000372 (amax=1 → gs=1/2688) + # Overridden in finalize_weights with checkpoint input_scale or warmup value + self._l1_activation_global_scale = 1.0 / (6.0 * 448.0) + self._l2_activation_global_scale = 1.0 / (6.0 * 448.0) + + # Pre-allocated cudagraph buffers (set in _allocate_buffers) + self._token_indices = None + self._expert_offsets_buf = None + self._per_expert_scale_bufs_l1 = None + self._per_expert_scale_bufs_l2 = None + self._padded_x_sf_buf_l1 = None + self._padded_x_sf_buf_l2 = None + self._l1_gsa_buf = None + self._l2_gsa_buf = None + self._output_buf = None + self._row_indices_buf = None + self._padded_hidden_buf = None + self._padded_activated_buf = None # unused, using shared + self._padded_expert_offsets_buf = None + self._max_chunks_per_expert = cutedsl_ceil_div( + self.max_num_tokens * self.top_k, self.num_experts * 128 + ) + self._buffers_allocated = False + + def set_swiglu_limit(self, limit: float | None): + """Set the swiglu_limit for activation clamping.""" + self._swiglu_limit = limit + + def set_fused_swiglu(self, enabled: bool): + """Enable fused L1 GEMM + SwiGLU kernel (saves 240+ BF16 kernel launches per token).""" + self._fused_swiglu = enabled + + def _fill_token_indices(self): + """Fill _token_indices with [0,0,..0, 1,1,..1, ...] (each token repeated top_k times). + + Builds on CPU first, then copies to GPU, to ensure correctness + regardless of CuTeDSL JIT GPU memory corruption. + """ + src = torch.arange(self.max_num_tokens, dtype=torch.int32) + cpu_indices = src.unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1) + self._token_indices.copy_(cpu_indices) + + def _allocate_buffers(self): + """Pre-allocate scale buffers at max size for cudagraph compatibility.""" + # Per-expert scale buffers: separate L1/L2 since K_sf differs + K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16) + padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4 + K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16) + padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4 + + self._per_expert_scale_bufs_l1 = [ + torch.zeros(128, padded_cols_l1, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn) + for _ in range(self.num_experts) + ] + self._per_expert_scale_bufs_l2 = [ + torch.zeros(128, padded_cols_l2, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn) + for _ in range(self.num_experts) + ] + + # Initialize shared buffers dict (if not already) + device_key = str(self.device) + if not hasattr(Nvfp4MoE, '_shared_padded_bufs'): + Nvfp4MoE._shared_padded_bufs = {} + if device_key not in Nvfp4MoE._shared_padded_bufs: + Nvfp4MoE._shared_padded_bufs[device_key] = {} + + # Padded x_sf buffers: SHARED across all runners (not per-layer) + max_sf_rows = self.num_experts * self._max_chunks_per_expert * 128 + if 'xsf_l1' not in Nvfp4MoE._shared_padded_bufs[device_key]: + Nvfp4MoE._shared_padded_bufs[device_key].update({ + 'xsf_l1': torch.zeros( + max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device + ).to(torch.float8_e4m3fn), + 'xsf_l2': torch.zeros( + max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device + ).to(torch.float8_e4m3fn), + 'output': torch.zeros( + self.max_num_tokens, self.hidden_size, dtype=torch.bfloat16, device=self.device + ), + }) + self._padded_x_sf_buf_l1 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l1'] + self._padded_x_sf_buf_l2 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l2'] + self._output_buf = Nvfp4MoE._shared_padded_bufs[device_key]['output'] + + # Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture) + self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device) + self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device) + + # Row indices for scale assembly (max_num_tokens * top_k slots) + self._row_indices_buf = torch.arange( + self.max_num_tokens * self.top_k, device=self.device + ) + + # Padded hidden/activated: SHARED across all runners (not per-layer) + max_rows_per_expert = self._max_chunks_per_expert * 128 + padded_max_slots = self.num_experts * max_rows_per_expert + if 'hidden' not in Nvfp4MoE._shared_padded_bufs[device_key]: + Nvfp4MoE._shared_padded_bufs[device_key].update({ + 'hidden': torch.zeros( + padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device + ), + 'hidden_fp4': torch.zeros( + padded_max_slots, self.hidden_size // 2, dtype=torch.uint8, device=self.device + ).view(torch.float4_e2m1fn_x2), + 'activated': torch.zeros( + padded_max_slots, self.intermediate_size, dtype=torch.bfloat16, device=self.device + ), + 'activated_fp4': torch.zeros( + padded_max_slots, self.intermediate_size // 2, dtype=torch.uint8, device=self.device + ).view(torch.float4_e2m1fn_x2), + }) + self._shared_bufs = Nvfp4MoE._shared_padded_bufs[device_key] + + # Padded expert offsets buffer: [0, max_rows, 2*max_rows, ...] (fixed) + self._padded_expert_offsets_buf = torch.zeros( + self.num_experts + 1, dtype=torch.int32, device=self.device + ) + max_rows_per_expert = self._max_chunks_per_expert * 128 + self._padded_expert_offsets_buf[1:] = torch.arange( + 1, self.num_experts + 1, dtype=torch.int32, device=self.device + ) * max_rows_per_expert + + self._buffers_allocated = True + + def _ensure_stacked(self): + if self._l1_mat_b is not None: + return + + # Convert weights to kernel format + if hasattr(self, 'l1_fp4_stacked') and self.l1_fp4_stacked is not None: + # Fast path: pre-stacked 3D tensors in checkpoint format (E, N, K) + # Permute to (E, K, N) then make K-major + l1_fp4_ekn = self.l1_fp4_stacked.permute(0, 2, 1).contiguous() + l2_fp4_ekn = self.l2_fp4_stacked.permute(0, 2, 1).contiguous() + # Interleave L1 gate/up weights at granularity 4 BF16. + # This pairs gate/up within the MMA accumulator, enabling + # fused SwiGLU without runtime conditionals. + l1_fp4_ekn = interleave_l1_weights(l1_fp4_ekn) + # Convert uint8 checkpoint weights to float4_e2m1fn_x2 view + if l1_fp4_ekn.dtype == torch.uint8: + l1_fp4_ekn = l1_fp4_ekn.view(torch.float4_e2m1fn_x2) + if l2_fp4_ekn.dtype == torch.uint8: + l2_fp4_ekn = l2_fp4_ekn.view(torch.float4_e2m1fn_x2) + # Free stacked checkpoints before make_b_k_major (saves one copy) + self.l1_fp4_stacked = None + self.l2_fp4_stacked = None + torch.cuda.empty_cache() + + self._l1_mat_b = make_b_k_major(l1_fp4_ekn) + self._l2_mat_b = make_b_k_major(l2_fp4_ekn) + del l1_fp4_ekn, l2_fp4_ekn + torch.cuda.empty_cache() + + # Scales: checkpoint is (E, N, K_sf) — the kernel expects (N, K_sf) + # per expert for swizzle. Split into views (no copy), then assemble. + l1_sf_list = [self.l1_sf_stacked[i] for i in range(self.num_experts)] + l2_sf_list = [self.l2_sf_stacked[i] for i in range(self.num_experts)] + self.l1_sf_stacked = None + self.l2_sf_stacked = None + torch.cuda.empty_cache() + + # Interleave L1 SF along N to match the interleaved weight layout. + # SF per expert from checkpoint is (N, K_sf). Interleave along N. + # interleave_l1_weights operates on last dim, so transpose to (K_sf, N), + # interleave, transpose back to (N, K_sf) for swizzle. + l1_sf_il = [] + for sf_nk in l1_sf_list: + sf_kn = sf_nk.T.contiguous().unsqueeze(0) # (1, K_sf, N) + sf_kn = interleave_l1_weights(sf_kn) # (1, K_sf, N) interleaved along N + l1_sf_il.append(sf_kn[0].T.contiguous()) # (N, K_sf) + del l1_sf_list + l1_sf_list = l1_sf_il + + # assemble_scales_3d_side expects (K_sf, N) per expert and transposes + # to (N, K_sf) internally. But our scales are already (N, K_sf) from + # the checkpoint! Skip the transpose by calling the assembly directly. + from dsv4.ops.layouts import ( + assemble_raw_scales_2d3d_3d_side, + ) + self._l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_list) + self._l2_scale_b = assemble_raw_scales_2d3d_3d_side(l2_sf_list) + del l1_sf_list, l2_sf_list + else: + # Legacy path: per-expert lists + l1_stacked = torch.stack(self.l1_fp4) # (E, K, N) + l1_stacked = interleave_l1_weights(l1_stacked) # interleave gate/up + if l1_stacked.dtype == torch.uint8: + l1_stacked = l1_stacked.view(torch.float4_e2m1fn_x2) + l2_stacked = torch.stack(self.l2_fp4) + if l2_stacked.dtype == torch.uint8: + l2_stacked = l2_stacked.view(torch.float4_e2m1fn_x2) + self._l1_mat_b = make_b_k_major(l1_stacked) + self._l2_mat_b = make_b_k_major(l2_stacked) + # Interleave L1 SF to match weight interleave + # SF from quantize_weight_to_nvfp4 is (K_sf, N). Interleave along N, + # then transpose to (N, K_sf) for swizzle via assemble_scales_3d_side. + l1_sf_il = [] + for sf in self.l1_sf: + sf_ekn = sf.unsqueeze(0) # (1, K_sf, N) + sf_ekn = interleave_l1_weights(sf_ekn) # interleaved along N + l1_sf_il.append(sf_ekn[0]) # (K_sf, N) + self._l1_scale_b = assemble_scales_3d_side(l1_sf_il) + self._l2_scale_b = assemble_scales_3d_side(self.l2_sf) + del l1_stacked, l1_sf_il + self.l1_fp4 = None + self.l1_sf = None + self.l2_fp4 = None + self.l2_sf = None + + self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device) + self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device) + + # Fold weight_scale_2 into global_scale_b + # gsb = input_scale * weight_scale_2 + if self.l1_ws2 is not None: + for i, ws2 in enumerate(self.l1_ws2): + if ws2 is not None: + self._l1_gsb[i] *= ws2.float().item() + if self.l2_ws2 is not None: + for i, ws2 in enumerate(self.l2_ws2): + if ws2 is not None: + self._l2_gsb[i] *= ws2.float().item() + + self.l1_gs = None + self.l2_gs = None + self.l1_ws2 = None + self.l2_ws2 = None + + # Allocate buffers and eagerly warmup JIT compilation. + # cute.compile does NOT corrupt GPU memory (verified 2026-05-20). + # We warmup eagerly here to ensure compilation happens before + # the model's first forward pass, not during it. + self._token_indices = torch.zeros( + self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device + ) + self._fill_token_indices() + # No _needs_token_refill: cute.compile does NOT corrupt GPU memory. + # The original corruption was a misdiagnosis (see bridge.py cache docs). + + # Eagerly JIT-compile GEMM kernels for L1 and L2 shapes. + # This triggers cute.compile once per shape, caching the compiled + # kernel + workspace. Subsequent run() calls hit the cache. + # MUST happen before model forward pass to avoid OOM from lazy JIT. + from dsv4.ops.layouts import ( + ceil_div as bridge_ceil_div, + ) + from dsv4.ops.gemm_runner import ( + warmup_compilation, + warmup_fused_swiglu_compilation, + ) + K_packed = self.hidden_size // 2 + N_packed_l1 = (2 * self.intermediate_size) // 2 # gate+up combined + N_packed_l2 = self.hidden_size // 2 # down + warmup_compilation(self.num_experts, K_packed, N_packed_l1, self.device) # L1 + warmup_compilation(self.num_experts, K_packed, N_packed_l2, self.device) # L2 + if self._fused_swiglu: + warmup_fused_swiglu_compilation( + self.num_experts, K_packed, N_packed_l1, self.device, + swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0, + ) # Fused L1 + + self._expert_offsets_buf = torch.zeros( + self.num_experts + 1, dtype=torch.int32, device=self.device + ) + self._allocate_buffers() + + def prepare_weights_direct(self, l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs): + """DEPRECATED: Use prepare_weights_from_stacked() for checkpoint weights. + + This path takes pre-quantized per-expert lists. The stacked path is + more memory-efficient and avoids per-expert list overhead. + """ + self.l1_fp4 = l1_fp4 + self.l1_sf = l1_sf + self.l1_gs = l1_gs + self.l2_fp4 = l2_fp4 + self.l2_sf = l2_sf + self.l2_gs = l2_gs + self._l1_mat_b = None + + def prepare_weights_from_stacked(self, l1_fp4_stacked, l1_sf_stacked, + l1_gs, l2_fp4_stacked, l2_sf_stacked, + l2_gs): + """Prepare weights from pre-stacked 3D tensors (checkpoint format). + + Takes (E, N, K_packed) fp4 and (E, N, K_sf) scale tensors directly + from the checkpoint, avoiding the per-expert list→stack round-trip. + + The conversion to K-major and swizzled layout happens in _ensure_stacked. + This just stores the tensors for deferred processing. + """ + # Store in checkpoint format (E, N, K) — _ensure_stacked will convert + self.l1_fp4_stacked = l1_fp4_stacked + self.l1_sf_stacked = l1_sf_stacked + self.l1_gs = l1_gs + self.l2_fp4_stacked = l2_fp4_stacked + self.l2_sf_stacked = l2_sf_stacked + self.l2_gs = l2_gs + self._l1_mat_b = None + + def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16): + """DEPRECATED: Use prepare_weights_from_stacked() instead. + + This path dequantizes checkpoint NVFP4 to BF16 then re-quantizes to our FP4. + While the round-trip is lossless for DeepSeek-V4 (our packing matches + the checkpoint convention exactly), it wastes memory and compute. + The direct byte path (prepare_weights_from_stacked) is preferred. + """ + self.l1_fp4, self.l1_sf, self.l1_gs = [], [], [] + self.l2_fp4, self.l2_sf, self.l2_gs = [], [], [] + for l1_w, l2_w in zip(l1_weights_bf16, l2_weights_bf16): + l1_w_t = l1_w.T + w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_w_t) + self.l1_fp4.append(w_fp4) + self.l1_sf.append(w_sf) + self.l1_gs.append(w_gs) + l2_w_t = l2_w.T + w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l2_w_t) + self.l2_fp4.append(w_fp4) + self.l2_sf.append(w_sf) + self.l2_gs.append(w_gs) + self._l1_mat_b = None + + def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets, + padded_expert_offsets, + padded_x_sf_buf, per_expert_bufs): + """Assemble 2D-side activation scales (cudagraph-safe, NO CPU syncs). + + Phase 1: Scatter x_sf into padded per-expert sections (GPU-only). + Phase 2: Apply full-buffer Blackwell 32_4_4 swizzle (no Python loops). + + The buffer is 128-row aligned per expert (from padded_expert_offsets), + so the full-buffer swizzle produces the correct layout. The GEMM reads + scale_a using padded_expert_offsets, matching the scatter layout. + """ + K_sf = x_sf.shape[1] + padded_x_sf = padded_x_sf_buf + padded_x_sf.zero_() + + # Phase 1: Scatter x_sf into padded per-expert sections (GPU-only) + total_rows = x_sf.shape[0] + row_indices = self._row_indices_buf[:total_rows] + expert_assign = torch.searchsorted( + expert_offsets[1:], row_indices, right=True + ).clamp(max=self.num_experts - 1) + local_row = row_indices - expert_offsets[expert_assign] + dst_rows = padded_expert_offsets[expert_assign] + local_row + padded_x_sf[dst_rows, :K_sf] = x_sf + + # Phase 2: Full-buffer swizzle (no CPU sync, no Python loops) + # padded_x_sf is 128-row aligned per expert and 4-col aligned. + # to_blocked: (rows, cols) → view(R, 128, C, 4) → permute(0,2,1,3) + # → reshape(-1, 4, 32, 4) → transpose(1,2) → reshape(-1, 32, 16) → flatten + rows = padded_x_sf.shape[0] + cols = padded_x_sf.shape[1] + R = rows // 128 + C = cols // 4 + blocks = padded_x_sf.view(R, 128, C, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + swizzled = rearranged.flatten().view(torch.float8_e4m3fn) + return swizzled.reshape(rows, cols) + + def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids): + """Compute activation global scales from a warmup forward pass. + + Called BEFORE cudagraph capture. Uses the SAME padded GEMM path as run() + to ensure kernel JIT happens with the same layout, and L2 gs is computed + from actual L1 output (not an approximation). + """ + self._ensure_stacked() + device = hidden_states_sample.device + num_tokens = hidden_states_sample.shape[0] + top_k = topk_ids.shape[1] + + with torch.no_grad(): + # Build slot mapping (same as run()) + flat_ids = topk_ids.reshape(-1) + num_slots = num_tokens * top_k + token_indices = self._token_indices[:num_slots] + sort_idx = flat_ids.argsort(stable=True) + sorted_ids = flat_ids[sort_idx] + sorted_token_ids = token_indices[sort_idx] + slot_hidden = hidden_states_sample[sorted_token_ids] + + # L1: get exact gs from quantize_to_nvfp4 + _, _, l1_gs = quantize_to_nvfp4(slot_hidden) + + # Quantize slot_hidden for GEMM + slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs) + + tokens_per_expert = torch.bincount(sorted_ids, minlength=self.num_experts)[:self.num_experts].int() + expert_offsets = self._expert_offsets_buf + expert_offsets.zero_() + expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0) + + padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128 + padded_expert_offsets = self._padded_expert_offsets_buf + padded_expert_offsets.zero_() + padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0) + + # Compute padded_dst (same as run()) + row_indices = self._row_indices_buf[:num_slots] + expert_assign = torch.searchsorted( + expert_offsets[1:], row_indices, right=True + ).clamp(max=self.num_experts - 1) + local_row = row_indices - expert_offsets[expert_assign] + padded_dst = padded_expert_offsets[expert_assign] + local_row + + # Scatter x_fp4 into padded layout + padded_x_fp4 = self._shared_bufs['hidden_fp4'] + padded_x_fp4.view(torch.uint8).zero_() + padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8) + + l1_scale_a = self._assemble_scales_cudagraph_safe( + slot_x_sf, expert_offsets[:self.num_experts + 1], + padded_expert_offsets, + self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1 + ) + l1_gsa = torch.full((self.num_experts,), l1_gs, dtype=torch.float32, device=device) + + l1_out = run_nvfp4_grouped_gemm( + mat_a=padded_x_fp4, mat_b=self._l1_mat_b, + scale_a=l1_scale_a, scale_b=self._l1_scale_b, + expert_offsets=padded_expert_offsets[1:self.num_experts + 1], + global_scale_a=l1_gsa, global_scale_b=self._l1_gsb, + ) + + # Extract real token outputs + l1_out_real = l1_out[padded_dst] + + # L2: get exact gs from SiLU(gate)*up + # De-interleave L1 output: with interleaved weights, L1 GEMM + # output has [gate]*4, [up]*4 pattern. De-interleave before splitting. + l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0] + gate = l1_deil[:, :self.intermediate_size] + up = l1_deil[:, self.intermediate_size:] + gate_silu = torch.nn.functional.silu(gate) + if self._swiglu_limit is not None: + gate_silu = gate_silu.clamp(max=self._swiglu_limit) + up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit) + activated = gate_silu * up + _, _, l2_gs = quantize_to_nvfp4(activated) + + self._l1_activation_global_scale = l1_gs + self._l2_activation_global_scale = l2_gs + + + + def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None): + """Forward: route tokens to experts, GEMM, combine. + + Uses torch.library.custom_op (nvfp4::moe_gemm) so torch.compile + treats this as an opaque op. The custom op calls _run_impl internally. + """ + if not hasattr(self, '_runner_id'): + self._runner_id = register_runner(self) + return nvfp4_moe_gemm( + hidden_states, topk_weights, topk_ids, + self._runner_id, self.hidden_size, + ) + + def _run_impl(self, hidden_states, topk_weights, topk_ids, expert_indices=None): + """Run the NVFP4 MoE forward pass. + + Handles global→local expert ID remapping for expert parallelism. + Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes. + + Each expert's slots are padded to multiples of 128 for the GEMM. + expert_offsets is [0, padded_e0, padded_e0+padded_e1, ...]. + scale_a is produced at those same offsets. + """ + num_tokens = hidden_states.shape[0] + top_k = topk_ids.shape[1] + device = hidden_states.device + + self._ensure_stacked() + + # -- Remap global expert IDs to local IDs -- + local_ids = topk_ids - self.experts_start_idx + local_mask = (local_ids >= 0) & (local_ids < self.num_experts) + safe_ids = local_ids.clamp(0, self.num_experts - 1) + safe_weights = topk_weights * local_mask.float() + + # -- Build slot mapping -- + flat_ids = safe_ids.reshape(-1) + flat_weights = safe_weights.reshape(-1) + num_slots = num_tokens * top_k + token_indices = self._token_indices[:num_slots] + + sort_idx = flat_ids.argsort(stable=True) + sorted_ids = flat_ids[sort_idx] + sorted_weights = flat_weights[sort_idx] + sorted_token_ids = token_indices[sort_idx] + + # Expert offsets (real token counts) + tokens_per_expert = torch.bincount(sorted_ids, minlength=self.num_experts)[:self.num_experts].int() + expert_offsets = self._expert_offsets_buf + expert_offsets.zero_() + expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0) + + # Pad each expert to 128-row alignment (GPU-only computation) + padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128 + padded_expert_offsets = self._padded_expert_offsets_buf + padded_expert_offsets.zero_() + padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0) + total_padded_slots = padded_expert_offsets[self.num_experts] + + # -- Gather hidden states into slot order, compute padded_dst -- + slot_hidden = hidden_states[sorted_token_ids] + row_indices = self._row_indices_buf[:num_slots] + expert_assign = torch.searchsorted( + expert_offsets[1:], row_indices, right=True + ).clamp(max=self.num_experts - 1) + local_row = row_indices - expert_offsets[expert_assign] + padded_dst = padded_expert_offsets[expert_assign] + local_row + + # === L1: gate + up === + # Fused amax + quantize: single kernel, zero CPU-GPU syncs. + # Computes amax on GPU → derives gsa → quantizes to NVFP4. + # gsa written to GPU buffer for GEMM global_scale_a. + if getattr(self, '_use_runtime_gsa', False): + from dsv4.ops.quantize import quantize_nvfp4_gpu_fused + slot_x_fp4, slot_x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(slot_hidden) + self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU, no sync + else: + slot_x_fp4, slot_x_sf = quantize_nvfp4_gpu( + slot_hidden, self._l1_activation_global_scale + ) + # Scatter x_fp4 into padded layout for the GEMM + # Must scatter as uint8 (float4_e2m1fn_x2 doesn't support index_put) + padded_x_fp4 = self._shared_bufs['hidden_fp4'] + padded_x_fp4.view(torch.uint8).zero_() + padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8) + + l1_scale_a = self._assemble_scales_cudagraph_safe( + slot_x_sf, expert_offsets[:self.num_experts + 1], + padded_expert_offsets, + self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1 + ) + l1_gsa = self._l1_gsa_buf # already filled by GPU compute (no .fill_ needed) + + if self._fused_swiglu: + # === Fused L1 GEMM + SwiGLU in kernel registers === + l1_out = run_fused_swiglu_grouped_gemm( + mat_a=padded_x_fp4, mat_b=self._l1_mat_b, + scale_a=l1_scale_a, scale_b=self._l1_scale_b, + expert_offsets=padded_expert_offsets[1:self.num_experts + 1], + global_scale_a=l1_gsa, global_scale_b=self._l1_gsb, + swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0, + ) + l1_out_real = l1_out[padded_dst] + # Fused deinterleave + amax + quantize: zero CPU syncs. + # Computes gsa from de-interleaved SwiGLU output on GPU, + # quantizes in the same kernel. Writes gsa to GPU buffer. + if getattr(self, '_use_runtime_gsa', False): + from dsv4.ops.quantize import deinterleave_amax_quantize_nvfp4_fused + slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = deinterleave_amax_quantize_nvfp4_fused( + l1_out_real, self.intermediate_size) + self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync + else: + slot_l2_x_fp4, slot_l2_x_sf = deinterleave_quantize_nvfp4_cuda( + l1_out_real, self.intermediate_size, self._l2_activation_global_scale + ) + else: + # === Non-fused L1 GEMM + PyTorch SiLU(gate)*up === + l1_out = run_nvfp4_grouped_gemm( + mat_a=padded_x_fp4, mat_b=self._l1_mat_b, + scale_a=l1_scale_a, scale_b=self._l1_scale_b, + expert_offsets=padded_expert_offsets[1:self.num_experts + 1], + global_scale_a=l1_gsa, global_scale_b=self._l1_gsb, + ) + l1_out_real = l1_out[padded_dst] + l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0] + gate = l1_deil[:, :self.intermediate_size] + up = l1_deil[:, self.intermediate_size:] + gate_silu = torch.nn.functional.silu(gate) + if self._swiglu_limit is not None: + gate_silu = gate_silu.clamp(max=self._swiglu_limit) + up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit) + activated = gate_silu * up + + # Compute runtime gsa for L2 from activated output (non-fused path) + # Fused amax + quantize: zero CPU syncs. + if not self._fused_swiglu and getattr(self, '_use_runtime_gsa', False): + from dsv4.ops.quantize import quantize_nvfp4_gpu_fused + slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(activated) + self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync + elif not self._fused_swiglu: + slot_l2_x_fp4, slot_l2_x_sf = quantize_nvfp4_gpu( + activated, self._l2_activation_global_scale + ) + padded_activated_fp4 = self._shared_bufs['activated_fp4'] + padded_activated_fp4.view(torch.uint8).zero_() + padded_activated_fp4.view(torch.uint8)[padded_dst] = slot_l2_x_fp4.view(torch.uint8) + + l2_scale_a = self._assemble_scales_cudagraph_safe( + slot_l2_x_sf, expert_offsets[:self.num_experts + 1], + padded_expert_offsets, + self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2 + ) + l2_gsa = self._l2_gsa_buf # already filled by GPU compute (no .fill_ needed) + + l2_out = run_nvfp4_grouped_gemm( + mat_a=padded_activated_fp4, mat_b=self._l2_mat_b, + scale_a=l2_scale_a, scale_b=self._l2_scale_b, + expert_offsets=padded_expert_offsets[1:self.num_experts + 1], + global_scale_a=l2_gsa, global_scale_b=self._l2_gsb, + ) + + l2_out_real = l2_out[padded_dst] + + # === Scatter -> final output === + y = self._output_buf[:num_tokens] + y.zero_() + weighted_out = l2_out_real * sorted_weights.unsqueeze(1).to(l2_out_real.dtype) + y.scatter_add_( + 0, + sorted_token_ids.unsqueeze(1).expand(-1, self.hidden_size), + weighted_out, + ) + + return y diff --git a/dsv4/layers/norm.py b/dsv4/_archive/layers/norm.py similarity index 100% rename from dsv4/layers/norm.py rename to dsv4/_archive/layers/norm.py diff --git a/dsv4/_archive/layers/router.py b/dsv4/_archive/layers/router.py new file mode 100644 index 00000000..fbdf3db6 --- /dev/null +++ b/dsv4/_archive/layers/router.py @@ -0,0 +1,345 @@ +"""DSV4 Router — token-to-expert assignment. + +Two routing modes that share an output shape: + - 'dense': sqrt(softplus(X @ W_gate)) + per-expert bias, top-k selection. + Used by MoE layers 3+ (the bulk of the network). + - 'hash': deterministic per-token-ID lookup, uniform weights. + Used by the first 3 MoE layers per DSV4 §2.1. + +Both modes produce (topk_weights, topk_ids) suitable for direct +consumption by Nvfp4MoE.run(). + +CUDA-graph-compatible: pre-allocated buffers, no CPU-GPU syncs. +Selection between modes is by layer_idx at construction time — +the kernel path is fixed once the Router is built so the dispatch +is constant-folded by torch.compile. +""" + +from __future__ import annotations +from typing import Optional, Literal +import torch + +from dsv4.ops.router import ( + register_router, + dense_router_op, + hash_router_op, +) + + +RouterMode = Literal["dense", "hash"] + + +class Router: + """DSV4 expert router. + + Per the DeepSeek-V4 paper (§2.1): + - Affinity activation is sqrt(softplus(·)), replacing V3's sigmoid(·). + - Auxiliary-loss-free strategy: a learned per-expert bias (loaded + from checkpoint, frozen at inference) is added to the activation + for SELECTION only. The actual gating weight applied to expert + outputs uses the UNBIASED activation. + - First 3 MoE layers use Hash routing (Roller et al. 2021): a + precomputed [vocab_size, k] LUT mapping token IDs to expert IDs. + No gate GEMM is performed. + - Sequence-wise balance loss is training-only; not applied here. + + Parameters + ---------- + hidden_size : int + Model hidden dimension. Must match W_gate's K dimension. + num_experts : int + Total routed experts (Flash: 256, Pro: 384). Shared experts are + handled separately by Nvfp4SharedExpert. + top_k : int + Experts activated per token. DSV4 uses 6. + routed_scaling_factor : float + Post-renormalization scale on gating weights. DSV3 used 2.5; + verify against the V4 checkpoint config — may be per-layer. + mode : {'dense', 'hash'} + Routing strategy. Decided at construction; cannot change at runtime. + vocab_size : int, optional + Required when mode='hash'. The LUT is [vocab_size, top_k] int32. + max_num_tokens : int + Upper bound on N for pre-allocated buffer sizing. + device : str + CUDA device. + """ + + def __init__( + self, + hidden_size: int, + num_experts: int, + top_k: int = 6, + routed_scaling_factor: float = 2.5, + *, + mode: RouterMode, + vocab_size: Optional[int] = None, + max_num_tokens: int = 8192, + device: str = "cuda", + ): + if mode == "hash" and vocab_size is None: + raise ValueError("vocab_size is required when mode='hash'") + if mode not in ("dense", "hash"): + raise ValueError(f"unknown router mode: {mode!r}") + + self.hidden_size = hidden_size + self.num_experts = num_experts + self.top_k = top_k + self.routed_scaling_factor = routed_scaling_factor + self.mode = mode + self.vocab_size = vocab_size + self.max_num_tokens = max_num_tokens + self.device = device + + # ---- Parameters (filled by load_weights / finalize_weights) ---- + # Dense mode — fused NVFP4 kernel (single-kernel, preferred): + # gate_weight: raw NVFP4 gate weight tensor [K_packed, E_packed] uint8 + # gate_weight_scale: weight scale [K_sf, E_sf] FP8 E4M3 + # gate_ws2: weight_scale_2 (global scale base) + # gate_input_scale: input_scale (activation global scale base) + # Dense mode — 2-kernel NVFP4 path (fallback): + # gate_lin: Nvfp4Linear for the gate projection + # Dense mode — BF16 fallback: + # W_gate: BF16 weight for cuBLAS when NVFP4 scales not available + # Hash mode: + # hash_lut: [vocab_size, top_k] int32 — precomputed expert IDs. + self.gate_weight = None # Raw NVFP4 weight for fused kernel + self.gate_weight_scale = None # FP8 E4M3 scale for fused kernel + self.gate_ws2 = None # weight_scale_2 for fused kernel + self.gate_input_scale = None # input_scale for fused kernel + self.gate_lin = None # Nvfp4Linear for 2-kernel NVFP4 path + self.W_gate: Optional[torch.Tensor] = None # BF16 fallback + self.e_bias: Optional[torch.Tensor] = None + self.hash_lut: Optional[torch.Tensor] = None + + # ---- Pre-allocated output buffers (cudagraph-safe) ---- + self._topk_weights_buf: Optional[torch.Tensor] = None + self._topk_ids_buf: Optional[torch.Tensor] = None + + # Runner ID assigned on first call (see custom_op pattern). + self._runner_id: Optional[int] = None + + # ------------------------------------------------------------------ + # Weight loading + # ------------------------------------------------------------------ + def load_weights( + self, + W_gate: Optional[torch.Tensor] = None, + e_bias: Optional[torch.Tensor] = None, + hash_lut: Optional[torch.Tensor] = None, + ) -> None: + """Populate router parameters from a checkpoint shard. + + Dense mode expects (W_gate, e_bias). Hash mode expects (hash_lut). + Mismatches with self.mode raise immediately — these errors are + nearly always loader bugs and silent acceptance would mask them. + """ + if self.mode == "dense": + if e_bias is None: + raise ValueError("dense router needs e_bias") + assert e_bias.shape == (self.num_experts,), \ + f"e_bias shape {tuple(e_bias.shape)} != ({self.num_experts},)" + self.e_bias = e_bias.to(device=self.device, dtype=torch.float32) + if W_gate is not None: + self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16) + # gate_lin is set separately via load_nvfp4_gate() + else: # hash + if hash_lut is None: + raise ValueError("hash router needs hash_lut") + assert hash_lut.shape == (self.vocab_size, self.top_k), \ + f"hash_lut shape {tuple(hash_lut.shape)} != " \ + f"{(self.vocab_size, self.top_k)}" + assert (hash_lut >= 0).all() and (hash_lut < self.num_experts).all(), \ + "hash_lut contains out-of-range expert IDs" + self.hash_lut = hash_lut.to(device=self.device, dtype=torch.int32) + + def load_nvfp4_gate(self, gate_lin) -> None: + """Set the NVFP4 gate linear layer (2-kernel path). + + Called by the single_shot after constructing the Nvfp4Linear + from checkpoint NVFP4 scales. When set, _run_dense_impl uses + the production NVFP4 GEMM path instead of BF16 cuBLAS. + """ + self.gate_lin = gate_lin + + def load_nvfp4_fused_gate(self, gate_weight, gate_weight_scale, + gate_ws2, gate_input_scale, + gate_weight_bf16=None) -> None: + """Set raw NVFP4 gate tensors and create Nvfp4Linear for production GEMM.""" + self.gate_weight = gate_weight.to(device=self.device) + self.gate_weight_scale = gate_weight_scale.to(device=self.device) + self.gate_ws2 = gate_ws2.to(device=self.device) if gate_ws2 is not None else None + self.gate_input_scale = gate_input_scale.to(self.device) + + # Create Nvfp4Linear from BF16 weight (handles layout correctly) + if gate_weight_bf16 is not None: + from dsv4.layers.linear import Nvfp4Linear + from dsv4.ops.quantize import quantize_to_nvfp4 + E = gate_weight_bf16.shape[0] + gate_lin = Nvfp4Linear(in_features=self.hidden_size, out_features=E, device=self.device) + g_fp4, g_sf, g_gs = quantize_to_nvfp4(gate_weight_bf16.bfloat16().to(self.device)) + gate_lin.fp4 = [g_fp4] + gate_lin.sf = [g_sf] + gate_lin.gs = [g_gs] + ws2_val = gate_ws2.float().item() if gate_ws2.numel() == 1 else gate_ws2.float().mean().item() + gate_lin.ws2 = [torch.tensor([ws2_val], device=self.device, dtype=torch.float32)] + gate_lin._activation_global_scale = gate_input_scale.float().item() if gate_input_scale.numel() == 1 else gate_input_scale.float().mean().item() + gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow + gate_lin.finalize_weights() + self.gate_lin = gate_lin + + def finalize_weights(self) -> None: + """Allocate output buffers and JIT-compile the routing kernel. + + Mirrors the finalize_weights() pattern in Nvfp4Linear: a one-time + setup step called after all parameters are loaded. Triggers + kernel compilation so the first forward isn't paying that cost. + """ + self._topk_weights_buf = torch.empty( + self.max_num_tokens, self.top_k, + dtype=torch.float32, device=self.device, + ) + self._topk_ids_buf = torch.empty( + self.max_num_tokens, self.top_k, + dtype=torch.int32, device=self.device, + ) + + # Eager JIT — dispatcher knows our mode and triggers the right + # kernel's compile path. See dsv4/ops/router.py. + from dsv4.ops.router import warmup_router_compilation + warmup_router_compilation(self) + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + def __call__( + self, + hidden_states: torch.Tensor, + token_ids: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Produce (topk_weights, topk_ids) for downstream Nvfp4MoE. + + Parameters + ---------- + hidden_states : Tensor [N, hidden_size] bfloat16 + Required for dense mode. Ignored for hash mode (kept in the + signature so the call site is mode-agnostic). + token_ids : Tensor [N] int32, optional + Required for hash mode. Ignored for dense mode. + + Returns + ------- + topk_weights : Tensor [N, top_k] float32 + topk_ids : Tensor [N, top_k] int32 + + Notes + ----- + Both outputs are views into pre-allocated buffers — do not retain + them across router calls. Nvfp4MoE consumes them immediately, + which matches its existing contract. + """ + if self._topk_weights_buf is None: + raise RuntimeError("Router.finalize_weights() not called") + + if self.mode == "dense": + if hidden_states is None: + raise ValueError("dense router requires hidden_states") + return self._run_dense(hidden_states) + else: + if token_ids is None: + raise ValueError("hash router requires token_ids") + return self._run_hash(token_ids) + + # ------------------------------------------------------------------ + # Mode-specific dispatch — each routes through a torch.library.custom_op + # so Dynamo / torch.compile treats the kernel as opaque. + # ------------------------------------------------------------------ + def _run_dense(self, hidden_states: torch.Tensor): + if self._runner_id is None: + self._runner_id = register_router(self) + return dense_router_op( + hidden_states, + self._runner_id, + self.num_experts, + self.top_k, + ) + + def _run_hash(self, token_ids: torch.Tensor): + if self._runner_id is None: + self._runner_id = register_router(self) + return hash_router_op( + token_ids, + self._runner_id, + self.top_k, + ) + + # ------------------------------------------------------------------ + # Called by the custom_op dispatch in dsv4/ops/router.py — not by user code. + # ------------------------------------------------------------------ + def _run_dense_impl(self, hidden_states: torch.Tensor): + """Hot-path: fused NVFP4, 2-kernel NVFP4, or BF16 fallback. + + Priority: + 1. Fused NVFP4 kernel (single-kernel GEMM + router epilogue) + 2. 2-kernel NVFP4 path (Nvfp4Linear + activation_topk) + 3. BF16 cuBLAS fallback + """ + N = hidden_states.shape[0] + out_w = self._topk_weights_buf[:N] + out_ids = self._topk_ids_buf[:N] + if self.gate_lin is not None: + # NVFP4 production GEMM path (proven Nvfp4Linear) + from dsv4.kernels.router import dense_router_dispatch_nvfp4 + dense_router_dispatch_nvfp4( + hidden_states=hidden_states, + gate_lin=self.gate_lin, + e_bias=self.e_bias, + routed_scaling_factor=self.routed_scaling_factor, + top_k=self.top_k, + out_weights=out_w, + out_ids=out_ids, + ) + elif self.gate_weight is not None: + # Fused NVFP4 path (gate_lin was not created) + # Fall back to BF16 + from dsv4.kernels.router import dense_router_dispatch + dense_router_dispatch( + hidden_states=hidden_states, + W_gate=self.W_gate, + e_bias=self.e_bias, + routed_scaling_factor=self.routed_scaling_factor, + top_k=self.top_k, + out_weights=out_w, + out_ids=out_ids, + ) + else: + from dsv4.kernels.router import dense_router_dispatch + dense_router_dispatch( + hidden_states=hidden_states, + W_gate=self.W_gate, + e_bias=self.e_bias, + routed_scaling_factor=self.routed_scaling_factor, + top_k=self.top_k, + out_weights=out_w, + out_ids=out_ids, + ) + return out_w, out_ids + + def _run_hash_impl(self, token_ids: torch.Tensor): + """Hot-path entry into the hash gather kernel. + + Implementation lives in dsv4/kernels/cuda/hash_router.cu via the + wrapper in dsv4/ops/router.py. + """ + from dsv4.kernels.router import hash_router_dispatch + N = token_ids.shape[0] + out_w = self._topk_weights_buf[:N] + out_ids = self._topk_ids_buf[:N] + hash_router_dispatch( + token_ids=token_ids, + hash_lut=self.hash_lut, + top_k=self.top_k, + out_weights=out_w, # filled with 1/k + out_ids=out_ids, + ) + return out_w, out_ids diff --git a/dsv4/_archive/layers/shared_expert.py b/dsv4/_archive/layers/shared_expert.py new file mode 100644 index 00000000..eb824744 --- /dev/null +++ b/dsv4/_archive/layers/shared_expert.py @@ -0,0 +1,409 @@ +"""CuTeDSL Shared Expert Pipeline + +NVFP4 inference for DeepSeek V4 shared experts. +Uses ScaledGroupedGemmKernel with num_groups=1. + +Pipeline: + 1. Quantize activation: BF16 → NVFP4 (using warmup gs) + 2. L1 GEMM: NVFP4_act × NVFP4_weight(gate_up) → BF16 + 3. SiLU(gate) * up → BF16 + 4. Re-quantize: BF16 → NVFP4 (using warmup gs) + 5. L2 GEMM: NVFP4_act × NVFP4_weight(down) → BF16 + +Unlike MoE, there's no routing, no scatter, no expert offsets. +All tokens go through the same expert (the shared expert). +Scale assembly is just: quantize activation → pad to 128-row alignment → Blackwell swizzle. + +CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs, +no dynamic shapes. Padding rows are zeros that contribute nothing to GEMM output. +""" + +import torch + +from dsv4.ops.quantize import ( + quantize_activation_nvfp4, + quantize_to_nvfp4, +) +from dsv4.ops.layouts import ( + make_b_k_major, + interleave_l1_weights, + deinterleave_l1_weights, +) +from dsv4.ops.gemm_runner import ( + run_nvfp4_grouped_gemm, + run_fused_swiglu_grouped_gemm, +) +from dsv4.ops.quantize import quantize_nvfp4_gpu_fused +from dsv4.kernels.gemm.grouped import ( + ceil_div as cutedsl_ceil_div, + pad_and_swizzle_single, +) + + +class _SharedExpertApply(torch.autograd.Function): + """Custom autograd function to make CuTeDSL runner opaque to torch.compile.""" + @staticmethod + def forward(ctx, runner, hidden_states): + return runner._run_impl(hidden_states) + + +class Nvfp4SharedExpert: + """NVFP4 shared expert runner using CuTeDSL GEMM (num_groups=1). + + CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs. + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + max_num_tokens: int = 8192, + device: str = "cuda", + swiglu_limit: float = 10.0, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.max_num_tokens = max_num_tokens + self.device = device + self.swiglu_limit = swiglu_limit + self._fused_swiglu = False # Set via set_fused_swiglu() + + # Weights (set after construction, then call finalize_weights) + self.l1_fp4 = None + self.l1_sf = None + self.l1_gs = None + self.l2_fp4 = None + self.l2_sf = None + self.l2_gs = None + # weight_scale_2 per layer (scalar, folded into global_scale_b in finalize_weights) + self.l1_ws2 = None + self.l2_ws2 = None + + # Processed weights (set by finalize_weights) + self._l1_mat_b = None + self._l2_mat_b = None + self._l1_scale_b = None + self._l2_scale_b = None + self._l1_gsb = None + self._l2_gsb = None + + # Activation global scales (set by compute_activation_global_scales) + self._l1_activation_global_scale = 1.0 / (6.0 * 448.0) + self._l2_activation_global_scale = 1.0 / (6.0 * 448.0) + + # Pre-allocated cudagraph buffers (set in _allocate_buffers) + self._padded_x_fp4_buf_l1 = None + self._padded_x_sf_buf_l1 = None + self._padded_x_fp4_buf_l2 = None + self._padded_x_sf_buf_l2 = None + self._l1_gsa_buf = None + self._l2_gsa_buf = None + self._expert_offsets_buf = None + self._buffers_allocated = False + + def set_swiglu_limit(self, limit: float): + self.swiglu_limit = limit + + def set_fused_swiglu(self, enabled: bool): + """Enable fused L1 GEMM + SwiGLU kernel (1-group variant of MoE fused kernel).""" + self._fused_swiglu = enabled + + def finalize_weights(self): + """Process weights for CuTeDSL GEMM. Must be called after setting l1/l2 weights.""" + # Convert uint8 checkpoint weights to float4_e2m1fn_x2 view + l1_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.l1_fp4] + l2_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.l2_fp4] + # Checkpoint weight is (N_packed, K_packed), make_b_k_major expects (E, K_packed, N_packed) + l1_stacked = torch.stack(l1_view).permute(0, 2, 1).contiguous() + l2_stacked = torch.stack(l2_view).permute(0, 2, 1).contiguous() + # P1: Interleave L1 gate/up weights for fused SwiGLU kernel compatibility. + # The fused kernel's SwiGLU epilogue expects granularity-8 interleaved gate/up. + # The unfused path (if _fused_swiglu=False) deinterleaves the GEMM output before splitting. + if self._fused_swiglu: + l1_stacked = interleave_l1_weights(l1_stacked, granularity_bf16=8) + # Stack weights and convert to K-major + self._l1_mat_b = make_b_k_major(l1_stacked) # (1, K_packed, N_packed) + self._l2_mat_b = make_b_k_major(l2_stacked) + # Checkpoint scale is (N_packed, K_sf) — use assemble_raw_scales_2d3d_3d_side + from dsv4.ops.layouts import assemble_raw_scales_2d3d_3d_side + self._l1_scale_b = assemble_raw_scales_2d3d_3d_side(self.l1_sf) + self._l2_scale_b = assemble_raw_scales_2d3d_3d_side(self.l2_sf) + self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device) + self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device) + + # Fold weight_scale_2 into global_scale_b + # gsb = input_scale * weight_scale_2 + if self.l1_ws2 is not None: + for i, ws2 in enumerate(self.l1_ws2): + if ws2 is not None: + self._l1_gsb[i] *= ws2.float().item() + if self.l2_ws2 is not None: + for i, ws2 in enumerate(self.l2_ws2): + if ws2 is not None: + self._l2_gsb[i] *= ws2.float().item() + + # Free raw weights + self.l1_fp4 = None + self.l1_sf = None + self.l1_gs = None + self.l2_fp4 = None + self.l2_sf = None + self.l2_gs = None + self.l1_ws2 = None + self.l2_ws2 = None + + def _allocate_buffers(self): + """Pre-allocate all buffers at max size for cudagraph compatibility.""" + max_rows = cutedsl_ceil_div(self.max_num_tokens, 128) * 128 # pad to 128 + + # L1: hidden_size packed, L2: intermediate_size packed + self._padded_x_fp4_buf_l1 = torch.zeros( + max_rows, self.hidden_size // 2, dtype=torch.uint8, device=self.device + ).view(torch.float4_e2m1fn_x2) + self._padded_x_fp4_buf_l2 = torch.zeros( + max_rows, self.intermediate_size // 2, dtype=torch.uint8, device=self.device + ).view(torch.float4_e2m1fn_x2) + + # Padded scale buffers (need same padded dimensions as pad_and_swizzle_single produces) + K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16) + padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4 + K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16) + padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4 + self._padded_x_sf_buf_l1 = torch.zeros( + max_rows, padded_cols_l1, dtype=torch.float16, device=self.device + ).to(torch.float8_e4m3fn) + self._padded_x_sf_buf_l2 = torch.zeros( + max_rows, padded_cols_l2, dtype=torch.float16, device=self.device + ).to(torch.float8_e4m3fn) + + # Global scale buffers + self._l1_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device) + self._l2_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device) + + # Expert offsets for num_groups=1: just [num_tokens_padded] + # The GEMM expects expert_offsets as (num_experts,) cumulative offsets + # For 1 expert: offsets = [num_tokens] (just one element) + self._expert_offsets_buf = torch.zeros(1, dtype=torch.int32, device=self.device) + + self._buffers_allocated = True + + def _ensure_initialized(self): + """Lazily initialize stacked weights and buffers.""" + if self._l1_mat_b is None: + self.finalize_weights() + if not self._buffers_allocated: + self._allocate_buffers() + + def _assemble_scales_single_group(self, x_sf, num_tokens, padded_x_sf_buf): + """Assemble 2D-side activation scales for num_groups=1. + + For a single group, scale assembly is just: + 1. Copy x_sf into a correctly-sized buffer (padded to 128 rows, 4 cols) + 2. Apply pad_and_swizzle_single (Blackwell swizzle) + 3. Reshape back to 2D (kernel expects 2D scale_a) + + The padded buffer must be sized exactly for 128-aligned num_tokens, + NOT the max_num_tokens buffer (which would be way too large). + """ + num_rows, num_cols = x_sf.shape + padded_rows = cutedsl_ceil_div(num_rows, 128) * 128 + padded_cols = cutedsl_ceil_div(num_cols, 4) * 4 + + # Use a temp buffer sized for this exact token count + buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn) + buf[:num_rows, :num_cols] = x_sf + swizzled_flat = pad_and_swizzle_single(buf) + return swizzled_flat.reshape(padded_rows, padded_cols) + + def compute_activation_global_scales(self, hidden_states_sample): + """Compute activation global scales from a warmup forward pass. + + Called BEFORE cudagraph capture. Uses quantize_to_nvfp4 to get + the exact global_scale from the data, then runs L1 to compute + L2 gs from actual SiLU(gate)*up output. + """ + self._ensure_initialized() + + with torch.no_grad(): + # L1: exact gs from quantize_to_nvfp4 + _, _, l1_gs = quantize_to_nvfp4(hidden_states_sample) + self._l1_activation_global_scale = l1_gs + + # Run L1 GEMM to get intermediate for L2 gs + num_tokens = hidden_states_sample.shape[0] + l1_out = self._run_l1(hidden_states_sample) + if l1_out is not None and not torch.isnan(l1_out).any(): + gate = l1_out[:, :self.intermediate_size] + up = l1_out[:, self.intermediate_size:] + if self.swiglu_limit is not None: + gate = gate.clamp(max=self.swiglu_limit) + up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit) + activated = torch.nn.functional.silu(gate) * up + _, _, l2_gs = quantize_to_nvfp4(activated) + self._l2_activation_global_scale = l2_gs + + + + def _run_l1_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Fused L1 GEMM + SwiGLU + clamp — single kernel launch (1-group variant of MoE fused kernel).""" + num_tokens = hidden_states.shape[0] + x_bf16 = hidden_states.reshape(num_tokens, self.hidden_size) + + # Quantize activation to NVFP4 (fused amax + quantize) + if getattr(self, '_use_runtime_gsa', False): + from dsv4.ops.quantize import quantize_nvfp4_gpu_fused + x_fp4, x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(x_bf16) + self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU + else: + from dsv4.ops.quantize import quantize_activation_nvfp4 + x_fp4, x_sf = quantize_activation_nvfp4(x_bf16, self._l1_activation_global_scale) + + # Padded buffer setup for 1-group GEMM + padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 + padded_x_fp4 = self._padded_x_fp4_buf_l1 + padded_x_fp4.view(torch.uint8).zero_() + padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8) + + # Assemble A-side scales + scale_a = self._assemble_scales_single_group(x_sf, num_tokens, self._padded_x_sf_buf_l1) + + # Expert offsets: [padded_rows] for 1 group (int32, pre-allocated) + expert_offsets = self._expert_offsets_buf + expert_offsets.fill_(padded_rows) + + # Global scales — GPU-computed gsa already in _l1_gsa_buf (no CPU sync) + gsa = self._l1_gsa_buf + + # Run fused GEMM + SwiGLU + l1_out = run_fused_swiglu_grouped_gemm( + mat_a=padded_x_fp4, + mat_b=self._l1_mat_b, + scale_a=scale_a, + scale_b=self._l1_scale_b, + expert_offsets=expert_offsets, + global_scale_a=gsa, + global_scale_b=self._l1_gsb, + swiglu_limit=self.swiglu_limit if self.swiglu_limit is not None else 0.0, + ) + l1_out_real = l1_out[:num_tokens] # (num_tokens, 2*intermediate) BF16, interleaved [silu(gate), silu(gate)*up] + # Deinterleave to separate gate and up, then take up half (SwiGLU result) + l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0] # (num_tokens, 2*intermediate) deinterleaved + intermediate = l1_deil[:, self.intermediate_size:] # up half = silu(gate)*up + return intermediate # (num_tokens, intermediate_size) BF16 + + def _run_l1(self, hidden_states: torch.Tensor) -> torch.Tensor: + """L1 GEMM: activation × gate_up_weight → BF16.""" + num_tokens = hidden_states.shape[0] + padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 + + # Fused amax + quantize: zero CPU syncs. + if getattr(self, '_use_runtime_gsa', False): + from dsv4.ops.quantize import quantize_nvfp4_gpu_fused + x_fp4, x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(hidden_states) + self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU, no sync + else: + x_fp4, x_sf = quantize_activation_nvfp4( + hidden_states, self._l1_activation_global_scale + ) + + # Scatter x_fp4 into padded buffer + padded_x_fp4 = self._padded_x_fp4_buf_l1 + padded_x_fp4.view(torch.uint8).zero_() + padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8) + + # Assemble A-side scales + scale_a = self._assemble_scales_single_group(x_sf, num_tokens, self._padded_x_sf_buf_l1) + + # Expert offsets: [padded_rows] for 1 group + expert_offsets = self._expert_offsets_buf + expert_offsets.fill_(padded_rows) + + # Global scales — GPU-computed gsa already in _l1_gsa_buf (no CPU sync) + gsa = self._l1_gsa_buf + + # Run GEMM + out = run_nvfp4_grouped_gemm( + mat_a=padded_x_fp4, + mat_b=self._l1_mat_b, + scale_a=scale_a, + scale_b=self._l1_scale_b, + expert_offsets=expert_offsets, + global_scale_a=gsa, + global_scale_b=self._l1_gsb, + ) + + # Extract real token outputs + return out[:num_tokens] + + def _run_l2(self, intermediate: torch.Tensor) -> torch.Tensor: + """L2 GEMM: intermediate × down_weight → BF16.""" + num_tokens = intermediate.shape[0] + padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 + + # Fused amax + quantize: zero CPU syncs. + if getattr(self, '_use_runtime_gsa', False): + from dsv4.ops.quantize import quantize_nvfp4_gpu_fused + x_fp4, x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(intermediate) + self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync + else: + x_fp4, x_sf = quantize_activation_nvfp4( + intermediate, self._l2_activation_global_scale + ) + + # Scatter into padded buffer + padded_x_fp4 = self._padded_x_fp4_buf_l2 + padded_x_fp4.view(torch.uint8).zero_() + padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8) + + # Assemble A-side scales + scale_a = self._assemble_scales_single_group(x_sf, num_tokens, self._padded_x_sf_buf_l2) + + # Expert offsets + expert_offsets = self._expert_offsets_buf + expert_offsets.fill_(padded_rows) + + # Global scales — GPU-computed gsa already in _l2_gsa_buf (no CPU sync) + gsa = self._l2_gsa_buf + + # Run GEMM + out = run_nvfp4_grouped_gemm( + mat_a=padded_x_fp4, + mat_b=self._l2_mat_b, + scale_a=scale_a, + scale_b=self._l2_scale_b, + expert_offsets=expert_offsets, + global_scale_a=gsa, + global_scale_b=self._l2_gsb, + ) + + return out[:num_tokens] + + def run(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Full shared expert forward: L1 → SiLU → L2 → output.""" + return _SharedExpertApply.apply(self, hidden_states) + + def _run_impl(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Actual implementation — called via custom autograd to be torch.compile-safe.""" + self._ensure_initialized() + + if self._fused_swiglu: + # P1: Fused L1 GEMM + SwiGLU + clamp in one kernel launch + intermediate = self._run_l1_fused(hidden_states) + else: + l1_out = self._run_l1(hidden_states) + if l1_out.shape[1] < 2 * self.intermediate_size: + print(f" WARNING: l1_out shape {l1_out.shape} < expected (N, {2*self.intermediate_size})", flush=True) + + gate = l1_out[:, :self.intermediate_size] + up = l1_out[:, self.intermediate_size:] + if torch.isnan(l1_out).any(): + print(f" SE L1 NaN: l1_out nan at {torch.isnan(l1_out).sum().item()} / {l1_out.numel()} positions, shape={l1_out.shape}", flush=True) + if torch.isnan(gate).any() or torch.isnan(up).any(): + print(f" SE gate nan={torch.isnan(gate).any().item()} up nan={torch.isnan(up).any().item()}", flush=True) + if self.swiglu_limit is not None: + gate = gate.clamp(max=self.swiglu_limit) + up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit) + intermediate = torch.nn.functional.silu(gate) * up + + output = self._run_l2(intermediate) + return output diff --git a/dsv4/loader/hf_checkpoint.py b/dsv4/_archive/loader/hf_checkpoint.py similarity index 100% rename from dsv4/loader/hf_checkpoint.py rename to dsv4/_archive/loader/hf_checkpoint.py diff --git a/dsv4/loader/layout_convert.py b/dsv4/_archive/loader/layout_convert.py similarity index 100% rename from dsv4/loader/layout_convert.py rename to dsv4/_archive/loader/layout_convert.py diff --git a/dsv4/model/dsv4.py b/dsv4/_archive/model/dsv4.py similarity index 100% rename from dsv4/model/dsv4.py rename to dsv4/_archive/model/dsv4.py diff --git a/dsv4/model/layer.py b/dsv4/_archive/model/layer.py similarity index 100% rename from dsv4/model/layer.py rename to dsv4/_archive/model/layer.py diff --git a/dsv4/model/layer_schedule.py b/dsv4/_archive/model/layer_schedule.py similarity index 100% rename from dsv4/model/layer_schedule.py rename to dsv4/_archive/model/layer_schedule.py diff --git a/dsv4/model/mtp.py b/dsv4/_archive/model/mtp.py similarity index 100% rename from dsv4/model/mtp.py rename to dsv4/_archive/model/mtp.py diff --git a/dsv4/_archive/ops/custom_ops.py b/dsv4/_archive/ops/custom_ops.py new file mode 100644 index 00000000..65ef96d2 --- /dev/null +++ b/dsv4/_archive/ops/custom_ops.py @@ -0,0 +1,138 @@ +"""torch.library.custom_op wrappers for CuTeDSL NVFP4 kernels. + +Dynamo (torch.compile fullgraph) cannot trace through CuTeDSL internals +(JIT compilation, cute.compile, etc.). By wrapping the runner calls in +torch.library.custom_op, Dynamo treats them as opaque black boxes. + +This is the correct approach per PyTorch's extensibility model: +- custom_op is the supported way to make Dynamo skip tracing +- autograd.Function does NOT work reliably with fullgraph mode +- The runner's _run_impl is already cudagraph-safe + +The registry pattern: custom ops can only take tensor/scalar arguments. +We store runners in a global dict keyed by integer ID, and pass the ID +as an int parameter. During Dynamo tracing, the fake impl returns a +correctly-shaped tensor without touching the runner. During execution, +the real impl looks up the runner and calls _run_impl. +""" + +import torch + +# --------------------------------------------------------------------------- +# Runner registry — maps integer IDs to runner objects +# --------------------------------------------------------------------------- +_next_runner_id = 0 +_runner_registry: dict[int, object] = {} + + +def register_runner(runner) -> int: + """Register a CuTeDSL runner and return its integer ID.""" + global _next_runner_id + rid = _next_runner_id + _next_runner_id += 1 + _runner_registry[rid] = runner + return rid + + +def get_runner(rid: int): + """Look up a runner by ID.""" + return _runner_registry[rid] + + +# --------------------------------------------------------------------------- +# NVFP4 Linear GEMM custom op (single linear layer) +# --------------------------------------------------------------------------- +@torch.library.custom_op("nvfp4::linear_gemm", mutates_args=()) +def nvfp4_linear_gemm( + x: torch.Tensor, + runner_id: int, + out_features: int, +) -> torch.Tensor: + """Opaque NVFP4 linear GEMM for torch.compile. + + Args: + x: (M, K) BF16 input + runner_id: integer key into the runner registry + out_features: output dimension (for shape inference) + Returns: + (M, out_features) BF16 output + """ + runner = get_runner(runner_id) + return runner._run_impl(x) + + +@nvfp4_linear_gemm.register_fake +def _(x, runner_id, out_features): + return torch.empty(x.shape[0], out_features, dtype=torch.bfloat16, device=x.device) + + +# --------------------------------------------------------------------------- +# NVFP4 MoE custom op (L1 + SiLU + L2 grouped GEMM) +# --------------------------------------------------------------------------- +@torch.library.custom_op("nvfp4::moe_gemm", mutates_args=()) +def nvfp4_moe_gemm( + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + runner_id: int, + hidden_size: int, +) -> torch.Tensor: + """Opaque NVFP4 MoE GEMM for torch.compile. + + Args: + hidden_states: (M, K) BF16 input + topk_weights: (M, top_k) float32 routing weights + topk_ids: (M, top_k) int32 expert IDs + runner_id: integer key into the runner registry + hidden_size: output dimension (for shape inference) + Returns: + (M, hidden_size) BF16 output + """ + runner = get_runner(runner_id) + return runner._run_impl(hidden_states, topk_weights, topk_ids) + + +@nvfp4_moe_gemm.register_fake +def _(hidden_states, topk_weights, topk_ids, runner_id, hidden_size): + return torch.empty( + hidden_states.shape[0], hidden_size, + dtype=torch.bfloat16, device=hidden_states.device, + ) + + +# --------------------------------------------------------------------------- +# DSV4 Sparse FMHA custom op (attention with SWA + sink bias) +# --------------------------------------------------------------------------- +@torch.library.custom_op("dsv4::sparse_fmha_with_swa", mutates_args=()) +def dsv4_sparse_fmha( + q: torch.Tensor, # (n_q_heads, T, hd) BF16 + k: torch.Tensor, # (n_kv_heads, N, hd) or (N, hd) BF16 + v: torch.Tensor, # same as k + sink_bias: torch.Tensor, # (n_q_heads,) FP32 — can be zeros if unused + scale: float, + swa_len: int, + is_causal: bool, + n_comp: int, +) -> torch.Tensor: + """Opaque DSV4 attention for torch.compile. + + Delegates to dsv4_attention with the appropriate flags. + sink_bias is always passed (use zeros when unused) to keep the + custom_op signature tensor-only for Dynamo compatibility. + """ + from dsv4.kernels.attention.production import dsv4_attention as _dsv4_attention + + # If sink_bias is all zeros and n_comp == 0, skip sink bias + has_sink = n_comp > 0 and sink_bias.abs().sum().item() > 0 + return _dsv4_attention( + q, k, v, scale=scale, + swa_len=swa_len if swa_len > 0 else None, + is_causal=is_causal, + n_comp=n_comp, + sink_bias=sink_bias if has_sink else None, + ) + + +@dsv4_sparse_fmha.register_fake +def _(q, k, v, sink_bias, scale, swa_len, is_causal, n_comp): + return torch.empty_like(q) diff --git a/dsv4/ops/rope.py b/dsv4/_archive/ops/rope.py similarity index 100% rename from dsv4/ops/rope.py rename to dsv4/_archive/ops/rope.py diff --git a/dsv4/_archive/ops/router.py b/dsv4/_archive/ops/router.py new file mode 100644 index 00000000..8d0c38d3 --- /dev/null +++ b/dsv4/_archive/ops/router.py @@ -0,0 +1,93 @@ +"""torch.library.custom_op wrappers and dispatch for the Router kernels. + +Mirrors the pattern in dsv4/ops/custom_ops.py: + - Routers are registered into an integer-keyed table. + - The custom_op takes the integer ID and tensor args only. + - Dynamo can't trace through the kernel; the op is opaque. +""" + +import torch +from dsv4.kernels.router import ( + dense_router_dispatch, # picks decode vs prefill internally + hash_router_dispatch, +) + +_next_router_id = 0 +_router_registry: dict[int, object] = {} + + +def register_router(router) -> int: + global _next_router_id + rid = _next_router_id + _next_router_id += 1 + _router_registry[rid] = router + return rid + + +def get_router(rid: int): + return _router_registry[rid] + + +def warmup_router_compilation(router) -> None: + """Trigger eager JIT compilation for the router's kernel path. + + Runs a dummy forward at max_num_tokens to compile the kernel for the + expected shape range. Caller already has the buffers allocated. + """ + if router.mode == "dense": + # Dummy forward at small N triggers decode-path compile. + # CuTeDSL fused kernel is WIP — falls through to prefill path. + dummy = torch.zeros( + 1, router.hidden_size, + dtype=torch.bfloat16, device=router.device, + ) + try: + router._run_dense_impl(dummy) + except Exception: + pass # CuTeDSL kernel not yet working; prefill path is fine + else: + dummy = torch.zeros(1, dtype=torch.int32, device=router.device) + router._run_hash_impl(dummy) + + +# ----- Dense router custom op ----- +@torch.library.custom_op("dsv4::dense_router", mutates_args=()) +def dense_router_op( + hidden_states: torch.Tensor, + router_id: int, + num_experts: int, + top_k: int, +) -> tuple[torch.Tensor, torch.Tensor]: + router = get_router(router_id) + return router._run_dense_impl(hidden_states) + + +@dense_router_op.register_fake +def _(hidden_states, router_id, num_experts, top_k): + N = hidden_states.shape[0] + device = hidden_states.device + return ( + torch.empty(N, top_k, dtype=torch.float32, device=device), + torch.empty(N, top_k, dtype=torch.int32, device=device), + ) + + +# ----- Hash router custom op ----- +@torch.library.custom_op("dsv4::hash_router", mutates_args=()) +def hash_router_op( + token_ids: torch.Tensor, + router_id: int, + top_k: int, +) -> tuple[torch.Tensor, torch.Tensor]: + router = get_router(router_id) + return router._run_hash_impl(token_ids) + + +@hash_router_op.register_fake +def _(token_ids, router_id, top_k): + N = token_ids.shape[0] + device = token_ids.device + return ( + torch.empty(N, top_k, dtype=torch.float32, device=device), + torch.empty(N, top_k, dtype=torch.int32, device=device), + ) diff --git a/dsv4/ops/topk.py b/dsv4/_archive/ops/topk.py similarity index 100% rename from dsv4/ops/topk.py rename to dsv4/_archive/ops/topk.py diff --git a/dsv4/ops/topk_select.py b/dsv4/_archive/ops/topk_select.py similarity index 100% rename from dsv4/ops/topk_select.py rename to dsv4/_archive/ops/topk_select.py diff --git a/dsv4/reference/attention.py b/dsv4/_archive/reference/attention.py similarity index 100% rename from dsv4/reference/attention.py rename to dsv4/_archive/reference/attention.py diff --git a/dsv4/reference/compressor.py b/dsv4/_archive/reference/compressor.py similarity index 100% rename from dsv4/reference/compressor.py rename to dsv4/_archive/reference/compressor.py diff --git a/dsv4/reference/csa_attention.py b/dsv4/_archive/reference/csa_attention.py similarity index 100% rename from dsv4/reference/csa_attention.py rename to dsv4/_archive/reference/csa_attention.py diff --git a/dsv4/reference/moe_pipeline.py b/dsv4/_archive/reference/moe_pipeline.py similarity index 100% rename from dsv4/reference/moe_pipeline.py rename to dsv4/_archive/reference/moe_pipeline.py diff --git a/dsv4/kernels/attention/__init__.py b/dsv4/kernels/attention/__init__.py index 1038473d..e13a048a 100644 --- a/dsv4/kernels/attention/__init__.py +++ b/dsv4/kernels/attention/__init__.py @@ -1,180 +1,6 @@ """DSV4 Attention kernels — public integration API. -==================================================================== -STATUS: SKELETON — not yet connected to model -==================================================================== -These functions define the API that AttentionSubBlock will call. -They're correct in structure but depend on: -1. LayerCacheHandle being fully implemented (gather_compressed_kv, etc.) -2. The production FMHA wrapper supporting sink_bias and n_comp -3. Custom op registration for torch.compile compatibility - -See ROADMAP.md Priority 5 for the full Stage E checklist. -==================================================================== - -These functions bridge the model's AttentionSubBlock to the production -FMHA kernel wrapper. Each function handles the cache → dense-tensor -materialization that the kernel requires. - -The model's attention layer calls these after: -1. Projection (q_down, q_up, kv_down) -2. RoPE application -3. Compression + cache writes -4. Indexer + top-k (CSA only) - -These functions handle: -- Gathering sparse/dense KV from cache into dense tensors -- Calling the production FMHA wrapper -- Returning attention output for inverse RoPE + wo_a/wo_b +The live inference path uses dsv4.kernels.attention.production directly. +See production.py for the dsv4_attention function used by single_shot_inference.py. """ from dsv4.kernels.attention.production import dsv4_attention -import torch -from typing import Optional, TYPE_CHECKING - -if TYPE_CHECKING: - from dsv4.cache.handle import LayerCacheHandle - - -def sparse_fmha_with_swa( - q: torch.Tensor, # (T, n_h * hd) BF16, post-RoPE - cache: "LayerCacheHandle", # provides compressed + SWA KV - selected_indices: torch.Tensor, # (T, top_k) int64 — which compressed blocks - sink_logits: Optional[torch.Tensor] = None, # (n_h,) FP32 - sliding_window: int = 128, -) -> torch.Tensor: - """CSA attention: sparse top-k compressed KV + sliding window, fused sink merge. - - Gathers the top-k compressed KV blocks + SWA window into a contiguous - tensor, then calls the production FMHA with sink bias. - - Args: - q: (T, n_h * hd) BF16 query (post-RoPE, pre-reshape) - cache: LayerCacheHandle with CSA compressed entries + SWA window - selected_indices: (T, top_k) int64 block indices from the indexer - sink_logits: (n_h,) FP32 per-head sink bias - sliding_window: SWA window length - - Returns: - (T, n_h * hd) BF16 attention output (pre inverse-RoPE) - """ - # Reshape q to (n_h, T, hd) - n_h_and_hd = q.shape[-1] - # n_h and hd come from the cache's config - n_h = cache.num_query_heads - hd = cache.head_dim - T = q.shape[0] - q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd) - - # Gather compressed KV for the selected blocks - # The cache handle provides the materialized dense KV from paged pool - k_compressed, v_compressed = cache.gather_compressed_kv(selected_indices) - # k_compressed: (1, n_comp_kv, hd) or (n_kv, n_comp_kv, hd) - # v_compressed: same shape - - # Gather SWA window KV - k_swa, v_swa = cache.gather_swa_kv() - # k_swa: (1, swa_len, hd), v_swa: same - - # Concatenate: [compressed, SWA] — single softmax (D5c insight) - k_full = torch.cat([k_compressed, k_swa], dim=-2) # (1, n_comp+swa_len, hd) - v_full = torch.cat([v_compressed, v_swa], dim=-2) - - # n_comp = compressed KV length (for sink bias offset) - n_comp = k_compressed.shape[-2] - - # Call production attention — MQA (n_kv=1 for DSV4) - output = dsv4_attention( - q_heads, k_full, v_full, - swa_len=sliding_window, - is_causal=True, - n_comp=n_comp, - sink_bias=sink_logits, - ) # (n_h, T, hd) - - # Reshape back to (T, n_h * hd) - return output.permute(1, 0, 2).reshape(T, n_h * hd) - - -def dense_fmha_with_swa( - q: torch.Tensor, - cache: "LayerCacheHandle", - sink_logits: Optional[torch.Tensor] = None, - sliding_window: int = 128, -) -> torch.Tensor: - """HCA attention: dense over all compressed KV + SWA window, fused sink merge. - - No indexer — all compressed entries are attended (m'=128 compression - means the sequence is very short). - - Args: - q: (T, n_h * hd) BF16 query - cache: LayerCacheHandle with HCA compressed entries + SWA window - sink_logits: (n_h,) FP32 per-head sink bias - sliding_window: SWA window length - - Returns: - (T, n_h * hd) BF16 attention output - """ - n_h = cache.num_query_heads - hd = cache.head_dim - T = q.shape[0] - q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) - - # Dense: gather ALL compressed KV (no indexer needed) - k_compressed, v_compressed = cache.gather_all_compressed_kv() - - k_swa, v_swa = cache.gather_swa_kv() - - k_full = torch.cat([k_compressed, k_swa], dim=-2) - v_full = torch.cat([v_compressed, v_swa], dim=-2) - - n_comp = k_compressed.shape[-2] - - output = dsv4_attention( - q_heads, k_full, v_full, - swa_len=sliding_window, - is_causal=True, - n_comp=n_comp, - sink_bias=sink_logits, - ) - - return output.permute(1, 0, 2).reshape(T, n_h * hd) - - -def swa_only_fmha( - q: torch.Tensor, - cache: "LayerCacheHandle", - sink_logits: Optional[torch.Tensor] = None, - sliding_window: int = 128, -) -> torch.Tensor: - """SWA-only attention: pure local attention over the sliding window. - - No compression branch, no indexer. Used for the first two layers - of the Flash variant. - - Args: - q: (T, n_h * hd) BF16 query - cache: LayerCacheHandle with SWA window - sink_logits: (n_h,) FP32 per-head sink bias - sliding_window: SWA window length - - Returns: - (T, n_h * hd) BF16 attention output - """ - n_h = cache.num_query_heads - hd = cache.head_dim - T = q.shape[0] - q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) - - k_swa, v_swa = cache.gather_swa_kv() - - # No n_comp (no compressed branch), no sink bias offset - output = dsv4_attention( - q_heads, k_swa, v_swa, - swa_len=sliding_window, - is_causal=True, - n_comp=0, - sink_bias=sink_logits, - ) - - return output.permute(1, 0, 2).reshape(T, n_h * hd) diff --git a/dsv4/kernels/compressor/__init__.py b/dsv4/kernels/compressor/__init__.py index 0c4e776a..aee3c7db 100644 --- a/dsv4/kernels/compressor/__init__.py +++ b/dsv4/kernels/compressor/__init__.py @@ -1,56 +1,5 @@ """CSA/HCA compressor — Python API bridge. -Wraps the compression functions with the interface that -AttentionSubBlock and flush.py expect. - -The compressor runs token-level softmax over m entries (CSA) or m' entries (HCA) -to produce compressed KV entries. The compressed entries are then written to the -paged pool by the flush_write kernel. +See dsv4/kernels/compressor/production_compress.py for the live path. +See dsv4/kernels/cuda/compressor_reduce.cu for the CUDA kernel. """ -import torch -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from dsv4.cache.handle import LayerCacheHandle - -from dsv4.kernels.compressor.compress_tail import csa_compress_tail, hca_compress_tail - - -def csa_compress_and_store( - kv_raw: torch.Tensor, # (T, head_dim) BF16 — current KV (goes to tail) - cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool -) -> None: - """CSA: compress KV entries and store into the classical paged cache. - - Steps: - 1. Check if tail has enough entries (tail_len >= m=4) - 2. If so, run compression (csa_compress_tail) - 3. Write compressed output to paged pool via flush_write - 4. Update tail buffer (a-stream becomes next b-stream) - """ - from dsv4.kernels.cuda.flush_write import flush_write_csa_cuda - # NOTE: This function is called from AttentionSubBlock.forward, which - # writes the raw KV to the tail buffer first (via cache.write_swa). - # The actual compression + flush happens when tail_len >= m. - # For now, the write_swa call handles the tail buffer write. - # The flush is triggered separately by the flush pipeline. - # See dsv4/cache/flush.py for the flush orchestration. - pass # Compression is handled by flush.py, not directly here - - -def hca_compress_and_store( - kv_raw: torch.Tensor, # (T, head_dim) BF16 - cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool -) -> None: - """HCA: compress KV entries and store into the classical paged cache. - - Same structure as CSA but no b-stream, no overlap, m'=128. - """ - pass # See flush.py - - -# Make compress_tail functions importable from this package -__all__ = [ - 'csa_compress_and_store', 'hca_compress_and_store', - 'csa_compress_tail', 'hca_compress_tail', -] diff --git a/dsv4/kernels/cuda/__init__.py b/dsv4/kernels/cuda/__init__.py index 9b48177f..c2e6632b 100644 --- a/dsv4/kernels/cuda/__init__.py +++ b/dsv4/kernels/cuda/__init__.py @@ -1,2 +1,2 @@ """CUDA kernel loader — re-exports from loader.py for convenience.""" -from dsv4.kernels.cuda.loader import get_cuda_module, preload_all +from dsv4.kernels.cuda.loader import get_cuda_module diff --git a/dsv4/kernels/cuda/loader.py b/dsv4/kernels/cuda/loader.py index 3200e4ea..f5c380bb 100644 --- a/dsv4/kernels/cuda/loader.py +++ b/dsv4/kernels/cuda/loader.py @@ -7,7 +7,7 @@ being called on every kernel invocation (was ~100ms per call, called ~500x per t Usage: from dsv4.kernels.cuda.loader import get_cuda_module mod = get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"]) - result = mod.fused_amax_quantize_nvfp4(x, divisor) + result = mod.quantize_nvfp4_from_buffer(x, divisor) """ import os import hashlib @@ -65,17 +65,4 @@ def get_cuda_module(name, sources, extra_cuda_cflags=None): return mod -def preload_all(): - """Preload all CUDA kernels at startup (before the hot path).""" - # amax_gsa — computes gsa on GPU (no .item()) - get_cuda_module("amax_gsa", ["amax_gsa.cu"]) - # quantize-from-buffer — reads gsa from GPU buffer (no .item()) - get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"]) - # Standalone quantize (for when gsa is known, not hot path) - get_cuda_module("quantize_nvfp4", ["quantize_nvfp4.cu"]) - # Sampler - get_cuda_module("sampler", ["sampler.cu"]) - # Dequant NVFP4 - get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"]) - # Fused compress + quantize - get_cuda_module("compressor_reduce_quant", ["compressor_reduce_quant.cu"]) + diff --git a/dsv4/kernels/indexer/__init__.py b/dsv4/kernels/indexer/__init__.py index 303ed0d0..48da924c 100644 --- a/dsv4/kernels/indexer/__init__.py +++ b/dsv4/kernels/indexer/__init__.py @@ -1,63 +1,5 @@ """CSA indexer — Python API bridge. -Wraps the CUDA indexer score+topk kernel with the interface that -AttentionSubBlock expects. - -The indexer (paper §2.3.5, eq. 16) scores each query against -compressed blocks via weighted ReLU MQA logits, then selects -top-k blocks for sparse attention. - -Currently uses scalar FP32 CUDA cores after FP4 dequant. -The FP4 tensor-core path (Stage F / E7) is a future optimization. +See dsv4/kernels/cuda/indexer_score_topk.cu for the live CUDA kernel. +The live inference path uses the inline indexer in single_shot_inference.py. """ -import torch -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from dsv4.cache.handle import LayerCacheHandle - - -def compute_index_scores_topk( - q_indexer: torch.Tensor, # (T, n_I_h * c_I) BF16 — indexer query - w_indexer: torch.Tensor, # (T, n_I_h) FP32 — per-head weights - cache: "LayerCacheHandle", # provides FP4 indexer keys - top_k: int = 512, # number of blocks to select -) -> torch.Tensor: # (T, top_k) int64 — selected block indices - """CSA: score compressed entries and select top-k blocks. - - Uses the CUDA indexer_score_topk kernel (raw CUDA, FP4 dequant + scalar - score + min-heap top-k). Returns entry indices for gather_compressed_kv. - """ - from dsv4.kernels.indexer.score_topk import run_indexer_score_topk - - # Read the indexer view from the cache - indexer_view = cache.read_indexer_view() - - # c_I is the indexer head dimension from schema - n_I_h = cache.schema.indexer_entries_per_block # This is entries, not heads - c_I = cache.schema.indexer_head_dim # 128 - - # n_I_h (number of indexer heads) comes from the config, not the schema. - # We need to pass it through the handle or compute it. - # For DSV4: n_I_h = 64 (same for Flash and Pro) - # TODO: add indexer_num_heads to schema or handle - n_I_h = 64 # config.indexer_num_heads, hardcoded for now - - # Reshape q_indexer from (T, n_I_h * c_I) to (T, n_I_h * c_I) — already flat - # The kernel expects q_I: [T, n_I_h * c_I] BF16 - # and w_h: [T, n_I_h] FP32 - - entries_per_block = cache.schema.entries_per_block - - indices = run_indexer_score_topk( - q_I=q_indexer, - w_h=w_indexer.float() if w_indexer.dtype != torch.float32 else w_indexer, - indexer_view=indexer_view, - num_heads=n_I_h, - head_dim=c_I, - top_k=top_k, - entries_per_block=entries_per_block, - ) - - # indices: (T, top_k) int32 → convert to int64 for gather_compressed_kv - return indices.to(torch.int64) diff --git a/helpers/import_closure.py b/helpers/import_closure.py index c05f49fb..c3d0b5db 100644 --- a/helpers/import_closure.py +++ b/helpers/import_closure.py @@ -1,5 +1,6 @@ # helpers/import_closure.py — list dsv4 modules NOT reachable from the entry points. -# Usage: python helpers/import_closure.py (run from repo root, PYTHONPATH=repo root) +# Usage: python3 helpers/import_closure.py (run from repo root) +# NOTE: handles lazy imports inside functions (single_shot uses these heavily) import ast, pathlib, sys ROOT = pathlib.Path(__file__).resolve().parent.parent ENTRYPOINTS = ["single_shot_inference.py"] # vLLM has 0 imports of dsv4 (Step 0 confirmed) @@ -11,6 +12,7 @@ def module_to_path(mod): return p if p.exists() else None def imports_of(path): + """Parse ALL imports including lazy ones inside functions.""" tree = ast.parse(path.read_text()) out = set() for n in ast.walk(tree): diff --git a/tests/production_values_test.py b/tests/e2e_archive/production_values_test.py similarity index 100% rename from tests/production_values_test.py rename to tests/e2e_archive/production_values_test.py diff --git a/tests/e2e/test_csa_hca_integration.py b/tests/e2e_archive/test_csa_hca_integration.py similarity index 100% rename from tests/e2e/test_csa_hca_integration.py rename to tests/e2e_archive/test_csa_hca_integration.py diff --git a/tests/unit/test_fused_router.py b/tests/e2e_archive/test_fused_router.py similarity index 100% rename from tests/unit/test_fused_router.py rename to tests/e2e_archive/test_fused_router.py diff --git a/tests/e2e/test_model_construction.py b/tests/e2e_archive/test_model_construction.py similarity index 100% rename from tests/e2e/test_model_construction.py rename to tests/e2e_archive/test_model_construction.py diff --git a/tests/e2e/test_one_layer.py b/tests/e2e_archive/test_one_layer.py similarity index 100% rename from tests/e2e/test_one_layer.py rename to tests/e2e_archive/test_one_layer.py