Files
nvfp4-megamoe-kernel/dsv4/ops/layouts.py
biondizzle 3fb3c925af Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

124 lines
4.3 KiB
Python

"""Tensor layout helpers: scale swizzle, gate/up interleave, K-major, offsets."""
import torch
from dsv4.kernels.gemm.grouped import (
pad_and_swizzle_single,
assemble_raw_scales_2d3d_2d_side,
assemble_raw_scales_2d3d_3d_side,
)
def ceil_div(a, b):
return (a + b - 1) // b
def round_up(a, b):
return ceil_div(a, b) * b
def interleave_l1_weights(w_ekn, granularity_bf16=8):
"""Interleave gate/up weights at granularity 8 in BF16 (4 in FP4).
The fused SwiGLU epilogue requires gate/up pairs to be adjacent in the
MMA accumulator. With interleaved weights, the MMA tile produces
gate[i*8..i*8+7] and up[i*8..i*8+7] next to each other in registers,
enabling a single-register SwiGLU without SMEM round-trips.
Before: [gate_0..gate_N/2-1 | up_0..up_N/2-1]
After: [gate_0..gate_7, up_0..up_7, gate_8..gate_15, up_8..up_15, ...]
The interleave operates along the N dimension, where each column = 1 BF16
(FP4 packing is along K, not N). So g = granularity_bf16 directly.
Args:
w_ekn: (E, K_packed, N_packed) FP4 weight tensor in K-major layout
N_packed = 2*intermediate/2 = intermediate (gate+up fused)
granularity_bf16: interleave group size in BF16 elements (default 8)
Returns:
(E, K_packed, N_packed) FP4 weight tensor with interleaved gate/up
"""
E, K, N = w_ekn.shape
N_half = N // 2 # gate and up each have N/2 FP4 columns
g = granularity_bf16 # N-axis interleave: each N-col = 1 BF16 col (packing is along K)
gate = w_ekn[:, :, :N_half].reshape(E, K, N_half // g, g)
up = w_ekn[:, :, N_half:].reshape(E, K, N_half // g, g)
return torch.stack([gate, up], dim=3).reshape(E, K, N)
def deinterleave_l1_weights(w_ekn, granularity_bf16=8):
"""De-interleave gate/up weights (inverse of interleave_l1_weights).
Used for testing/verification only.
"""
g = granularity_bf16 # N-axis: each N-col = 1 BF16 col
E, K, N = w_ekn.shape
w_reshaped = w_ekn.reshape(E, K, N // (2 * g), 2, g)
gate = w_reshaped[:, :, :, 0, :].reshape(E, K, N // 2)
up = w_reshaped[:, :, :, 1, :].reshape(E, K, N // 2)
return torch.cat([gate, up], dim=2)
def assemble_scales_2d_side(raw_scales):
"""Assemble activation scale factors for the 2Dx3D scenario.
Args:
raw_scales: list of (M_e, K_sf) float8_e4m3fn tensors, one per expert
Returns:
Assembled and swizzled scale tensor
"""
return assemble_raw_scales_2d3d_2d_side(raw_scales)
def assemble_scales_3d_side(raw_scales):
"""Assemble weight scale factors for the 2Dx3D scenario.
Args:
raw_scales: list of (K_sf, N) float8_e4m3fn tensors, one per expert
NOTE: These will be transposed to (N, K_sf) before swizzling,
since the kernel expects N as the non-K dimension.
Returns:
Assembled and swizzled scale tensor
"""
# Kernel expects (N, K_sf) — transpose before swizzling
transposed = [sf.T.contiguous() for sf in raw_scales]
return assemble_raw_scales_2d3d_3d_side(transposed)
# ── Tensor Layout Conversion ──────────────────────────────────────────
def make_b_k_major(b_tensor):
"""Convert B tensor from N-major to K-major layout.
The kernel expects B with stride (E*K*N, 1, K) — K is contiguous.
torch.stack produces stride (E*K*N, N, 1) — N is contiguous.
Args:
b_tensor: (experts, K_packed, N_packed) float4_e2m1fn_x2, N-major
Returns:
Same shape, K-major strides
"""
return b_tensor.permute(0, 2, 1).contiguous().permute(0, 2, 1)
def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"):
"""Compute cumulative token offsets for the grouped GEMM.
Args:
tokens_per_expert: list of int, one per expert
Returns:
offs: (num_experts,) int32 — cumulative sum
"""
offs = torch.tensor(
[sum(tokens_per_expert[:e+1]) for e in range(num_experts)],
dtype=torch.int32, device=device,
)
return offs
# ── Kernel Launch ─────────────────────────────────────────────────────