- Move dead dsv4/ modules to dsv4/_archive/ (52 files)
- model/{dsv4,mtp,layer,layer_schedule}
- layers/{embedding,attention,ffn,norm} (kept linear,mhc,router,moe,shared_expert,grouped_linear - live)
- cache/*, kernels/cache/*, kernels/indexer/{csa_indexer,score_topk,compute_valid_lens}
- kernels/router/{nvfp4_fused_router,dense_router_decode_kernel,dense_router_prefill}
- ops/{topk,topk_select,rope,router}, loader/{hf_checkpoint,layout_convert}
- reference/{attention,compressor,csa_attention,moe_pipeline}
- kernels/compressor/{compress_tail,csa_hca}
- Restore dsv4/ops/{router,custom_ops}.py (needed by live layers)
- Fix dsv4/kernels/{indexer,compressor,attention}/__init__.py (removed broken imports)
- Remove preload_all() from loader.py (dead, referenced nonexistent .cu file)
- Fix loader.py docstring (fused_amax_quantize_nvfp4 → quantize_nvfp4_from_buffer)
- Move broken tests to tests/e2e_archive/
- test_fused_router, production_values_test, e2e/{one_layer,model_construction,csa_hca}
- vLLM has 0 imports of dsv4 (Step 0 confirmed)
31 lines
1.1 KiB
Python
31 lines
1.1 KiB
Python
"""RMSNorm — PyTorch reference implementation.
|
|
|
|
Swap to fused kernel (CuTeDSL) in Phase 6. API won't change.
|
|
"""
|
|
import torch
|
|
|
|
|
|
class RMSNorm:
|
|
"""Root Mean Square Layer Normalization.
|
|
|
|
y = x / sqrt(mean(x^2) + eps) * weight
|
|
|
|
CUDA-graph-compatible: weight is a buffer, no CPU syncs.
|
|
"""
|
|
|
|
def __init__(self, hidden_size: int, eps: float = 1e-6, device: str = "cuda"):
|
|
self.hidden_size = hidden_size
|
|
self.eps = eps
|
|
self.device = device
|
|
self.weight: torch.Tensor | None = None # (hidden_size,) FP32, set by load_weights
|
|
|
|
def load_weights(self, weight: torch.Tensor) -> None:
|
|
assert weight.shape == (self.hidden_size,), f"weight shape {weight.shape} != ({self.hidden_size},)"
|
|
self.weight = weight.to(device=self.device, dtype=torch.float32)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""x: (T, hidden_size) BF16 -> (T, hidden_size) BF16"""
|
|
x_f = x.float()
|
|
rms = x_f.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
|
|
return (x_f * rms * self.weight).to(torch.bfloat16)
|