- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py - Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc. - Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda) - Moved PyTorch bridges to dsv4/ops/ - Moved nn.Module layers to dsv4layers/ - Moved reference implementations to dsv4/reference/ - Moved vendored CUTLASS code to vendored/ - Archived ~190 debug tests to tests/archive/ - Kept ~15 canonical tests in tests/unit/ - Updated all import paths - Added stubs for future components (model/, cache/, loader/) - Updated pyproject.toml: dsv4-inference package name
124 lines
4.3 KiB
Python
124 lines
4.3 KiB
Python
"""Tensor layout helpers: scale swizzle, gate/up interleave, K-major, offsets."""
|
|
import torch
|
|
|
|
from dsv4.kernels.gemm.grouped import (
|
|
pad_and_swizzle_single,
|
|
assemble_raw_scales_2d3d_2d_side,
|
|
assemble_raw_scales_2d3d_3d_side,
|
|
)
|
|
|
|
def ceil_div(a, b):
|
|
return (a + b - 1) // b
|
|
|
|
|
|
def round_up(a, b):
|
|
return ceil_div(a, b) * b
|
|
|
|
def interleave_l1_weights(w_ekn, granularity_bf16=8):
|
|
"""Interleave gate/up weights at granularity 8 in BF16 (4 in FP4).
|
|
|
|
The fused SwiGLU epilogue requires gate/up pairs to be adjacent in the
|
|
MMA accumulator. With interleaved weights, the MMA tile produces
|
|
gate[i*8..i*8+7] and up[i*8..i*8+7] next to each other in registers,
|
|
enabling a single-register SwiGLU without SMEM round-trips.
|
|
|
|
Before: [gate_0..gate_N/2-1 | up_0..up_N/2-1]
|
|
After: [gate_0..gate_7, up_0..up_7, gate_8..gate_15, up_8..up_15, ...]
|
|
|
|
The interleave operates along the N dimension, where each column = 1 BF16
|
|
(FP4 packing is along K, not N). So g = granularity_bf16 directly.
|
|
|
|
Args:
|
|
w_ekn: (E, K_packed, N_packed) FP4 weight tensor in K-major layout
|
|
N_packed = 2*intermediate/2 = intermediate (gate+up fused)
|
|
granularity_bf16: interleave group size in BF16 elements (default 8)
|
|
|
|
Returns:
|
|
(E, K_packed, N_packed) FP4 weight tensor with interleaved gate/up
|
|
"""
|
|
E, K, N = w_ekn.shape
|
|
N_half = N // 2 # gate and up each have N/2 FP4 columns
|
|
g = granularity_bf16 # N-axis interleave: each N-col = 1 BF16 col (packing is along K)
|
|
|
|
gate = w_ekn[:, :, :N_half].reshape(E, K, N_half // g, g)
|
|
up = w_ekn[:, :, N_half:].reshape(E, K, N_half // g, g)
|
|
return torch.stack([gate, up], dim=3).reshape(E, K, N)
|
|
|
|
|
|
def deinterleave_l1_weights(w_ekn, granularity_bf16=8):
|
|
"""De-interleave gate/up weights (inverse of interleave_l1_weights).
|
|
|
|
Used for testing/verification only.
|
|
"""
|
|
g = granularity_bf16 # N-axis: each N-col = 1 BF16 col
|
|
E, K, N = w_ekn.shape
|
|
w_reshaped = w_ekn.reshape(E, K, N // (2 * g), 2, g)
|
|
gate = w_reshaped[:, :, :, 0, :].reshape(E, K, N // 2)
|
|
up = w_reshaped[:, :, :, 1, :].reshape(E, K, N // 2)
|
|
return torch.cat([gate, up], dim=2)
|
|
|
|
|
|
def assemble_scales_2d_side(raw_scales):
|
|
"""Assemble activation scale factors for the 2Dx3D scenario.
|
|
|
|
Args:
|
|
raw_scales: list of (M_e, K_sf) float8_e4m3fn tensors, one per expert
|
|
|
|
Returns:
|
|
Assembled and swizzled scale tensor
|
|
"""
|
|
return assemble_raw_scales_2d3d_2d_side(raw_scales)
|
|
|
|
|
|
def assemble_scales_3d_side(raw_scales):
|
|
"""Assemble weight scale factors for the 2Dx3D scenario.
|
|
|
|
Args:
|
|
raw_scales: list of (K_sf, N) float8_e4m3fn tensors, one per expert
|
|
NOTE: These will be transposed to (N, K_sf) before swizzling,
|
|
since the kernel expects N as the non-K dimension.
|
|
|
|
Returns:
|
|
Assembled and swizzled scale tensor
|
|
"""
|
|
# Kernel expects (N, K_sf) — transpose before swizzling
|
|
transposed = [sf.T.contiguous() for sf in raw_scales]
|
|
return assemble_raw_scales_2d3d_3d_side(transposed)
|
|
|
|
|
|
# ── Tensor Layout Conversion ──────────────────────────────────────────
|
|
|
|
def make_b_k_major(b_tensor):
|
|
"""Convert B tensor from N-major to K-major layout.
|
|
|
|
The kernel expects B with stride (E*K*N, 1, K) — K is contiguous.
|
|
torch.stack produces stride (E*K*N, N, 1) — N is contiguous.
|
|
|
|
Args:
|
|
b_tensor: (experts, K_packed, N_packed) float4_e2m1fn_x2, N-major
|
|
|
|
Returns:
|
|
Same shape, K-major strides
|
|
"""
|
|
return b_tensor.permute(0, 2, 1).contiguous().permute(0, 2, 1)
|
|
|
|
|
|
def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"):
|
|
"""Compute cumulative token offsets for the grouped GEMM.
|
|
|
|
Args:
|
|
tokens_per_expert: list of int, one per expert
|
|
|
|
Returns:
|
|
offs: (num_experts,) int32 — cumulative sum
|
|
"""
|
|
offs = torch.tensor(
|
|
[sum(tokens_per_expert[:e+1]) for e in range(num_experts)],
|
|
dtype=torch.int32, device=device,
|
|
)
|
|
return offs
|
|
|
|
|
|
# ── Kernel Launch ─────────────────────────────────────────────────────
|
|
|