diff --git a/dsv4/layers/attention.py b/dsv4/layers/attention.py index 5ed0a464..f2f42c1b 100644 --- a/dsv4/layers/attention.py +++ b/dsv4/layers/attention.py @@ -1,2 +1,242 @@ -"""DSV4 attention sub-block.""" -# TODO: Phase 3+4 +"""DSV4 attention sub-block — CSA / HCA / SWA variants. + +The constructor decides which variant this instance is by reading the +LayerSpec passed in. After construction, forward() is fixed — no runtime +branching on attention type, so torch.compile can fold the dispatch +completely. + +What this class is responsible for: + - Holding the projection weights (Q-down/up, KV-down, output groups). + - Calling Nvfp4Linear for the projections. + - Running RoPE (partial, last rope_dim dimensions). + - Calling the compressor (CSA or HCA) to write into the KV cache. + - For CSA only: calling the indexer + top-k to select sparse entries. + - Calling the FMHA kernel with the right K/V source. + - Calling inverse RoPE on the output. + - Calling Nvfp4GroupedLinear (wo_a) + Nvfp4Linear (wo_b). + +What this class does NOT touch: + - mHC pre/post — done by TransformerLayer. + - Pre-norm — done by TransformerLayer. + - KV cache allocation / eviction — done by the cache handle. +""" +from __future__ import annotations +from typing import TYPE_CHECKING +import torch + +from dsv4.layers.linear import Nvfp4Linear +from dsv4.layers.grouped_linear import Nvfp4GroupedLinear +from dsv4.model.layer_schedule import AttentionType, LayerSpec + +if TYPE_CHECKING: + from dsv4.model.config import DSV4Config + from dsv4.cache.paged_cache import LayerCacheHandle + + +class AttentionSubBlock: + def __init__(self, config: "DSV4Config", spec: LayerSpec): + self.config = config + self.spec = spec + self.layer_idx = spec.layer_idx + self.attn_type = spec.attn # frozen for this layer's lifetime + + # ---- Projections (all NVFP4) ---- + + # Query path is low-rank: hidden -> dc -> head_dim * num_heads. + self.q_down = Nvfp4Linear( + in_features=config.hidden_size, + out_features=config.query_compression_dim, + ) + self.q_up = Nvfp4Linear( + in_features=config.query_compression_dim, + out_features=config.head_dim * config.num_query_heads, + ) + + # KV projection — single head shared across queries (MQA). + # CSA needs two streams (a, b) for the overlapping compression; + # HCA needs one. SWA needs raw KV with no compression. + if self.attn_type == AttentionType.CSA: + # 2 * head_dim for (Ca, Cb), 2 * head_dim for compression weights (Za, Zb) + self.kv_down = Nvfp4Linear( + in_features=config.hidden_size, + out_features=4 * config.head_dim, + ) + elif self.attn_type == AttentionType.HCA: + # head_dim for C, head_dim for Z + self.kv_down = Nvfp4Linear( + in_features=config.hidden_size, + out_features=2 * config.head_dim, + ) + else: # SWA + self.kv_down = Nvfp4Linear( + in_features=config.hidden_size, + out_features=config.head_dim, + ) + + # ---- Indexer (CSA only) ---- + if self.attn_type == AttentionType.CSA: + # Indexer queries share the q_down latent c_Q with main queries. + # Up-projection: c_Q -> indexer_num_heads * indexer_head_dim. + self.indexer_q_up = Nvfp4Linear( + in_features=config.query_compression_dim, + out_features=config.indexer_num_heads * config.indexer_head_dim, + ) + # Per-head weight vector w_h: hidden -> indexer_num_heads. + self.indexer_head_weights = Nvfp4Linear( + in_features=config.hidden_size, + out_features=config.indexer_num_heads, + ) + + # ---- Output projection ---- + # wo_a: grouped, splits the n_heads outputs into n_groups and projects + # each group from (head_dim * n_heads / n_groups) to output_group_dim. + self.wo_a = Nvfp4GroupedLinear( + num_groups=config.num_output_groups, + in_features=(config.head_dim * config.num_query_heads) // config.num_output_groups, + out_features=config.output_group_dim, + ) + # wo_b: dense, concatenated group outputs back to hidden_size. + self.wo_b = Nvfp4Linear( + in_features=config.output_group_dim * config.num_output_groups, + out_features=config.hidden_size, + ) + + # Attention sink: per-head learnable logit added to the softmax + # denominator (paper §2.3.3). One scalar per query head. + self.sink_logits: torch.Tensor | None = None # set by load_weights + + # ------------------------------------------------------------------ + def forward( + self, + x: torch.Tensor, # (T, hidden_size) BF16, post-RMSNorm + cache: "LayerCacheHandle", # per-layer cache slice + ) -> torch.Tensor: # (T, hidden_size) BF16 + # Branch on attention type. Frozen at construction so this is a + # constant in any traced graph. + if self.attn_type == AttentionType.CSA: + return self._forward_csa(x, cache) + elif self.attn_type == AttentionType.HCA: + return self._forward_hca(x, cache) + else: + return self._forward_swa(x, cache) + + # ------------------------------------------------------------------ + # CSA: compressed + sparse attention, with indexer + top-k. + # ------------------------------------------------------------------ + def _forward_csa(self, x, cache): + # 1) Low-rank query projection. c_Q is shared with the indexer. + c_Q = self.q_down(x) # (T, dc) + q = self.q_up(c_Q) # (T, n_h * head_dim) + + # 2) KV down-projection — produces (Ca, Cb, Za, Zb) interleaved. + kv_raw = self.kv_down(x) # (T, 4 * head_dim) + + # 3) Apply partial RoPE to q and the to-be-compressed keys. + q_roped = self._apply_rope(q, positions=cache.positions) + + # 4) Compress + write to cache. Returns nothing — the cache holds it. + from dsv4.kernels.compressor import csa_compress_and_store + csa_compress_and_store( + kv_raw=kv_raw, + cache=cache, # writes C_comp + Z + indexer keys + positions=cache.positions, + compression_ratio=self.config.csa_compression_ratio, + ) + + # 5) Indexer: compute index scores, select top-k compressed entries. + q_indexer = self.indexer_q_up(c_Q) # (T, n_I_h * c_I) + w_indexer = self.indexer_head_weights(x) # (T, n_I_h) + + from dsv4.kernels.indexer import compute_index_scores_topk + topk_indices = compute_index_scores_topk( + q_indexer=q_indexer, + w_indexer=w_indexer, + cache=cache, # provides compressed indexer keys + top_k=self.config.csa_top_k, + ) + + # 6) FMHA over selected sparse entries + SWA branch, fused sink merge. + from dsv4.kernels.attention import sparse_fmha_with_swa + attn_out = sparse_fmha_with_swa( + q=q_roped, + cache=cache, + selected_indices=topk_indices, + sink_logits=self.sink_logits, + sliding_window=self.config.sliding_window, + ) # (T, n_h * head_dim) + + # 7) Inverse RoPE on outputs (paper §2.3.3). + from dsv4.ops.rope import inverse_rope_bf16 + attn_out = inverse_rope_bf16(attn_out, positions=cache.positions) + + # 8) Grouped output projection: wo_a then wo_b. + grouped = self.wo_a(attn_out) # (T, n_groups * d_g) + return self.wo_b(grouped) # (T, hidden_size) + + # ------------------------------------------------------------------ + # HCA: heavy compression, dense attention over compressed entries. + # ------------------------------------------------------------------ + def _forward_hca(self, x, cache): + c_Q = self.q_down(x) + q = self.q_up(c_Q) + kv_raw = self.kv_down(x) # (T, 2 * head_dim) + q_roped = self._apply_rope(q, positions=cache.positions) + + from dsv4.kernels.compressor import hca_compress_and_store + hca_compress_and_store( + kv_raw=kv_raw, + cache=cache, + positions=cache.positions, + compression_ratio=self.config.hca_compression_ratio, + ) + + # Dense FMHA over ALL compressed entries (not top-k) plus SWA. + from dsv4.kernels.attention import dense_fmha_with_swa + attn_out = dense_fmha_with_swa( + q=q_roped, + cache=cache, + sink_logits=self.sink_logits, + sliding_window=self.config.sliding_window, + ) + + from dsv4.ops.rope import inverse_rope_bf16 + attn_out = inverse_rope_bf16(attn_out, positions=cache.positions) + + grouped = self.wo_a(attn_out) + return self.wo_b(grouped) + + # ------------------------------------------------------------------ + # SWA: sliding window only, no compression branch. + # First 2 layers of Flash. Acts as the "local context" foundation. + # ------------------------------------------------------------------ + def _forward_swa(self, x, cache): + c_Q = self.q_down(x) + q = self.q_up(c_Q) + kv_raw = self.kv_down(x) # (T, head_dim) + q_roped = self._apply_rope(q, positions=cache.positions) + + # Write raw KV to the SWA window in the cache. No compressor. + cache.write_swa(kv_raw, positions=cache.positions) + + # Dense FMHA over the sliding window only. + from dsv4.kernels.attention import swa_only_fmha + attn_out = swa_only_fmha( + q=q_roped, + cache=cache, + sink_logits=self.sink_logits, + sliding_window=self.config.sliding_window, + ) + + from dsv4.ops.rope import inverse_rope_bf16 + attn_out = inverse_rope_bf16(attn_out, positions=cache.positions) + + grouped = self.wo_a(attn_out) + return self.wo_b(grouped) + + # ------------------------------------------------------------------ + def _apply_rope(self, q: torch.Tensor, positions: torch.Tensor): + """Apply RoPE to the last `rope_dim` dimensions of each head.""" + from dsv4.ops.rope import forward_rope_partial + return forward_rope_partial( + q, positions=positions, rope_dim=self.config.rope_dim, + ) diff --git a/dsv4/layers/ffn.py b/dsv4/layers/ffn.py index 8c4e53e2..cf12d203 100644 --- a/dsv4/layers/ffn.py +++ b/dsv4/layers/ffn.py @@ -1,2 +1,53 @@ -"""FFN: router + MoE + shared expert.""" -# TODO: Phase 2 +"""DSV4 FFN sub-block — routed MoE + shared expert. + +The router instance encapsulates hash-vs-dense; this sub-block doesn't +have to care, it just calls router(x, token_ids) and feeds the result +to Nvfp4MoE. Shared expert runs in parallel (logically — kernels +can overlap). +""" +from __future__ import annotations +from typing import TYPE_CHECKING +import torch + +from dsv4.layers.router import Router +from dsv4.layers.moe import Nvfp4MoE +from dsv4.layers.shared_expert import Nvfp4SharedExpert +from dsv4.model.layer_schedule import LayerSpec, RouterMode + +if TYPE_CHECKING: + from dsv4.model.config import DSV4Config + + +class FFNSubBlock: + def __init__(self, config: "DSV4Config", spec: LayerSpec): + self.config = config + self.spec = spec + + self.router = Router( + hidden_size=config.hidden_size, + num_experts=config.num_routed_experts, + top_k=config.num_experts_per_tok, + routed_scaling_factor=config.routed_scaling_factor, + mode="hash" if spec.router_mode == RouterMode.HASH else "dense", + vocab_size=config.vocab_size if spec.router_mode == RouterMode.HASH else None, + ) + self.moe = Nvfp4MoE( + num_experts=config.num_routed_experts, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + top_k=config.num_experts_per_tok, + ) + self.shared = Nvfp4SharedExpert( + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + ) + + def forward( + self, + x: torch.Tensor, # (T, hidden_size) BF16, post-RMSNorm + token_ids: torch.Tensor, # (T,) int32 — needed only for hash routing + ) -> torch.Tensor: + topk_w, topk_ids = self.router(x, token_ids=token_ids) + routed_out = self.moe.run(x, topk_w, topk_ids) + shared_out = self.shared.run(x) + return routed_out + shared_out diff --git a/dsv4/model/config.py b/dsv4/model/config.py index db67e083..9683290e 100644 --- a/dsv4/model/config.py +++ b/dsv4/model/config.py @@ -1,2 +1,120 @@ -"""DSV4Config (Flash + Pro).""" -# TODO: Phase 1 +"""DSV4 architecture configuration. + +One config class for both Flash and Pro — they differ only in numeric +parameters, not structure. Construct via .flash() / .pro() classmethods +or load from a checkpoint's config.json. +""" +from __future__ import annotations +from dataclasses import dataclass + + +@dataclass(frozen=True) +class DSV4Config: + # Identity + variant: str # "flash" | "pro" + vocab_size: int + + # Transformer + num_layers: int # Flash: 43, Pro: 61 + hidden_size: int # Flash: 4096, Pro: 7168 + + # Attention (shared across CSA/HCA/SWA) + num_query_heads: int # Flash: 64, Pro: 128 + head_dim: int # Both: 512 + rope_dim: int # Both: 64 (partial RoPE on last 64 dims) + query_compression_dim: int # Flash: 1024, Pro: 1536 + + # CSA-specific + csa_compression_ratio: int # m, both: 4 + csa_top_k: int # Flash: 512, Pro: 1024 + indexer_num_heads: int # Both: 64 + indexer_head_dim: int # Both: 128 + + # HCA-specific + hca_compression_ratio: int # m', both: 128 + + # SWA branch + sliding_window: int # n_win, both: 128 + + # Output projection grouping + num_output_groups: int # Flash: 8, Pro: 16 + output_group_dim: int # Both: 1024 + + # MoE + num_routed_experts: int # Flash: 256, Pro: 384 + num_shared_experts: int # Both: 1 + num_experts_per_tok: int # Both: 6 + moe_intermediate_size: int # Flash: 2048, Pro: 3072 + num_hash_routing_layers: int # Both: 3 + routed_scaling_factor: float # Verify per checkpoint + + # mHC + n_hc: int # Both: 4 + sinkhorn_iters: int # Both: 20 + + # Architecture schedule — first N layers use a non-standard attention + # pattern (paper §4.2.1). + # Flash: layers 0-1 = SWA only + # Pro: layers 0-1 = HCA only + num_special_first_layers: int = 2 + + # Norm + rms_norm_eps: float = 1e-6 + + @classmethod + def flash(cls) -> "DSV4Config": + return cls( + variant="flash", + vocab_size=128_000, + num_layers=43, + hidden_size=4096, + num_query_heads=64, + head_dim=512, + rope_dim=64, + query_compression_dim=1024, + csa_compression_ratio=4, + csa_top_k=512, + indexer_num_heads=64, + indexer_head_dim=128, + hca_compression_ratio=128, + sliding_window=128, + num_output_groups=8, + output_group_dim=1024, + num_routed_experts=256, + num_shared_experts=1, + num_experts_per_tok=6, + moe_intermediate_size=2048, + num_hash_routing_layers=3, + routed_scaling_factor=2.5, + n_hc=4, + sinkhorn_iters=20, + ) + + @classmethod + def pro(cls) -> "DSV4Config": + return cls( + variant="pro", + vocab_size=128_000, + num_layers=61, + hidden_size=7168, + num_query_heads=128, + head_dim=512, + rope_dim=64, + query_compression_dim=1536, + csa_compression_ratio=4, + csa_top_k=1024, + indexer_num_heads=64, + indexer_head_dim=128, + hca_compression_ratio=128, + sliding_window=128, + num_output_groups=16, + output_group_dim=1024, + num_routed_experts=384, + num_shared_experts=1, + num_experts_per_tok=6, + moe_intermediate_size=3072, + num_hash_routing_layers=3, + routed_scaling_factor=2.5, + n_hc=4, + sinkhorn_iters=20, + ) diff --git a/dsv4/model/layer.py b/dsv4/model/layer.py index be929a3c..65f338f9 100644 --- a/dsv4/model/layer.py +++ b/dsv4/model/layer.py @@ -1,2 +1,82 @@ -"""Single transformer layer.""" -# TODO: Phase 1 +"""A single DSV4 transformer layer. + +Structure (paper Figure 2): + X_l ─→ mHC.pre_block ─→ RMSNorm ─→ Attention ─→ mHC.post_block (using F_attn) + │ + ▼ + mHC.pre_block ─→ RMSNorm ─→ FFN ─→ mHC.post_block (using F_ffn) + │ + ▼ + X_{l+1} + +Each layer owns: + - One LayerSpec (from build_schedule). + - Two mHC instances (one per sub-block). + - One AttentionSubBlock (type fixed by spec.attn). + - One FFNSubBlock (router mode fixed by spec.router_mode). + - Two RMSNorm weight tensors. + +The layer is otherwise pure orchestration: no learned params live +directly on TransformerLayer, only on its components. +""" +from __future__ import annotations +from typing import TYPE_CHECKING +import torch + +from dsv4.layers.mhc import mHCLayer +from dsv4.layers.attention import AttentionSubBlock +from dsv4.layers.ffn import FFNSubBlock +from dsv4.layers.norm import RMSNorm # PyTorch ref for now, fused later +from dsv4.model.layer_schedule import LayerSpec + +if TYPE_CHECKING: + from dsv4.model.config import DSV4Config + from dsv4.cache.paged_cache import LayerCacheHandle + + +class TransformerLayer: + def __init__(self, config: "DSV4Config", spec: LayerSpec): + self.config = config + self.spec = spec + self.layer_idx = spec.layer_idx + + # Two mHC wrappers — one per sub-block. mHCLayer holds its own + # projection weights (W_pre, W_res, W_post) and static biases. + self.mhc_attn = mHCLayer( + hidden_size=config.hidden_size, + n_hc=config.n_hc, + sinkhorn_iters=config.sinkhorn_iters, + ) + self.mhc_ffn = mHCLayer( + hidden_size=config.hidden_size, + n_hc=config.n_hc, + sinkhorn_iters=config.sinkhorn_iters, + ) + + # Pre-block norms (one per sub-block). + self.norm_attn = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm_ffn = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Sub-blocks — type-frozen at construction. + self.attn = AttentionSubBlock(config, spec) + self.ffn = FFNSubBlock(config, spec) + + def forward( + self, + X: torch.Tensor, # (T, n_hc, hidden_size) BF16 — residual streams + token_ids: torch.Tensor, # (T,) int32 — for hash routing + cache: "LayerCacheHandle", + ) -> torch.Tensor: + # ---- Attention sub-block ---- + x_attn_in, ctx_attn = self.mhc_attn.pre_block(X) + x_attn_in = self.norm_attn(x_attn_in) + F_attn = self.attn.forward(x_attn_in, cache) + X = self.mhc_attn.post_block(X, F_attn, ctx_attn) + + # ---- FFN sub-block ---- + x_ffn_in, ctx_ffn = self.mhc_ffn.pre_block(X) + x_ffn_in = self.norm_ffn(x_ffn_in) + F_ffn = self.ffn.forward(x_ffn_in, token_ids) + X = self.mhc_ffn.post_block(X, F_ffn, ctx_ffn) + + return X diff --git a/dsv4/model/layer_schedule.py b/dsv4/model/layer_schedule.py new file mode 100644 index 00000000..95fcb477 --- /dev/null +++ b/dsv4/model/layer_schedule.py @@ -0,0 +1,102 @@ +"""Per-layer-index architecture schedule for DSV4. + +Encodes the attention / FFN / routing pattern from the paper. Kept +separate from layer.py so the KV cache allocator can query it without +instantiating any layers. +""" +from __future__ import annotations +from enum import Enum +from dataclasses import dataclass +from typing import List + +from dsv4.model.config import DSV4Config + + +class AttentionType(Enum): + SWA = "swa" # Sliding window only (Flash layers 0-1) + HCA = "hca" # Heavily Compressed Attention (Pro layers 0-1, alternating later) + CSA = "csa" # Compressed Sparse Attention (alternating from layer 2) + + +class FFNType(Enum): + MOE = "moe" # DeepSeekMoE — both variants use MoE in every layer + + +class RouterMode(Enum): + HASH = "hash" # Layers 0-2: deterministic by token ID + DENSE = "dense" # Layers 3+: sqrt(softplus) + top-k + + +@dataclass(frozen=True) +class LayerSpec: + """The complete architectural type of a single transformer layer.""" + layer_idx: int + attn: AttentionType + ffn: FFNType + router_mode: RouterMode + + +def build_schedule(config: DSV4Config) -> List[LayerSpec]: + """Produce the full per-layer schedule for a given config. + + Paper §4.2.1: + Flash: first 2 layers = pure SWA, then alternating CSA / HCA. + Pro: first 2 layers = HCA, then alternating CSA / HCA. + Both: first 3 MoE layers use hash routing. + + 'Alternating CSA / HCA' starts with CSA at layer 2 in both variants. + """ + schedule: List[LayerSpec] = [] + first_special = config.num_special_first_layers # 2 + + for i in range(config.num_layers): + # Attention type + if i < first_special: + if config.variant == "flash": + attn = AttentionType.SWA + else: # pro + attn = AttentionType.HCA + else: + # Layer 2 is CSA, layer 3 is HCA, layer 4 is CSA, ... + attn = AttentionType.CSA if (i - first_special) % 2 == 0 else AttentionType.HCA + + # FFN type — DSV4 is MoE in every layer. + ffn = FFNType.MOE + + # Router mode — hash for the first N MoE layers. + router = RouterMode.HASH if i < config.num_hash_routing_layers else RouterMode.DENSE + + schedule.append(LayerSpec(layer_idx=i, attn=attn, ffn=ffn, router_mode=router)) + + return schedule + + +def validate_schedule(schedule: List[LayerSpec], config: DSV4Config) -> None: + """Sanity checks. Wrong schedule = silent garbage, so be loud here.""" + assert len(schedule) == config.num_layers, \ + f"schedule has {len(schedule)} layers, config says {config.num_layers}" + + # First N hash routing layers + for i in range(config.num_hash_routing_layers): + assert schedule[i].router_mode == RouterMode.HASH, \ + f"layer {i} should be HASH-routed" + for i in range(config.num_hash_routing_layers, config.num_layers): + assert schedule[i].router_mode == RouterMode.DENSE, \ + f"layer {i} should be DENSE-routed" + + # First two layers are special + expected_first = AttentionType.SWA if config.variant == "flash" else AttentionType.HCA + for i in range(config.num_special_first_layers): + assert schedule[i].attn == expected_first, \ + f"layer {i} should be {expected_first} for {config.variant}" + + # Alternation from layer 2 + for i in range(config.num_special_first_layers, config.num_layers): + expected = AttentionType.CSA if (i - config.num_special_first_layers) % 2 == 0 \ + else AttentionType.HCA + assert schedule[i].attn == expected, \ + f"layer {i} should be {expected}, got {schedule[i].attn}" + + # Every layer is MoE + for i, spec in enumerate(schedule): + assert spec.ffn == FFNType.MOE, f"layer {i} should be MOE" diff --git a/run_router_tests.py b/run_router_tests.py new file mode 100644 index 00000000..5ee89d34 --- /dev/null +++ b/run_router_tests.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +import sys, os +os.chdir('/root/dsv4-nvfp4-workspace/kernel') +sys.path.insert(0, '.') +import torch +from torch.utils.cpp_extension import load + +def pf(ok): + return 'PASS' if ok else 'FAIL' + +print(f'CUDA: {torch.cuda.is_available()}, Device: {torch.cuda.get_device_name(0)}', flush=True) +print(f'GPU mem free: {torch.cuda.mem_get_info()[0]//1024//1024} MB', flush=True) + +# 1. Hash Router +print('\n=== Hash Router ===', flush=True) +try: + hr = load(name='hr2', sources=['dsv4/kernels/cuda/hash_router.cu'], + extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False) + for N in [1, 4, 64, 128, 512]: + vocab, k, E = 128000, 6, 256 + lut = torch.randint(0, E, (vocab, k), dtype=torch.int32, device='cuda') + tids = torch.randint(0, vocab, (N,), dtype=torch.int32, device='cuda') + ow = torch.empty(N, k, dtype=torch.float32, device='cuda') + oi = torch.empty(N, k, dtype=torch.int32, device='cuda') + hr.hash_router(tids, lut, k, ow, oi) + torch.cuda.synchronize() + exp_ids = lut[tids] + exp_w = torch.full((N, k), 1.0/k, dtype=torch.float32, device='cuda') + ids_ok = (oi == exp_ids).all().item() + w_ok = torch.allclose(ow, exp_w, atol=1e-7, rtol=1e-7) + ok = ids_ok and w_ok + print(f' N={N:4d}: IDs={ids_ok} W={w_ok} {pf(ok)}', flush=True) + del lut, tids, ow, oi, exp_ids, exp_w + print('Hash Router: ALL PASS', flush=True) +except Exception as e: + import traceback; traceback.print_exc() + +torch.cuda.empty_cache() + +# 2. Top-k Select +print('\n=== Top-k Select ===', flush=True) +try: + tk = load(name='tk2', sources=['dsv4/kernels/cuda/topk_select.cu'], + extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False) + for N, E in [(1,256), (4,256), (64,256), (64,384), (128,256), (512,256)]: + k = 6 + scores = torch.randn(N, E, dtype=torch.float32, device='cuda') + ov, oidx = tk.topk_select(scores, k) + torch.cuda.synchronize() + exp = scores.topk(k, dim=-1) + ids_ok = (oidx == exp.indices).all().item() + vals_ok = torch.allclose(ov, exp.values, atol=1e-6, rtol=1e-6) + ok = ids_ok and vals_ok + print(f' N={N:4d} E={E}: IDs={ids_ok} V={vals_ok} {pf(ok)}', flush=True) + del scores, ov, oidx, exp + print('Top-k Select: ALL PASS', flush=True) +except Exception as e: + import traceback; traceback.print_exc() + +torch.cuda.empty_cache() + +# 3. Activation + Top-k +print('\n=== Activation + Top-k ===', flush=True) +try: + atk = load(name='atk2', sources=['dsv4/kernels/cuda/activation_topk.cu'], + extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False) + for N, E in [(1,256), (4,256), (64,256), (64,384)]: + k = 6 + logits = torch.randn(N, E, dtype=torch.float32, device='cuda') + bias = torch.randn(E, dtype=torch.float32, device='cuda') + scaling = 2.5 + out_w = torch.empty(N, k, dtype=torch.float32, device='cuda') + out_ids = torch.empty(N, k, dtype=torch.int32, device='cuda') + atk.activation_topk(logits, bias, k, scaling, out_w, out_ids) + torch.cuda.synchronize() + # Oracle + sp = torch.log1p(logits.exp()) + torch.clamp(logits, min=0) + act = sp.sqrt() + score = act + bias + exp_topk = score.topk(k, dim=-1) + exp_ids = exp_topk.indices + exp_w = torch.gather(act, 1, exp_ids) + exp_w = exp_w / exp_w.sum(dim=-1, keepdim=True) * scaling + ids_ok = (out_ids == exp_ids).all().item() + vals_ok = torch.allclose(out_w, exp_w, atol=1e-5, rtol=1e-5) + ok = ids_ok and vals_ok + print(f' N={N:4d} E={E}: IDs={ids_ok} V={vals_ok} {pf(ok)}', flush=True) + if not ids_ok: + print(f' ID mismatches: {(out_ids != exp_ids).sum().item()}/{out_ids.numel()}', flush=True) + if not vals_ok: + print(f' Max diff: {(out_w - exp_w).abs().max().item():.2e}', flush=True) + del logits, bias, out_w, out_ids, sp, act, score, exp_topk, exp_ids, exp_w + print('Activation+Topk: ALL PASS', flush=True) +except Exception as e: + import traceback; traceback.print_exc() + +print('\n=== DONE ===', flush=True) diff --git a/tests/unit/test_fmha_v3_tenwarp.py b/tests/unit/test_fmha_v3_tenwarp.py index 3f7af3b9..018d0492 100644 --- a/tests/unit/test_fmha_v3_tenwarp.py +++ b/tests/unit/test_fmha_v3_tenwarp.py @@ -98,11 +98,11 @@ class FmhaV3TenWarp: tmem_dealloc: cutlass.Int64; holding: cutlass.Int32 smem = utils.SmemAllocator(); st = smem.allocate(SS) - qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() - kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() - s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants() + qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() + kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() + s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants() softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id)) - acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True) + acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True) tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id))) tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True) diff --git a/tests/unit/test_layer_schedule.py b/tests/unit/test_layer_schedule.py new file mode 100644 index 00000000..00189470 --- /dev/null +++ b/tests/unit/test_layer_schedule.py @@ -0,0 +1,85 @@ +"""Tests for layer schedule — pure data, no kernels, no tensors.""" + +from dsv4.model.config import DSV4Config +from dsv4.model.layer_schedule import ( + AttentionType, FFNType, RouterMode, + LayerSpec, build_schedule, validate_schedule, +) + + +def test_flash_schedule(): + config = DSV4Config.flash() + schedule = build_schedule(config) + validate_schedule(schedule, config) + + assert len(schedule) == 43 + + # First two layers: SWA + hash routing + assert schedule[0].attn == AttentionType.SWA + assert schedule[1].attn == AttentionType.SWA + assert schedule[0].router_mode == RouterMode.HASH + assert schedule[1].router_mode == RouterMode.HASH + + # Layer 2: CSA + hash routing (last hash layer) + assert schedule[2].attn == AttentionType.CSA + assert schedule[2].router_mode == RouterMode.HASH + + # Layer 3: HCA + dense routing (first dense layer) + assert schedule[3].attn == AttentionType.HCA + assert schedule[3].router_mode == RouterMode.DENSE + + # Alternation continues + assert schedule[4].attn == AttentionType.CSA + assert schedule[5].attn == AttentionType.HCA + + # All layers are MoE + for spec in schedule: + assert spec.ffn == FFNType.MOE + + +def test_pro_schedule(): + config = DSV4Config.pro() + schedule = build_schedule(config) + validate_schedule(schedule, config) + + assert len(schedule) == 61 + + # First two layers: HCA + hash routing + assert schedule[0].attn == AttentionType.HCA + assert schedule[1].attn == AttentionType.HCA + assert schedule[0].router_mode == RouterMode.HASH + + # Layer 2: CSA + hash routing + assert schedule[2].attn == AttentionType.CSA + assert schedule[2].router_mode == RouterMode.HASH + + # Layer 3: HCA + dense routing + assert schedule[3].attn == AttentionType.HCA + assert schedule[3].router_mode == RouterMode.DENSE + + +def test_layer_spec_frozen(): + """LayerSpec is frozen — mutation should raise.""" + config = DSV4Config.flash() + spec = build_schedule(config)[0] + try: + spec.attn = AttentionType.HCA + assert False, "should have raised" + except AttributeError: + pass + + +def test_schedule_indices_match(): + """Each LayerSpec.layer_idx matches its position in the list.""" + config = DSV4Config.flash() + schedule = build_schedule(config) + for i, spec in enumerate(schedule): + assert spec.layer_idx == i + + +if __name__ == "__main__": + test_flash_schedule() + test_pro_schedule() + test_layer_spec_frozen() + test_schedule_indices_match() + print("All schedule tests passed")