|
|
4453d7475a
|
Fix layer construction: match existing API signatures, add RMSNorm impl
- Nvfp4GroupedLinear: (n_local_groups, heads_per_group, head_dim, o_lora_rank)
- mHCLayer: hidden_dim, t_max_sinkhorn (not hidden_size, sinkhorn_iters)
- RMSNorm: PyTorch reference implementation (BF16, cudagraph-safe)
- Verified: all 43 Flash + 61 Pro layers construct cleanly
- All projection shapes validated against architecture spec
|
2026-05-21 23:31:58 +00:00 |
|
|
|
66a89859ed
|
Layer dispatch: config, schedule, attention/FFN sub-blocks, TransformerLayer
DSV4Config: frozen dataclass with .flash() / .pro() classmethods.
All architectural constants (dims, heads, MoE params, mHC) in one place.
LayerSchedule: pure-data per-layer-index -> (attn_type, ffn_type, router_mode).
Flash: SWA, SWA, CSA, HCA, CSA, HCA, ... (43 layers)
Pro: HCA, HCA, CSA, HCA, CSA, HCA, ... (61 layers)
Both: first 3 MoE layers = hash routing, rest = dense
validate_schedule() enforces correctness at construction.
AttentionSubBlock: CSA / HCA / SWA variants.
- Low-rank Q projection (q_down -> q_up)
- KV down-projection (varies by attn type: 4h/2h/1h)
- CSA: indexer_q_up + indexer_head_weights
- Grouped output projection (wo_a + wo_b)
- Kernel calls are imports (NotImplementedError until kernel lands)
- No PyTorch fallback paths
FFNSubBlock: MoE + shared expert.
- Router (hash/dense) mode from LayerSpec
- Nvfp4MoE + Nvfp4SharedExpert
TransformerLayer: composition of mHC + norm + attention + FFN.
- Two mHC wrappers (attn + ffn sub-blocks)
- Two RMSNorm (one per sub-block)
- Pure orchestration, no learned params on the layer itself
Tests: schedule construction + validation for both variants.
No forward tests yet (depends on FMHA kernel + KV cache).
|
2026-05-21 23:11:09 +00:00 |
|
|
|
3fb3c925af
|
Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
|
2026-05-21 17:30:44 +00:00 |
|