Layer dispatch: config, schedule, attention/FFN sub-blocks, TransformerLayer
DSV4Config: frozen dataclass with .flash() / .pro() classmethods. All architectural constants (dims, heads, MoE params, mHC) in one place. LayerSchedule: pure-data per-layer-index -> (attn_type, ffn_type, router_mode). Flash: SWA, SWA, CSA, HCA, CSA, HCA, ... (43 layers) Pro: HCA, HCA, CSA, HCA, CSA, HCA, ... (61 layers) Both: first 3 MoE layers = hash routing, rest = dense validate_schedule() enforces correctness at construction. AttentionSubBlock: CSA / HCA / SWA variants. - Low-rank Q projection (q_down -> q_up) - KV down-projection (varies by attn type: 4h/2h/1h) - CSA: indexer_q_up + indexer_head_weights - Grouped output projection (wo_a + wo_b) - Kernel calls are imports (NotImplementedError until kernel lands) - No PyTorch fallback paths FFNSubBlock: MoE + shared expert. - Router (hash/dense) mode from LayerSpec - Nvfp4MoE + Nvfp4SharedExpert TransformerLayer: composition of mHC + norm + attention + FFN. - Two mHC wrappers (attn + ffn sub-blocks) - Two RMSNorm (one per sub-block) - Pure orchestration, no learned params on the layer itself Tests: schedule construction + validation for both variants. No forward tests yet (depends on FMHA kernel + KV cache).
This commit is contained in:
@@ -1,2 +1,242 @@
|
||||
"""DSV4 attention sub-block."""
|
||||
# TODO: Phase 3+4
|
||||
"""DSV4 attention sub-block — CSA / HCA / SWA variants.
|
||||
|
||||
The constructor decides which variant this instance is by reading the
|
||||
LayerSpec passed in. After construction, forward() is fixed — no runtime
|
||||
branching on attention type, so torch.compile can fold the dispatch
|
||||
completely.
|
||||
|
||||
What this class is responsible for:
|
||||
- Holding the projection weights (Q-down/up, KV-down, output groups).
|
||||
- Calling Nvfp4Linear for the projections.
|
||||
- Running RoPE (partial, last rope_dim dimensions).
|
||||
- Calling the compressor (CSA or HCA) to write into the KV cache.
|
||||
- For CSA only: calling the indexer + top-k to select sparse entries.
|
||||
- Calling the FMHA kernel with the right K/V source.
|
||||
- Calling inverse RoPE on the output.
|
||||
- Calling Nvfp4GroupedLinear (wo_a) + Nvfp4Linear (wo_b).
|
||||
|
||||
What this class does NOT touch:
|
||||
- mHC pre/post — done by TransformerLayer.
|
||||
- Pre-norm — done by TransformerLayer.
|
||||
- KV cache allocation / eviction — done by the cache handle.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
from dsv4.layers.grouped_linear import Nvfp4GroupedLinear
|
||||
from dsv4.model.layer_schedule import AttentionType, LayerSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dsv4.model.config import DSV4Config
|
||||
from dsv4.cache.paged_cache import LayerCacheHandle
|
||||
|
||||
|
||||
class AttentionSubBlock:
|
||||
def __init__(self, config: "DSV4Config", spec: LayerSpec):
|
||||
self.config = config
|
||||
self.spec = spec
|
||||
self.layer_idx = spec.layer_idx
|
||||
self.attn_type = spec.attn # frozen for this layer's lifetime
|
||||
|
||||
# ---- Projections (all NVFP4) ----
|
||||
|
||||
# Query path is low-rank: hidden -> dc -> head_dim * num_heads.
|
||||
self.q_down = Nvfp4Linear(
|
||||
in_features=config.hidden_size,
|
||||
out_features=config.query_compression_dim,
|
||||
)
|
||||
self.q_up = Nvfp4Linear(
|
||||
in_features=config.query_compression_dim,
|
||||
out_features=config.head_dim * config.num_query_heads,
|
||||
)
|
||||
|
||||
# KV projection — single head shared across queries (MQA).
|
||||
# CSA needs two streams (a, b) for the overlapping compression;
|
||||
# HCA needs one. SWA needs raw KV with no compression.
|
||||
if self.attn_type == AttentionType.CSA:
|
||||
# 2 * head_dim for (Ca, Cb), 2 * head_dim for compression weights (Za, Zb)
|
||||
self.kv_down = Nvfp4Linear(
|
||||
in_features=config.hidden_size,
|
||||
out_features=4 * config.head_dim,
|
||||
)
|
||||
elif self.attn_type == AttentionType.HCA:
|
||||
# head_dim for C, head_dim for Z
|
||||
self.kv_down = Nvfp4Linear(
|
||||
in_features=config.hidden_size,
|
||||
out_features=2 * config.head_dim,
|
||||
)
|
||||
else: # SWA
|
||||
self.kv_down = Nvfp4Linear(
|
||||
in_features=config.hidden_size,
|
||||
out_features=config.head_dim,
|
||||
)
|
||||
|
||||
# ---- Indexer (CSA only) ----
|
||||
if self.attn_type == AttentionType.CSA:
|
||||
# Indexer queries share the q_down latent c_Q with main queries.
|
||||
# Up-projection: c_Q -> indexer_num_heads * indexer_head_dim.
|
||||
self.indexer_q_up = Nvfp4Linear(
|
||||
in_features=config.query_compression_dim,
|
||||
out_features=config.indexer_num_heads * config.indexer_head_dim,
|
||||
)
|
||||
# Per-head weight vector w_h: hidden -> indexer_num_heads.
|
||||
self.indexer_head_weights = Nvfp4Linear(
|
||||
in_features=config.hidden_size,
|
||||
out_features=config.indexer_num_heads,
|
||||
)
|
||||
|
||||
# ---- Output projection ----
|
||||
# wo_a: grouped, splits the n_heads outputs into n_groups and projects
|
||||
# each group from (head_dim * n_heads / n_groups) to output_group_dim.
|
||||
self.wo_a = Nvfp4GroupedLinear(
|
||||
num_groups=config.num_output_groups,
|
||||
in_features=(config.head_dim * config.num_query_heads) // config.num_output_groups,
|
||||
out_features=config.output_group_dim,
|
||||
)
|
||||
# wo_b: dense, concatenated group outputs back to hidden_size.
|
||||
self.wo_b = Nvfp4Linear(
|
||||
in_features=config.output_group_dim * config.num_output_groups,
|
||||
out_features=config.hidden_size,
|
||||
)
|
||||
|
||||
# Attention sink: per-head learnable logit added to the softmax
|
||||
# denominator (paper §2.3.3). One scalar per query head.
|
||||
self.sink_logits: torch.Tensor | None = None # set by load_weights
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor, # (T, hidden_size) BF16, post-RMSNorm
|
||||
cache: "LayerCacheHandle", # per-layer cache slice
|
||||
) -> torch.Tensor: # (T, hidden_size) BF16
|
||||
# Branch on attention type. Frozen at construction so this is a
|
||||
# constant in any traced graph.
|
||||
if self.attn_type == AttentionType.CSA:
|
||||
return self._forward_csa(x, cache)
|
||||
elif self.attn_type == AttentionType.HCA:
|
||||
return self._forward_hca(x, cache)
|
||||
else:
|
||||
return self._forward_swa(x, cache)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# CSA: compressed + sparse attention, with indexer + top-k.
|
||||
# ------------------------------------------------------------------
|
||||
def _forward_csa(self, x, cache):
|
||||
# 1) Low-rank query projection. c_Q is shared with the indexer.
|
||||
c_Q = self.q_down(x) # (T, dc)
|
||||
q = self.q_up(c_Q) # (T, n_h * head_dim)
|
||||
|
||||
# 2) KV down-projection — produces (Ca, Cb, Za, Zb) interleaved.
|
||||
kv_raw = self.kv_down(x) # (T, 4 * head_dim)
|
||||
|
||||
# 3) Apply partial RoPE to q and the to-be-compressed keys.
|
||||
q_roped = self._apply_rope(q, positions=cache.positions)
|
||||
|
||||
# 4) Compress + write to cache. Returns nothing — the cache holds it.
|
||||
from dsv4.kernels.compressor import csa_compress_and_store
|
||||
csa_compress_and_store(
|
||||
kv_raw=kv_raw,
|
||||
cache=cache, # writes C_comp + Z + indexer keys
|
||||
positions=cache.positions,
|
||||
compression_ratio=self.config.csa_compression_ratio,
|
||||
)
|
||||
|
||||
# 5) Indexer: compute index scores, select top-k compressed entries.
|
||||
q_indexer = self.indexer_q_up(c_Q) # (T, n_I_h * c_I)
|
||||
w_indexer = self.indexer_head_weights(x) # (T, n_I_h)
|
||||
|
||||
from dsv4.kernels.indexer import compute_index_scores_topk
|
||||
topk_indices = compute_index_scores_topk(
|
||||
q_indexer=q_indexer,
|
||||
w_indexer=w_indexer,
|
||||
cache=cache, # provides compressed indexer keys
|
||||
top_k=self.config.csa_top_k,
|
||||
)
|
||||
|
||||
# 6) FMHA over selected sparse entries + SWA branch, fused sink merge.
|
||||
from dsv4.kernels.attention import sparse_fmha_with_swa
|
||||
attn_out = sparse_fmha_with_swa(
|
||||
q=q_roped,
|
||||
cache=cache,
|
||||
selected_indices=topk_indices,
|
||||
sink_logits=self.sink_logits,
|
||||
sliding_window=self.config.sliding_window,
|
||||
) # (T, n_h * head_dim)
|
||||
|
||||
# 7) Inverse RoPE on outputs (paper §2.3.3).
|
||||
from dsv4.ops.rope import inverse_rope_bf16
|
||||
attn_out = inverse_rope_bf16(attn_out, positions=cache.positions)
|
||||
|
||||
# 8) Grouped output projection: wo_a then wo_b.
|
||||
grouped = self.wo_a(attn_out) # (T, n_groups * d_g)
|
||||
return self.wo_b(grouped) # (T, hidden_size)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HCA: heavy compression, dense attention over compressed entries.
|
||||
# ------------------------------------------------------------------
|
||||
def _forward_hca(self, x, cache):
|
||||
c_Q = self.q_down(x)
|
||||
q = self.q_up(c_Q)
|
||||
kv_raw = self.kv_down(x) # (T, 2 * head_dim)
|
||||
q_roped = self._apply_rope(q, positions=cache.positions)
|
||||
|
||||
from dsv4.kernels.compressor import hca_compress_and_store
|
||||
hca_compress_and_store(
|
||||
kv_raw=kv_raw,
|
||||
cache=cache,
|
||||
positions=cache.positions,
|
||||
compression_ratio=self.config.hca_compression_ratio,
|
||||
)
|
||||
|
||||
# Dense FMHA over ALL compressed entries (not top-k) plus SWA.
|
||||
from dsv4.kernels.attention import dense_fmha_with_swa
|
||||
attn_out = dense_fmha_with_swa(
|
||||
q=q_roped,
|
||||
cache=cache,
|
||||
sink_logits=self.sink_logits,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
from dsv4.ops.rope import inverse_rope_bf16
|
||||
attn_out = inverse_rope_bf16(attn_out, positions=cache.positions)
|
||||
|
||||
grouped = self.wo_a(attn_out)
|
||||
return self.wo_b(grouped)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SWA: sliding window only, no compression branch.
|
||||
# First 2 layers of Flash. Acts as the "local context" foundation.
|
||||
# ------------------------------------------------------------------
|
||||
def _forward_swa(self, x, cache):
|
||||
c_Q = self.q_down(x)
|
||||
q = self.q_up(c_Q)
|
||||
kv_raw = self.kv_down(x) # (T, head_dim)
|
||||
q_roped = self._apply_rope(q, positions=cache.positions)
|
||||
|
||||
# Write raw KV to the SWA window in the cache. No compressor.
|
||||
cache.write_swa(kv_raw, positions=cache.positions)
|
||||
|
||||
# Dense FMHA over the sliding window only.
|
||||
from dsv4.kernels.attention import swa_only_fmha
|
||||
attn_out = swa_only_fmha(
|
||||
q=q_roped,
|
||||
cache=cache,
|
||||
sink_logits=self.sink_logits,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
from dsv4.ops.rope import inverse_rope_bf16
|
||||
attn_out = inverse_rope_bf16(attn_out, positions=cache.positions)
|
||||
|
||||
grouped = self.wo_a(attn_out)
|
||||
return self.wo_b(grouped)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
def _apply_rope(self, q: torch.Tensor, positions: torch.Tensor):
|
||||
"""Apply RoPE to the last `rope_dim` dimensions of each head."""
|
||||
from dsv4.ops.rope import forward_rope_partial
|
||||
return forward_rope_partial(
|
||||
q, positions=positions, rope_dim=self.config.rope_dim,
|
||||
)
|
||||
|
||||
@@ -1,2 +1,53 @@
|
||||
"""FFN: router + MoE + shared expert."""
|
||||
# TODO: Phase 2
|
||||
"""DSV4 FFN sub-block — routed MoE + shared expert.
|
||||
|
||||
The router instance encapsulates hash-vs-dense; this sub-block doesn't
|
||||
have to care, it just calls router(x, token_ids) and feeds the result
|
||||
to Nvfp4MoE. Shared expert runs in parallel (logically — kernels
|
||||
can overlap).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
from dsv4.layers.router import Router
|
||||
from dsv4.layers.moe import Nvfp4MoE
|
||||
from dsv4.layers.shared_expert import Nvfp4SharedExpert
|
||||
from dsv4.model.layer_schedule import LayerSpec, RouterMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dsv4.model.config import DSV4Config
|
||||
|
||||
|
||||
class FFNSubBlock:
|
||||
def __init__(self, config: "DSV4Config", spec: LayerSpec):
|
||||
self.config = config
|
||||
self.spec = spec
|
||||
|
||||
self.router = Router(
|
||||
hidden_size=config.hidden_size,
|
||||
num_experts=config.num_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
routed_scaling_factor=config.routed_scaling_factor,
|
||||
mode="hash" if spec.router_mode == RouterMode.HASH else "dense",
|
||||
vocab_size=config.vocab_size if spec.router_mode == RouterMode.HASH else None,
|
||||
)
|
||||
self.moe = Nvfp4MoE(
|
||||
num_experts=config.num_routed_experts,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
top_k=config.num_experts_per_tok,
|
||||
)
|
||||
self.shared = Nvfp4SharedExpert(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor, # (T, hidden_size) BF16, post-RMSNorm
|
||||
token_ids: torch.Tensor, # (T,) int32 — needed only for hash routing
|
||||
) -> torch.Tensor:
|
||||
topk_w, topk_ids = self.router(x, token_ids=token_ids)
|
||||
routed_out = self.moe.run(x, topk_w, topk_ids)
|
||||
shared_out = self.shared.run(x)
|
||||
return routed_out + shared_out
|
||||
|
||||
@@ -1,2 +1,120 @@
|
||||
"""DSV4Config (Flash + Pro)."""
|
||||
# TODO: Phase 1
|
||||
"""DSV4 architecture configuration.
|
||||
|
||||
One config class for both Flash and Pro — they differ only in numeric
|
||||
parameters, not structure. Construct via .flash() / .pro() classmethods
|
||||
or load from a checkpoint's config.json.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DSV4Config:
|
||||
# Identity
|
||||
variant: str # "flash" | "pro"
|
||||
vocab_size: int
|
||||
|
||||
# Transformer
|
||||
num_layers: int # Flash: 43, Pro: 61
|
||||
hidden_size: int # Flash: 4096, Pro: 7168
|
||||
|
||||
# Attention (shared across CSA/HCA/SWA)
|
||||
num_query_heads: int # Flash: 64, Pro: 128
|
||||
head_dim: int # Both: 512
|
||||
rope_dim: int # Both: 64 (partial RoPE on last 64 dims)
|
||||
query_compression_dim: int # Flash: 1024, Pro: 1536
|
||||
|
||||
# CSA-specific
|
||||
csa_compression_ratio: int # m, both: 4
|
||||
csa_top_k: int # Flash: 512, Pro: 1024
|
||||
indexer_num_heads: int # Both: 64
|
||||
indexer_head_dim: int # Both: 128
|
||||
|
||||
# HCA-specific
|
||||
hca_compression_ratio: int # m', both: 128
|
||||
|
||||
# SWA branch
|
||||
sliding_window: int # n_win, both: 128
|
||||
|
||||
# Output projection grouping
|
||||
num_output_groups: int # Flash: 8, Pro: 16
|
||||
output_group_dim: int # Both: 1024
|
||||
|
||||
# MoE
|
||||
num_routed_experts: int # Flash: 256, Pro: 384
|
||||
num_shared_experts: int # Both: 1
|
||||
num_experts_per_tok: int # Both: 6
|
||||
moe_intermediate_size: int # Flash: 2048, Pro: 3072
|
||||
num_hash_routing_layers: int # Both: 3
|
||||
routed_scaling_factor: float # Verify per checkpoint
|
||||
|
||||
# mHC
|
||||
n_hc: int # Both: 4
|
||||
sinkhorn_iters: int # Both: 20
|
||||
|
||||
# Architecture schedule — first N layers use a non-standard attention
|
||||
# pattern (paper §4.2.1).
|
||||
# Flash: layers 0-1 = SWA only
|
||||
# Pro: layers 0-1 = HCA only
|
||||
num_special_first_layers: int = 2
|
||||
|
||||
# Norm
|
||||
rms_norm_eps: float = 1e-6
|
||||
|
||||
@classmethod
|
||||
def flash(cls) -> "DSV4Config":
|
||||
return cls(
|
||||
variant="flash",
|
||||
vocab_size=128_000,
|
||||
num_layers=43,
|
||||
hidden_size=4096,
|
||||
num_query_heads=64,
|
||||
head_dim=512,
|
||||
rope_dim=64,
|
||||
query_compression_dim=1024,
|
||||
csa_compression_ratio=4,
|
||||
csa_top_k=512,
|
||||
indexer_num_heads=64,
|
||||
indexer_head_dim=128,
|
||||
hca_compression_ratio=128,
|
||||
sliding_window=128,
|
||||
num_output_groups=8,
|
||||
output_group_dim=1024,
|
||||
num_routed_experts=256,
|
||||
num_shared_experts=1,
|
||||
num_experts_per_tok=6,
|
||||
moe_intermediate_size=2048,
|
||||
num_hash_routing_layers=3,
|
||||
routed_scaling_factor=2.5,
|
||||
n_hc=4,
|
||||
sinkhorn_iters=20,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def pro(cls) -> "DSV4Config":
|
||||
return cls(
|
||||
variant="pro",
|
||||
vocab_size=128_000,
|
||||
num_layers=61,
|
||||
hidden_size=7168,
|
||||
num_query_heads=128,
|
||||
head_dim=512,
|
||||
rope_dim=64,
|
||||
query_compression_dim=1536,
|
||||
csa_compression_ratio=4,
|
||||
csa_top_k=1024,
|
||||
indexer_num_heads=64,
|
||||
indexer_head_dim=128,
|
||||
hca_compression_ratio=128,
|
||||
sliding_window=128,
|
||||
num_output_groups=16,
|
||||
output_group_dim=1024,
|
||||
num_routed_experts=384,
|
||||
num_shared_experts=1,
|
||||
num_experts_per_tok=6,
|
||||
moe_intermediate_size=3072,
|
||||
num_hash_routing_layers=3,
|
||||
routed_scaling_factor=2.5,
|
||||
n_hc=4,
|
||||
sinkhorn_iters=20,
|
||||
)
|
||||
|
||||
@@ -1,2 +1,82 @@
|
||||
"""Single transformer layer."""
|
||||
# TODO: Phase 1
|
||||
"""A single DSV4 transformer layer.
|
||||
|
||||
Structure (paper Figure 2):
|
||||
X_l ─→ mHC.pre_block ─→ RMSNorm ─→ Attention ─→ mHC.post_block (using F_attn)
|
||||
│
|
||||
▼
|
||||
mHC.pre_block ─→ RMSNorm ─→ FFN ─→ mHC.post_block (using F_ffn)
|
||||
│
|
||||
▼
|
||||
X_{l+1}
|
||||
|
||||
Each layer owns:
|
||||
- One LayerSpec (from build_schedule).
|
||||
- Two mHC instances (one per sub-block).
|
||||
- One AttentionSubBlock (type fixed by spec.attn).
|
||||
- One FFNSubBlock (router mode fixed by spec.router_mode).
|
||||
- Two RMSNorm weight tensors.
|
||||
|
||||
The layer is otherwise pure orchestration: no learned params live
|
||||
directly on TransformerLayer, only on its components.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
from dsv4.layers.mhc import mHCLayer
|
||||
from dsv4.layers.attention import AttentionSubBlock
|
||||
from dsv4.layers.ffn import FFNSubBlock
|
||||
from dsv4.layers.norm import RMSNorm # PyTorch ref for now, fused later
|
||||
from dsv4.model.layer_schedule import LayerSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dsv4.model.config import DSV4Config
|
||||
from dsv4.cache.paged_cache import LayerCacheHandle
|
||||
|
||||
|
||||
class TransformerLayer:
|
||||
def __init__(self, config: "DSV4Config", spec: LayerSpec):
|
||||
self.config = config
|
||||
self.spec = spec
|
||||
self.layer_idx = spec.layer_idx
|
||||
|
||||
# Two mHC wrappers — one per sub-block. mHCLayer holds its own
|
||||
# projection weights (W_pre, W_res, W_post) and static biases.
|
||||
self.mhc_attn = mHCLayer(
|
||||
hidden_size=config.hidden_size,
|
||||
n_hc=config.n_hc,
|
||||
sinkhorn_iters=config.sinkhorn_iters,
|
||||
)
|
||||
self.mhc_ffn = mHCLayer(
|
||||
hidden_size=config.hidden_size,
|
||||
n_hc=config.n_hc,
|
||||
sinkhorn_iters=config.sinkhorn_iters,
|
||||
)
|
||||
|
||||
# Pre-block norms (one per sub-block).
|
||||
self.norm_attn = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.norm_ffn = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
# Sub-blocks — type-frozen at construction.
|
||||
self.attn = AttentionSubBlock(config, spec)
|
||||
self.ffn = FFNSubBlock(config, spec)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
X: torch.Tensor, # (T, n_hc, hidden_size) BF16 — residual streams
|
||||
token_ids: torch.Tensor, # (T,) int32 — for hash routing
|
||||
cache: "LayerCacheHandle",
|
||||
) -> torch.Tensor:
|
||||
# ---- Attention sub-block ----
|
||||
x_attn_in, ctx_attn = self.mhc_attn.pre_block(X)
|
||||
x_attn_in = self.norm_attn(x_attn_in)
|
||||
F_attn = self.attn.forward(x_attn_in, cache)
|
||||
X = self.mhc_attn.post_block(X, F_attn, ctx_attn)
|
||||
|
||||
# ---- FFN sub-block ----
|
||||
x_ffn_in, ctx_ffn = self.mhc_ffn.pre_block(X)
|
||||
x_ffn_in = self.norm_ffn(x_ffn_in)
|
||||
F_ffn = self.ffn.forward(x_ffn_in, token_ids)
|
||||
X = self.mhc_ffn.post_block(X, F_ffn, ctx_ffn)
|
||||
|
||||
return X
|
||||
|
||||
102
dsv4/model/layer_schedule.py
Normal file
102
dsv4/model/layer_schedule.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Per-layer-index architecture schedule for DSV4.
|
||||
|
||||
Encodes the attention / FFN / routing pattern from the paper. Kept
|
||||
separate from layer.py so the KV cache allocator can query it without
|
||||
instantiating any layers.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from dsv4.model.config import DSV4Config
|
||||
|
||||
|
||||
class AttentionType(Enum):
|
||||
SWA = "swa" # Sliding window only (Flash layers 0-1)
|
||||
HCA = "hca" # Heavily Compressed Attention (Pro layers 0-1, alternating later)
|
||||
CSA = "csa" # Compressed Sparse Attention (alternating from layer 2)
|
||||
|
||||
|
||||
class FFNType(Enum):
|
||||
MOE = "moe" # DeepSeekMoE — both variants use MoE in every layer
|
||||
|
||||
|
||||
class RouterMode(Enum):
|
||||
HASH = "hash" # Layers 0-2: deterministic by token ID
|
||||
DENSE = "dense" # Layers 3+: sqrt(softplus) + top-k
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LayerSpec:
|
||||
"""The complete architectural type of a single transformer layer."""
|
||||
layer_idx: int
|
||||
attn: AttentionType
|
||||
ffn: FFNType
|
||||
router_mode: RouterMode
|
||||
|
||||
|
||||
def build_schedule(config: DSV4Config) -> List[LayerSpec]:
|
||||
"""Produce the full per-layer schedule for a given config.
|
||||
|
||||
Paper §4.2.1:
|
||||
Flash: first 2 layers = pure SWA, then alternating CSA / HCA.
|
||||
Pro: first 2 layers = HCA, then alternating CSA / HCA.
|
||||
Both: first 3 MoE layers use hash routing.
|
||||
|
||||
'Alternating CSA / HCA' starts with CSA at layer 2 in both variants.
|
||||
"""
|
||||
schedule: List[LayerSpec] = []
|
||||
first_special = config.num_special_first_layers # 2
|
||||
|
||||
for i in range(config.num_layers):
|
||||
# Attention type
|
||||
if i < first_special:
|
||||
if config.variant == "flash":
|
||||
attn = AttentionType.SWA
|
||||
else: # pro
|
||||
attn = AttentionType.HCA
|
||||
else:
|
||||
# Layer 2 is CSA, layer 3 is HCA, layer 4 is CSA, ...
|
||||
attn = AttentionType.CSA if (i - first_special) % 2 == 0 else AttentionType.HCA
|
||||
|
||||
# FFN type — DSV4 is MoE in every layer.
|
||||
ffn = FFNType.MOE
|
||||
|
||||
# Router mode — hash for the first N MoE layers.
|
||||
router = RouterMode.HASH if i < config.num_hash_routing_layers else RouterMode.DENSE
|
||||
|
||||
schedule.append(LayerSpec(layer_idx=i, attn=attn, ffn=ffn, router_mode=router))
|
||||
|
||||
return schedule
|
||||
|
||||
|
||||
def validate_schedule(schedule: List[LayerSpec], config: DSV4Config) -> None:
|
||||
"""Sanity checks. Wrong schedule = silent garbage, so be loud here."""
|
||||
assert len(schedule) == config.num_layers, \
|
||||
f"schedule has {len(schedule)} layers, config says {config.num_layers}"
|
||||
|
||||
# First N hash routing layers
|
||||
for i in range(config.num_hash_routing_layers):
|
||||
assert schedule[i].router_mode == RouterMode.HASH, \
|
||||
f"layer {i} should be HASH-routed"
|
||||
for i in range(config.num_hash_routing_layers, config.num_layers):
|
||||
assert schedule[i].router_mode == RouterMode.DENSE, \
|
||||
f"layer {i} should be DENSE-routed"
|
||||
|
||||
# First two layers are special
|
||||
expected_first = AttentionType.SWA if config.variant == "flash" else AttentionType.HCA
|
||||
for i in range(config.num_special_first_layers):
|
||||
assert schedule[i].attn == expected_first, \
|
||||
f"layer {i} should be {expected_first} for {config.variant}"
|
||||
|
||||
# Alternation from layer 2
|
||||
for i in range(config.num_special_first_layers, config.num_layers):
|
||||
expected = AttentionType.CSA if (i - config.num_special_first_layers) % 2 == 0 \
|
||||
else AttentionType.HCA
|
||||
assert schedule[i].attn == expected, \
|
||||
f"layer {i} should be {expected}, got {schedule[i].attn}"
|
||||
|
||||
# Every layer is MoE
|
||||
for i, spec in enumerate(schedule):
|
||||
assert spec.ffn == FFNType.MOE, f"layer {i} should be MOE"
|
||||
97
run_router_tests.py
Normal file
97
run_router_tests.py
Normal file
@@ -0,0 +1,97 @@
|
||||
#!/usr/bin/env python3
|
||||
import sys, os
|
||||
os.chdir('/root/dsv4-nvfp4-workspace/kernel')
|
||||
sys.path.insert(0, '.')
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
def pf(ok):
|
||||
return 'PASS' if ok else 'FAIL'
|
||||
|
||||
print(f'CUDA: {torch.cuda.is_available()}, Device: {torch.cuda.get_device_name(0)}', flush=True)
|
||||
print(f'GPU mem free: {torch.cuda.mem_get_info()[0]//1024//1024} MB', flush=True)
|
||||
|
||||
# 1. Hash Router
|
||||
print('\n=== Hash Router ===', flush=True)
|
||||
try:
|
||||
hr = load(name='hr2', sources=['dsv4/kernels/cuda/hash_router.cu'],
|
||||
extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False)
|
||||
for N in [1, 4, 64, 128, 512]:
|
||||
vocab, k, E = 128000, 6, 256
|
||||
lut = torch.randint(0, E, (vocab, k), dtype=torch.int32, device='cuda')
|
||||
tids = torch.randint(0, vocab, (N,), dtype=torch.int32, device='cuda')
|
||||
ow = torch.empty(N, k, dtype=torch.float32, device='cuda')
|
||||
oi = torch.empty(N, k, dtype=torch.int32, device='cuda')
|
||||
hr.hash_router(tids, lut, k, ow, oi)
|
||||
torch.cuda.synchronize()
|
||||
exp_ids = lut[tids]
|
||||
exp_w = torch.full((N, k), 1.0/k, dtype=torch.float32, device='cuda')
|
||||
ids_ok = (oi == exp_ids).all().item()
|
||||
w_ok = torch.allclose(ow, exp_w, atol=1e-7, rtol=1e-7)
|
||||
ok = ids_ok and w_ok
|
||||
print(f' N={N:4d}: IDs={ids_ok} W={w_ok} {pf(ok)}', flush=True)
|
||||
del lut, tids, ow, oi, exp_ids, exp_w
|
||||
print('Hash Router: ALL PASS', flush=True)
|
||||
except Exception as e:
|
||||
import traceback; traceback.print_exc()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 2. Top-k Select
|
||||
print('\n=== Top-k Select ===', flush=True)
|
||||
try:
|
||||
tk = load(name='tk2', sources=['dsv4/kernels/cuda/topk_select.cu'],
|
||||
extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False)
|
||||
for N, E in [(1,256), (4,256), (64,256), (64,384), (128,256), (512,256)]:
|
||||
k = 6
|
||||
scores = torch.randn(N, E, dtype=torch.float32, device='cuda')
|
||||
ov, oidx = tk.topk_select(scores, k)
|
||||
torch.cuda.synchronize()
|
||||
exp = scores.topk(k, dim=-1)
|
||||
ids_ok = (oidx == exp.indices).all().item()
|
||||
vals_ok = torch.allclose(ov, exp.values, atol=1e-6, rtol=1e-6)
|
||||
ok = ids_ok and vals_ok
|
||||
print(f' N={N:4d} E={E}: IDs={ids_ok} V={vals_ok} {pf(ok)}', flush=True)
|
||||
del scores, ov, oidx, exp
|
||||
print('Top-k Select: ALL PASS', flush=True)
|
||||
except Exception as e:
|
||||
import traceback; traceback.print_exc()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 3. Activation + Top-k
|
||||
print('\n=== Activation + Top-k ===', flush=True)
|
||||
try:
|
||||
atk = load(name='atk2', sources=['dsv4/kernels/cuda/activation_topk.cu'],
|
||||
extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False)
|
||||
for N, E in [(1,256), (4,256), (64,256), (64,384)]:
|
||||
k = 6
|
||||
logits = torch.randn(N, E, dtype=torch.float32, device='cuda')
|
||||
bias = torch.randn(E, dtype=torch.float32, device='cuda')
|
||||
scaling = 2.5
|
||||
out_w = torch.empty(N, k, dtype=torch.float32, device='cuda')
|
||||
out_ids = torch.empty(N, k, dtype=torch.int32, device='cuda')
|
||||
atk.activation_topk(logits, bias, k, scaling, out_w, out_ids)
|
||||
torch.cuda.synchronize()
|
||||
# Oracle
|
||||
sp = torch.log1p(logits.exp()) + torch.clamp(logits, min=0)
|
||||
act = sp.sqrt()
|
||||
score = act + bias
|
||||
exp_topk = score.topk(k, dim=-1)
|
||||
exp_ids = exp_topk.indices
|
||||
exp_w = torch.gather(act, 1, exp_ids)
|
||||
exp_w = exp_w / exp_w.sum(dim=-1, keepdim=True) * scaling
|
||||
ids_ok = (out_ids == exp_ids).all().item()
|
||||
vals_ok = torch.allclose(out_w, exp_w, atol=1e-5, rtol=1e-5)
|
||||
ok = ids_ok and vals_ok
|
||||
print(f' N={N:4d} E={E}: IDs={ids_ok} V={vals_ok} {pf(ok)}', flush=True)
|
||||
if not ids_ok:
|
||||
print(f' ID mismatches: {(out_ids != exp_ids).sum().item()}/{out_ids.numel()}', flush=True)
|
||||
if not vals_ok:
|
||||
print(f' Max diff: {(out_w - exp_w).abs().max().item():.2e}', flush=True)
|
||||
del logits, bias, out_w, out_ids, sp, act, score, exp_topk, exp_ids, exp_w
|
||||
print('Activation+Topk: ALL PASS', flush=True)
|
||||
except Exception as e:
|
||||
import traceback; traceback.print_exc()
|
||||
|
||||
print('\n=== DONE ===', flush=True)
|
||||
@@ -98,11 +98,11 @@ class FmhaV3TenWarp:
|
||||
tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
|
||||
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
|
||||
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
|
||||
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
|
||||
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
|
||||
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
|
||||
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True)
|
||||
|
||||
85
tests/unit/test_layer_schedule.py
Normal file
85
tests/unit/test_layer_schedule.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Tests for layer schedule — pure data, no kernels, no tensors."""
|
||||
|
||||
from dsv4.model.config import DSV4Config
|
||||
from dsv4.model.layer_schedule import (
|
||||
AttentionType, FFNType, RouterMode,
|
||||
LayerSpec, build_schedule, validate_schedule,
|
||||
)
|
||||
|
||||
|
||||
def test_flash_schedule():
|
||||
config = DSV4Config.flash()
|
||||
schedule = build_schedule(config)
|
||||
validate_schedule(schedule, config)
|
||||
|
||||
assert len(schedule) == 43
|
||||
|
||||
# First two layers: SWA + hash routing
|
||||
assert schedule[0].attn == AttentionType.SWA
|
||||
assert schedule[1].attn == AttentionType.SWA
|
||||
assert schedule[0].router_mode == RouterMode.HASH
|
||||
assert schedule[1].router_mode == RouterMode.HASH
|
||||
|
||||
# Layer 2: CSA + hash routing (last hash layer)
|
||||
assert schedule[2].attn == AttentionType.CSA
|
||||
assert schedule[2].router_mode == RouterMode.HASH
|
||||
|
||||
# Layer 3: HCA + dense routing (first dense layer)
|
||||
assert schedule[3].attn == AttentionType.HCA
|
||||
assert schedule[3].router_mode == RouterMode.DENSE
|
||||
|
||||
# Alternation continues
|
||||
assert schedule[4].attn == AttentionType.CSA
|
||||
assert schedule[5].attn == AttentionType.HCA
|
||||
|
||||
# All layers are MoE
|
||||
for spec in schedule:
|
||||
assert spec.ffn == FFNType.MOE
|
||||
|
||||
|
||||
def test_pro_schedule():
|
||||
config = DSV4Config.pro()
|
||||
schedule = build_schedule(config)
|
||||
validate_schedule(schedule, config)
|
||||
|
||||
assert len(schedule) == 61
|
||||
|
||||
# First two layers: HCA + hash routing
|
||||
assert schedule[0].attn == AttentionType.HCA
|
||||
assert schedule[1].attn == AttentionType.HCA
|
||||
assert schedule[0].router_mode == RouterMode.HASH
|
||||
|
||||
# Layer 2: CSA + hash routing
|
||||
assert schedule[2].attn == AttentionType.CSA
|
||||
assert schedule[2].router_mode == RouterMode.HASH
|
||||
|
||||
# Layer 3: HCA + dense routing
|
||||
assert schedule[3].attn == AttentionType.HCA
|
||||
assert schedule[3].router_mode == RouterMode.DENSE
|
||||
|
||||
|
||||
def test_layer_spec_frozen():
|
||||
"""LayerSpec is frozen — mutation should raise."""
|
||||
config = DSV4Config.flash()
|
||||
spec = build_schedule(config)[0]
|
||||
try:
|
||||
spec.attn = AttentionType.HCA
|
||||
assert False, "should have raised"
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def test_schedule_indices_match():
|
||||
"""Each LayerSpec.layer_idx matches its position in the list."""
|
||||
config = DSV4Config.flash()
|
||||
schedule = build_schedule(config)
|
||||
for i, spec in enumerate(schedule):
|
||||
assert spec.layer_idx == i
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_flash_schedule()
|
||||
test_pro_schedule()
|
||||
test_layer_spec_frozen()
|
||||
test_schedule_indices_match()
|
||||
print("All schedule tests passed")
|
||||
Reference in New Issue
Block a user