Layer dispatch: config, schedule, attention/FFN sub-blocks, TransformerLayer

DSV4Config: frozen dataclass with .flash() / .pro() classmethods.
All architectural constants (dims, heads, MoE params, mHC) in one place.

LayerSchedule: pure-data per-layer-index -> (attn_type, ffn_type, router_mode).
  Flash: SWA, SWA, CSA, HCA, CSA, HCA, ... (43 layers)
  Pro:   HCA, HCA, CSA, HCA, CSA, HCA, ... (61 layers)
  Both:  first 3 MoE layers = hash routing, rest = dense
  validate_schedule() enforces correctness at construction.

AttentionSubBlock: CSA / HCA / SWA variants.
  - Low-rank Q projection (q_down -> q_up)
  - KV down-projection (varies by attn type: 4h/2h/1h)
  - CSA: indexer_q_up + indexer_head_weights
  - Grouped output projection (wo_a + wo_b)
  - Kernel calls are imports (NotImplementedError until kernel lands)
  - No PyTorch fallback paths

FFNSubBlock: MoE + shared expert.
  - Router (hash/dense) mode from LayerSpec
  - Nvfp4MoE + Nvfp4SharedExpert

TransformerLayer: composition of mHC + norm + attention + FFN.
  - Two mHC wrappers (attn + ffn sub-blocks)
  - Two RMSNorm (one per sub-block)
  - Pure orchestration, no learned params on the layer itself

Tests: schedule construction + validation for both variants.
No forward tests yet (depends on FMHA kernel + KV cache).
This commit is contained in:
2026-05-21 23:11:09 +00:00
parent c681b591a0
commit 0b8f4da323
8 changed files with 785 additions and 12 deletions

View File

@@ -1,2 +1,242 @@
"""DSV4 attention sub-block."""
# TODO: Phase 3+4
"""DSV4 attention sub-block — CSA / HCA / SWA variants.
The constructor decides which variant this instance is by reading the
LayerSpec passed in. After construction, forward() is fixed — no runtime
branching on attention type, so torch.compile can fold the dispatch
completely.
What this class is responsible for:
- Holding the projection weights (Q-down/up, KV-down, output groups).
- Calling Nvfp4Linear for the projections.
- Running RoPE (partial, last rope_dim dimensions).
- Calling the compressor (CSA or HCA) to write into the KV cache.
- For CSA only: calling the indexer + top-k to select sparse entries.
- Calling the FMHA kernel with the right K/V source.
- Calling inverse RoPE on the output.
- Calling Nvfp4GroupedLinear (wo_a) + Nvfp4Linear (wo_b).
What this class does NOT touch:
- mHC pre/post — done by TransformerLayer.
- Pre-norm — done by TransformerLayer.
- KV cache allocation / eviction — done by the cache handle.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from dsv4.layers.linear import Nvfp4Linear
from dsv4.layers.grouped_linear import Nvfp4GroupedLinear
from dsv4.model.layer_schedule import AttentionType, LayerSpec
if TYPE_CHECKING:
from dsv4.model.config import DSV4Config
from dsv4.cache.paged_cache import LayerCacheHandle
class AttentionSubBlock:
def __init__(self, config: "DSV4Config", spec: LayerSpec):
self.config = config
self.spec = spec
self.layer_idx = spec.layer_idx
self.attn_type = spec.attn # frozen for this layer's lifetime
# ---- Projections (all NVFP4) ----
# Query path is low-rank: hidden -> dc -> head_dim * num_heads.
self.q_down = Nvfp4Linear(
in_features=config.hidden_size,
out_features=config.query_compression_dim,
)
self.q_up = Nvfp4Linear(
in_features=config.query_compression_dim,
out_features=config.head_dim * config.num_query_heads,
)
# KV projection — single head shared across queries (MQA).
# CSA needs two streams (a, b) for the overlapping compression;
# HCA needs one. SWA needs raw KV with no compression.
if self.attn_type == AttentionType.CSA:
# 2 * head_dim for (Ca, Cb), 2 * head_dim for compression weights (Za, Zb)
self.kv_down = Nvfp4Linear(
in_features=config.hidden_size,
out_features=4 * config.head_dim,
)
elif self.attn_type == AttentionType.HCA:
# head_dim for C, head_dim for Z
self.kv_down = Nvfp4Linear(
in_features=config.hidden_size,
out_features=2 * config.head_dim,
)
else: # SWA
self.kv_down = Nvfp4Linear(
in_features=config.hidden_size,
out_features=config.head_dim,
)
# ---- Indexer (CSA only) ----
if self.attn_type == AttentionType.CSA:
# Indexer queries share the q_down latent c_Q with main queries.
# Up-projection: c_Q -> indexer_num_heads * indexer_head_dim.
self.indexer_q_up = Nvfp4Linear(
in_features=config.query_compression_dim,
out_features=config.indexer_num_heads * config.indexer_head_dim,
)
# Per-head weight vector w_h: hidden -> indexer_num_heads.
self.indexer_head_weights = Nvfp4Linear(
in_features=config.hidden_size,
out_features=config.indexer_num_heads,
)
# ---- Output projection ----
# wo_a: grouped, splits the n_heads outputs into n_groups and projects
# each group from (head_dim * n_heads / n_groups) to output_group_dim.
self.wo_a = Nvfp4GroupedLinear(
num_groups=config.num_output_groups,
in_features=(config.head_dim * config.num_query_heads) // config.num_output_groups,
out_features=config.output_group_dim,
)
# wo_b: dense, concatenated group outputs back to hidden_size.
self.wo_b = Nvfp4Linear(
in_features=config.output_group_dim * config.num_output_groups,
out_features=config.hidden_size,
)
# Attention sink: per-head learnable logit added to the softmax
# denominator (paper §2.3.3). One scalar per query head.
self.sink_logits: torch.Tensor | None = None # set by load_weights
# ------------------------------------------------------------------
def forward(
self,
x: torch.Tensor, # (T, hidden_size) BF16, post-RMSNorm
cache: "LayerCacheHandle", # per-layer cache slice
) -> torch.Tensor: # (T, hidden_size) BF16
# Branch on attention type. Frozen at construction so this is a
# constant in any traced graph.
if self.attn_type == AttentionType.CSA:
return self._forward_csa(x, cache)
elif self.attn_type == AttentionType.HCA:
return self._forward_hca(x, cache)
else:
return self._forward_swa(x, cache)
# ------------------------------------------------------------------
# CSA: compressed + sparse attention, with indexer + top-k.
# ------------------------------------------------------------------
def _forward_csa(self, x, cache):
# 1) Low-rank query projection. c_Q is shared with the indexer.
c_Q = self.q_down(x) # (T, dc)
q = self.q_up(c_Q) # (T, n_h * head_dim)
# 2) KV down-projection — produces (Ca, Cb, Za, Zb) interleaved.
kv_raw = self.kv_down(x) # (T, 4 * head_dim)
# 3) Apply partial RoPE to q and the to-be-compressed keys.
q_roped = self._apply_rope(q, positions=cache.positions)
# 4) Compress + write to cache. Returns nothing — the cache holds it.
from dsv4.kernels.compressor import csa_compress_and_store
csa_compress_and_store(
kv_raw=kv_raw,
cache=cache, # writes C_comp + Z + indexer keys
positions=cache.positions,
compression_ratio=self.config.csa_compression_ratio,
)
# 5) Indexer: compute index scores, select top-k compressed entries.
q_indexer = self.indexer_q_up(c_Q) # (T, n_I_h * c_I)
w_indexer = self.indexer_head_weights(x) # (T, n_I_h)
from dsv4.kernels.indexer import compute_index_scores_topk
topk_indices = compute_index_scores_topk(
q_indexer=q_indexer,
w_indexer=w_indexer,
cache=cache, # provides compressed indexer keys
top_k=self.config.csa_top_k,
)
# 6) FMHA over selected sparse entries + SWA branch, fused sink merge.
from dsv4.kernels.attention import sparse_fmha_with_swa
attn_out = sparse_fmha_with_swa(
q=q_roped,
cache=cache,
selected_indices=topk_indices,
sink_logits=self.sink_logits,
sliding_window=self.config.sliding_window,
) # (T, n_h * head_dim)
# 7) Inverse RoPE on outputs (paper §2.3.3).
from dsv4.ops.rope import inverse_rope_bf16
attn_out = inverse_rope_bf16(attn_out, positions=cache.positions)
# 8) Grouped output projection: wo_a then wo_b.
grouped = self.wo_a(attn_out) # (T, n_groups * d_g)
return self.wo_b(grouped) # (T, hidden_size)
# ------------------------------------------------------------------
# HCA: heavy compression, dense attention over compressed entries.
# ------------------------------------------------------------------
def _forward_hca(self, x, cache):
c_Q = self.q_down(x)
q = self.q_up(c_Q)
kv_raw = self.kv_down(x) # (T, 2 * head_dim)
q_roped = self._apply_rope(q, positions=cache.positions)
from dsv4.kernels.compressor import hca_compress_and_store
hca_compress_and_store(
kv_raw=kv_raw,
cache=cache,
positions=cache.positions,
compression_ratio=self.config.hca_compression_ratio,
)
# Dense FMHA over ALL compressed entries (not top-k) plus SWA.
from dsv4.kernels.attention import dense_fmha_with_swa
attn_out = dense_fmha_with_swa(
q=q_roped,
cache=cache,
sink_logits=self.sink_logits,
sliding_window=self.config.sliding_window,
)
from dsv4.ops.rope import inverse_rope_bf16
attn_out = inverse_rope_bf16(attn_out, positions=cache.positions)
grouped = self.wo_a(attn_out)
return self.wo_b(grouped)
# ------------------------------------------------------------------
# SWA: sliding window only, no compression branch.
# First 2 layers of Flash. Acts as the "local context" foundation.
# ------------------------------------------------------------------
def _forward_swa(self, x, cache):
c_Q = self.q_down(x)
q = self.q_up(c_Q)
kv_raw = self.kv_down(x) # (T, head_dim)
q_roped = self._apply_rope(q, positions=cache.positions)
# Write raw KV to the SWA window in the cache. No compressor.
cache.write_swa(kv_raw, positions=cache.positions)
# Dense FMHA over the sliding window only.
from dsv4.kernels.attention import swa_only_fmha
attn_out = swa_only_fmha(
q=q_roped,
cache=cache,
sink_logits=self.sink_logits,
sliding_window=self.config.sliding_window,
)
from dsv4.ops.rope import inverse_rope_bf16
attn_out = inverse_rope_bf16(attn_out, positions=cache.positions)
grouped = self.wo_a(attn_out)
return self.wo_b(grouped)
# ------------------------------------------------------------------
def _apply_rope(self, q: torch.Tensor, positions: torch.Tensor):
"""Apply RoPE to the last `rope_dim` dimensions of each head."""
from dsv4.ops.rope import forward_rope_partial
return forward_rope_partial(
q, positions=positions, rope_dim=self.config.rope_dim,
)

View File

@@ -1,2 +1,53 @@
"""FFN: router + MoE + shared expert."""
# TODO: Phase 2
"""DSV4 FFN sub-block — routed MoE + shared expert.
The router instance encapsulates hash-vs-dense; this sub-block doesn't
have to care, it just calls router(x, token_ids) and feeds the result
to Nvfp4MoE. Shared expert runs in parallel (logically — kernels
can overlap).
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from dsv4.layers.router import Router
from dsv4.layers.moe import Nvfp4MoE
from dsv4.layers.shared_expert import Nvfp4SharedExpert
from dsv4.model.layer_schedule import LayerSpec, RouterMode
if TYPE_CHECKING:
from dsv4.model.config import DSV4Config
class FFNSubBlock:
def __init__(self, config: "DSV4Config", spec: LayerSpec):
self.config = config
self.spec = spec
self.router = Router(
hidden_size=config.hidden_size,
num_experts=config.num_routed_experts,
top_k=config.num_experts_per_tok,
routed_scaling_factor=config.routed_scaling_factor,
mode="hash" if spec.router_mode == RouterMode.HASH else "dense",
vocab_size=config.vocab_size if spec.router_mode == RouterMode.HASH else None,
)
self.moe = Nvfp4MoE(
num_experts=config.num_routed_experts,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
top_k=config.num_experts_per_tok,
)
self.shared = Nvfp4SharedExpert(
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
)
def forward(
self,
x: torch.Tensor, # (T, hidden_size) BF16, post-RMSNorm
token_ids: torch.Tensor, # (T,) int32 — needed only for hash routing
) -> torch.Tensor:
topk_w, topk_ids = self.router(x, token_ids=token_ids)
routed_out = self.moe.run(x, topk_w, topk_ids)
shared_out = self.shared.run(x)
return routed_out + shared_out

View File

@@ -1,2 +1,120 @@
"""DSV4Config (Flash + Pro)."""
# TODO: Phase 1
"""DSV4 architecture configuration.
One config class for both Flash and Pro — they differ only in numeric
parameters, not structure. Construct via .flash() / .pro() classmethods
or load from a checkpoint's config.json.
"""
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True)
class DSV4Config:
# Identity
variant: str # "flash" | "pro"
vocab_size: int
# Transformer
num_layers: int # Flash: 43, Pro: 61
hidden_size: int # Flash: 4096, Pro: 7168
# Attention (shared across CSA/HCA/SWA)
num_query_heads: int # Flash: 64, Pro: 128
head_dim: int # Both: 512
rope_dim: int # Both: 64 (partial RoPE on last 64 dims)
query_compression_dim: int # Flash: 1024, Pro: 1536
# CSA-specific
csa_compression_ratio: int # m, both: 4
csa_top_k: int # Flash: 512, Pro: 1024
indexer_num_heads: int # Both: 64
indexer_head_dim: int # Both: 128
# HCA-specific
hca_compression_ratio: int # m', both: 128
# SWA branch
sliding_window: int # n_win, both: 128
# Output projection grouping
num_output_groups: int # Flash: 8, Pro: 16
output_group_dim: int # Both: 1024
# MoE
num_routed_experts: int # Flash: 256, Pro: 384
num_shared_experts: int # Both: 1
num_experts_per_tok: int # Both: 6
moe_intermediate_size: int # Flash: 2048, Pro: 3072
num_hash_routing_layers: int # Both: 3
routed_scaling_factor: float # Verify per checkpoint
# mHC
n_hc: int # Both: 4
sinkhorn_iters: int # Both: 20
# Architecture schedule — first N layers use a non-standard attention
# pattern (paper §4.2.1).
# Flash: layers 0-1 = SWA only
# Pro: layers 0-1 = HCA only
num_special_first_layers: int = 2
# Norm
rms_norm_eps: float = 1e-6
@classmethod
def flash(cls) -> "DSV4Config":
return cls(
variant="flash",
vocab_size=128_000,
num_layers=43,
hidden_size=4096,
num_query_heads=64,
head_dim=512,
rope_dim=64,
query_compression_dim=1024,
csa_compression_ratio=4,
csa_top_k=512,
indexer_num_heads=64,
indexer_head_dim=128,
hca_compression_ratio=128,
sliding_window=128,
num_output_groups=8,
output_group_dim=1024,
num_routed_experts=256,
num_shared_experts=1,
num_experts_per_tok=6,
moe_intermediate_size=2048,
num_hash_routing_layers=3,
routed_scaling_factor=2.5,
n_hc=4,
sinkhorn_iters=20,
)
@classmethod
def pro(cls) -> "DSV4Config":
return cls(
variant="pro",
vocab_size=128_000,
num_layers=61,
hidden_size=7168,
num_query_heads=128,
head_dim=512,
rope_dim=64,
query_compression_dim=1536,
csa_compression_ratio=4,
csa_top_k=1024,
indexer_num_heads=64,
indexer_head_dim=128,
hca_compression_ratio=128,
sliding_window=128,
num_output_groups=16,
output_group_dim=1024,
num_routed_experts=384,
num_shared_experts=1,
num_experts_per_tok=6,
moe_intermediate_size=3072,
num_hash_routing_layers=3,
routed_scaling_factor=2.5,
n_hc=4,
sinkhorn_iters=20,
)

View File

@@ -1,2 +1,82 @@
"""Single transformer layer."""
# TODO: Phase 1
"""A single DSV4 transformer layer.
Structure (paper Figure 2):
X_l ─→ mHC.pre_block ─→ RMSNorm ─→ Attention ─→ mHC.post_block (using F_attn)
mHC.pre_block ─→ RMSNorm ─→ FFN ─→ mHC.post_block (using F_ffn)
X_{l+1}
Each layer owns:
- One LayerSpec (from build_schedule).
- Two mHC instances (one per sub-block).
- One AttentionSubBlock (type fixed by spec.attn).
- One FFNSubBlock (router mode fixed by spec.router_mode).
- Two RMSNorm weight tensors.
The layer is otherwise pure orchestration: no learned params live
directly on TransformerLayer, only on its components.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from dsv4.layers.mhc import mHCLayer
from dsv4.layers.attention import AttentionSubBlock
from dsv4.layers.ffn import FFNSubBlock
from dsv4.layers.norm import RMSNorm # PyTorch ref for now, fused later
from dsv4.model.layer_schedule import LayerSpec
if TYPE_CHECKING:
from dsv4.model.config import DSV4Config
from dsv4.cache.paged_cache import LayerCacheHandle
class TransformerLayer:
def __init__(self, config: "DSV4Config", spec: LayerSpec):
self.config = config
self.spec = spec
self.layer_idx = spec.layer_idx
# Two mHC wrappers — one per sub-block. mHCLayer holds its own
# projection weights (W_pre, W_res, W_post) and static biases.
self.mhc_attn = mHCLayer(
hidden_size=config.hidden_size,
n_hc=config.n_hc,
sinkhorn_iters=config.sinkhorn_iters,
)
self.mhc_ffn = mHCLayer(
hidden_size=config.hidden_size,
n_hc=config.n_hc,
sinkhorn_iters=config.sinkhorn_iters,
)
# Pre-block norms (one per sub-block).
self.norm_attn = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm_ffn = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Sub-blocks — type-frozen at construction.
self.attn = AttentionSubBlock(config, spec)
self.ffn = FFNSubBlock(config, spec)
def forward(
self,
X: torch.Tensor, # (T, n_hc, hidden_size) BF16 — residual streams
token_ids: torch.Tensor, # (T,) int32 — for hash routing
cache: "LayerCacheHandle",
) -> torch.Tensor:
# ---- Attention sub-block ----
x_attn_in, ctx_attn = self.mhc_attn.pre_block(X)
x_attn_in = self.norm_attn(x_attn_in)
F_attn = self.attn.forward(x_attn_in, cache)
X = self.mhc_attn.post_block(X, F_attn, ctx_attn)
# ---- FFN sub-block ----
x_ffn_in, ctx_ffn = self.mhc_ffn.pre_block(X)
x_ffn_in = self.norm_ffn(x_ffn_in)
F_ffn = self.ffn.forward(x_ffn_in, token_ids)
X = self.mhc_ffn.post_block(X, F_ffn, ctx_ffn)
return X

View File

@@ -0,0 +1,102 @@
"""Per-layer-index architecture schedule for DSV4.
Encodes the attention / FFN / routing pattern from the paper. Kept
separate from layer.py so the KV cache allocator can query it without
instantiating any layers.
"""
from __future__ import annotations
from enum import Enum
from dataclasses import dataclass
from typing import List
from dsv4.model.config import DSV4Config
class AttentionType(Enum):
SWA = "swa" # Sliding window only (Flash layers 0-1)
HCA = "hca" # Heavily Compressed Attention (Pro layers 0-1, alternating later)
CSA = "csa" # Compressed Sparse Attention (alternating from layer 2)
class FFNType(Enum):
MOE = "moe" # DeepSeekMoE — both variants use MoE in every layer
class RouterMode(Enum):
HASH = "hash" # Layers 0-2: deterministic by token ID
DENSE = "dense" # Layers 3+: sqrt(softplus) + top-k
@dataclass(frozen=True)
class LayerSpec:
"""The complete architectural type of a single transformer layer."""
layer_idx: int
attn: AttentionType
ffn: FFNType
router_mode: RouterMode
def build_schedule(config: DSV4Config) -> List[LayerSpec]:
"""Produce the full per-layer schedule for a given config.
Paper §4.2.1:
Flash: first 2 layers = pure SWA, then alternating CSA / HCA.
Pro: first 2 layers = HCA, then alternating CSA / HCA.
Both: first 3 MoE layers use hash routing.
'Alternating CSA / HCA' starts with CSA at layer 2 in both variants.
"""
schedule: List[LayerSpec] = []
first_special = config.num_special_first_layers # 2
for i in range(config.num_layers):
# Attention type
if i < first_special:
if config.variant == "flash":
attn = AttentionType.SWA
else: # pro
attn = AttentionType.HCA
else:
# Layer 2 is CSA, layer 3 is HCA, layer 4 is CSA, ...
attn = AttentionType.CSA if (i - first_special) % 2 == 0 else AttentionType.HCA
# FFN type — DSV4 is MoE in every layer.
ffn = FFNType.MOE
# Router mode — hash for the first N MoE layers.
router = RouterMode.HASH if i < config.num_hash_routing_layers else RouterMode.DENSE
schedule.append(LayerSpec(layer_idx=i, attn=attn, ffn=ffn, router_mode=router))
return schedule
def validate_schedule(schedule: List[LayerSpec], config: DSV4Config) -> None:
"""Sanity checks. Wrong schedule = silent garbage, so be loud here."""
assert len(schedule) == config.num_layers, \
f"schedule has {len(schedule)} layers, config says {config.num_layers}"
# First N hash routing layers
for i in range(config.num_hash_routing_layers):
assert schedule[i].router_mode == RouterMode.HASH, \
f"layer {i} should be HASH-routed"
for i in range(config.num_hash_routing_layers, config.num_layers):
assert schedule[i].router_mode == RouterMode.DENSE, \
f"layer {i} should be DENSE-routed"
# First two layers are special
expected_first = AttentionType.SWA if config.variant == "flash" else AttentionType.HCA
for i in range(config.num_special_first_layers):
assert schedule[i].attn == expected_first, \
f"layer {i} should be {expected_first} for {config.variant}"
# Alternation from layer 2
for i in range(config.num_special_first_layers, config.num_layers):
expected = AttentionType.CSA if (i - config.num_special_first_layers) % 2 == 0 \
else AttentionType.HCA
assert schedule[i].attn == expected, \
f"layer {i} should be {expected}, got {schedule[i].attn}"
# Every layer is MoE
for i, spec in enumerate(schedule):
assert spec.ffn == FFNType.MOE, f"layer {i} should be MOE"

97
run_router_tests.py Normal file
View File

@@ -0,0 +1,97 @@
#!/usr/bin/env python3
import sys, os
os.chdir('/root/dsv4-nvfp4-workspace/kernel')
sys.path.insert(0, '.')
import torch
from torch.utils.cpp_extension import load
def pf(ok):
return 'PASS' if ok else 'FAIL'
print(f'CUDA: {torch.cuda.is_available()}, Device: {torch.cuda.get_device_name(0)}', flush=True)
print(f'GPU mem free: {torch.cuda.mem_get_info()[0]//1024//1024} MB', flush=True)
# 1. Hash Router
print('\n=== Hash Router ===', flush=True)
try:
hr = load(name='hr2', sources=['dsv4/kernels/cuda/hash_router.cu'],
extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False)
for N in [1, 4, 64, 128, 512]:
vocab, k, E = 128000, 6, 256
lut = torch.randint(0, E, (vocab, k), dtype=torch.int32, device='cuda')
tids = torch.randint(0, vocab, (N,), dtype=torch.int32, device='cuda')
ow = torch.empty(N, k, dtype=torch.float32, device='cuda')
oi = torch.empty(N, k, dtype=torch.int32, device='cuda')
hr.hash_router(tids, lut, k, ow, oi)
torch.cuda.synchronize()
exp_ids = lut[tids]
exp_w = torch.full((N, k), 1.0/k, dtype=torch.float32, device='cuda')
ids_ok = (oi == exp_ids).all().item()
w_ok = torch.allclose(ow, exp_w, atol=1e-7, rtol=1e-7)
ok = ids_ok and w_ok
print(f' N={N:4d}: IDs={ids_ok} W={w_ok} {pf(ok)}', flush=True)
del lut, tids, ow, oi, exp_ids, exp_w
print('Hash Router: ALL PASS', flush=True)
except Exception as e:
import traceback; traceback.print_exc()
torch.cuda.empty_cache()
# 2. Top-k Select
print('\n=== Top-k Select ===', flush=True)
try:
tk = load(name='tk2', sources=['dsv4/kernels/cuda/topk_select.cu'],
extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False)
for N, E in [(1,256), (4,256), (64,256), (64,384), (128,256), (512,256)]:
k = 6
scores = torch.randn(N, E, dtype=torch.float32, device='cuda')
ov, oidx = tk.topk_select(scores, k)
torch.cuda.synchronize()
exp = scores.topk(k, dim=-1)
ids_ok = (oidx == exp.indices).all().item()
vals_ok = torch.allclose(ov, exp.values, atol=1e-6, rtol=1e-6)
ok = ids_ok and vals_ok
print(f' N={N:4d} E={E}: IDs={ids_ok} V={vals_ok} {pf(ok)}', flush=True)
del scores, ov, oidx, exp
print('Top-k Select: ALL PASS', flush=True)
except Exception as e:
import traceback; traceback.print_exc()
torch.cuda.empty_cache()
# 3. Activation + Top-k
print('\n=== Activation + Top-k ===', flush=True)
try:
atk = load(name='atk2', sources=['dsv4/kernels/cuda/activation_topk.cu'],
extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False)
for N, E in [(1,256), (4,256), (64,256), (64,384)]:
k = 6
logits = torch.randn(N, E, dtype=torch.float32, device='cuda')
bias = torch.randn(E, dtype=torch.float32, device='cuda')
scaling = 2.5
out_w = torch.empty(N, k, dtype=torch.float32, device='cuda')
out_ids = torch.empty(N, k, dtype=torch.int32, device='cuda')
atk.activation_topk(logits, bias, k, scaling, out_w, out_ids)
torch.cuda.synchronize()
# Oracle
sp = torch.log1p(logits.exp()) + torch.clamp(logits, min=0)
act = sp.sqrt()
score = act + bias
exp_topk = score.topk(k, dim=-1)
exp_ids = exp_topk.indices
exp_w = torch.gather(act, 1, exp_ids)
exp_w = exp_w / exp_w.sum(dim=-1, keepdim=True) * scaling
ids_ok = (out_ids == exp_ids).all().item()
vals_ok = torch.allclose(out_w, exp_w, atol=1e-5, rtol=1e-5)
ok = ids_ok and vals_ok
print(f' N={N:4d} E={E}: IDs={ids_ok} V={vals_ok} {pf(ok)}', flush=True)
if not ids_ok:
print(f' ID mismatches: {(out_ids != exp_ids).sum().item()}/{out_ids.numel()}', flush=True)
if not vals_ok:
print(f' Max diff: {(out_w - exp_w).abs().max().item():.2e}', flush=True)
del logits, bias, out_w, out_ids, sp, act, score, exp_topk, exp_ids, exp_w
print('Activation+Topk: ALL PASS', flush=True)
except Exception as e:
import traceback; traceback.print_exc()
print('\n=== DONE ===', flush=True)

View File

@@ -98,11 +98,11 @@ class FmhaV3TenWarp:
tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True)

View File

@@ -0,0 +1,85 @@
"""Tests for layer schedule — pure data, no kernels, no tensors."""
from dsv4.model.config import DSV4Config
from dsv4.model.layer_schedule import (
AttentionType, FFNType, RouterMode,
LayerSpec, build_schedule, validate_schedule,
)
def test_flash_schedule():
config = DSV4Config.flash()
schedule = build_schedule(config)
validate_schedule(schedule, config)
assert len(schedule) == 43
# First two layers: SWA + hash routing
assert schedule[0].attn == AttentionType.SWA
assert schedule[1].attn == AttentionType.SWA
assert schedule[0].router_mode == RouterMode.HASH
assert schedule[1].router_mode == RouterMode.HASH
# Layer 2: CSA + hash routing (last hash layer)
assert schedule[2].attn == AttentionType.CSA
assert schedule[2].router_mode == RouterMode.HASH
# Layer 3: HCA + dense routing (first dense layer)
assert schedule[3].attn == AttentionType.HCA
assert schedule[3].router_mode == RouterMode.DENSE
# Alternation continues
assert schedule[4].attn == AttentionType.CSA
assert schedule[5].attn == AttentionType.HCA
# All layers are MoE
for spec in schedule:
assert spec.ffn == FFNType.MOE
def test_pro_schedule():
config = DSV4Config.pro()
schedule = build_schedule(config)
validate_schedule(schedule, config)
assert len(schedule) == 61
# First two layers: HCA + hash routing
assert schedule[0].attn == AttentionType.HCA
assert schedule[1].attn == AttentionType.HCA
assert schedule[0].router_mode == RouterMode.HASH
# Layer 2: CSA + hash routing
assert schedule[2].attn == AttentionType.CSA
assert schedule[2].router_mode == RouterMode.HASH
# Layer 3: HCA + dense routing
assert schedule[3].attn == AttentionType.HCA
assert schedule[3].router_mode == RouterMode.DENSE
def test_layer_spec_frozen():
"""LayerSpec is frozen — mutation should raise."""
config = DSV4Config.flash()
spec = build_schedule(config)[0]
try:
spec.attn = AttentionType.HCA
assert False, "should have raised"
except AttributeError:
pass
def test_schedule_indices_match():
"""Each LayerSpec.layer_idx matches its position in the list."""
config = DSV4Config.flash()
schedule = build_schedule(config)
for i, spec in enumerate(schedule):
assert spec.layer_idx == i
if __name__ == "__main__":
test_flash_schedule()
test_pro_schedule()
test_layer_spec_frozen()
test_schedule_indices_match()
print("All schedule tests passed")