From 44582ec43bc8a7b1bfc41113077ad9ea0a1fa402 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 21 May 2026 23:31:58 +0000 Subject: [PATCH] Fix layer construction: match existing API signatures, add RMSNorm impl - Nvfp4GroupedLinear: (n_local_groups, heads_per_group, head_dim, o_lora_rank) - mHCLayer: hidden_dim, t_max_sinkhorn (not hidden_size, sinkhorn_iters) - RMSNorm: PyTorch reference implementation (BF16, cudagraph-safe) - Verified: all 43 Flash + 61 Pro layers construct cleanly - All projection shapes validated against architecture spec --- dsv4/layers/attention.py | 11 +++++++---- dsv4/layers/norm.py | 32 ++++++++++++++++++++++++++++++-- dsv4/model/layer.py | 8 ++++---- 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/dsv4/layers/attention.py b/dsv4/layers/attention.py index f2f42c1b..9d126df2 100644 --- a/dsv4/layers/attention.py +++ b/dsv4/layers/attention.py @@ -85,15 +85,18 @@ class AttentionSubBlock: self.indexer_head_weights = Nvfp4Linear( in_features=config.hidden_size, out_features=config.indexer_num_heads, + max_num_tokens=1, # scalar per head, not per-token projection ) # ---- Output projection ---- # wo_a: grouped, splits the n_heads outputs into n_groups and projects - # each group from (head_dim * n_heads / n_groups) to output_group_dim. + # each group from (heads_per_group * head_dim) to o_lora_rank. + heads_per_group = config.num_query_heads // config.num_output_groups self.wo_a = Nvfp4GroupedLinear( - num_groups=config.num_output_groups, - in_features=(config.head_dim * config.num_query_heads) // config.num_output_groups, - out_features=config.output_group_dim, + n_local_groups=config.num_output_groups, + heads_per_group=heads_per_group, + head_dim=config.head_dim, + o_lora_rank=config.output_group_dim, ) # wo_b: dense, concatenated group outputs back to hidden_size. self.wo_b = Nvfp4Linear( diff --git a/dsv4/layers/norm.py b/dsv4/layers/norm.py index 04c95489..c75e4b46 100644 --- a/dsv4/layers/norm.py +++ b/dsv4/layers/norm.py @@ -1,2 +1,30 @@ -"""RMSNorm placeholder.""" -# TODO: Implement RMSNorm +"""RMSNorm — PyTorch reference implementation. + +Swap to fused kernel (CuTeDSL) in Phase 6. API won't change. +""" +import torch + + +class RMSNorm: + """Root Mean Square Layer Normalization. + + y = x / sqrt(mean(x^2) + eps) * weight + + CUDA-graph-compatible: weight is a buffer, no CPU syncs. + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6, device: str = "cuda"): + self.hidden_size = hidden_size + self.eps = eps + self.device = device + self.weight: torch.Tensor | None = None # (hidden_size,) FP32, set by load_weights + + def load_weights(self, weight: torch.Tensor) -> None: + assert weight.shape == (self.hidden_size,), f"weight shape {weight.shape} != ({self.hidden_size},)" + self.weight = weight.to(device=self.device, dtype=torch.float32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """x: (T, hidden_size) BF16 -> (T, hidden_size) BF16""" + x_f = x.float() + rms = x_f.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt() + return (x_f * rms * self.weight).to(torch.bfloat16) diff --git a/dsv4/model/layer.py b/dsv4/model/layer.py index 65f338f9..388e5eb7 100644 --- a/dsv4/model/layer.py +++ b/dsv4/model/layer.py @@ -43,14 +43,14 @@ class TransformerLayer: # Two mHC wrappers — one per sub-block. mHCLayer holds its own # projection weights (W_pre, W_res, W_post) and static biases. self.mhc_attn = mHCLayer( - hidden_size=config.hidden_size, + hidden_dim=config.hidden_size, n_hc=config.n_hc, - sinkhorn_iters=config.sinkhorn_iters, + t_max_sinkhorn=config.sinkhorn_iters, ) self.mhc_ffn = mHCLayer( - hidden_size=config.hidden_size, + hidden_dim=config.hidden_size, n_hc=config.n_hc, - sinkhorn_iters=config.sinkhorn_iters, + t_max_sinkhorn=config.sinkhorn_iters, ) # Pre-block norms (one per sub-block).