- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py - Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc. - Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda) - Moved PyTorch bridges to dsv4/ops/ - Moved nn.Module layers to dsv4layers/ - Moved reference implementations to dsv4/reference/ - Moved vendored CUTLASS code to vendored/ - Archived ~190 debug tests to tests/archive/ - Kept ~15 canonical tests in tests/unit/ - Updated all import paths - Added stubs for future components (model/, cache/, loader/) - Updated pyproject.toml: dsv4-inference package name
353 lines
14 KiB
Python
353 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
CuTeDSL NVFP4 GEMM test — verify the reference kernel works with our data.
|
|
|
|
Uses NVIDIA's ScaledGroupedGemmKernel from the CUTLASS CuTeDSL examples
|
|
with NVFP4 (Float4E2M1FN + Float8E4M3FN, sf_vec_size=16).
|
|
|
|
This tests a single GEMM: A(tokens, K) @ B(experts, K, N) = C(tokens, N)
|
|
with proper scale factor padding/swizzling.
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import math
|
|
import torch
|
|
|
|
# Add repo root so 'from cutedsl.kernel...' works
|
|
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
sys.path.insert(0, REPO_ROOT)
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
import cutlass.torch as cutlass_torch
|
|
import cutlass.utils as utils
|
|
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
|
|
|
from dsv4.kernels.gemm.grouped import (
|
|
ScaledGroupedGemmKernel,
|
|
pad_and_swizzle_single,
|
|
assemble_raw_scales_2d3d_2d_side,
|
|
assemble_raw_scales_2d3d_3d_side,
|
|
cat_byte_reinterpretable_tensors,
|
|
stack_byte_reinterpretable_tensors,
|
|
offs_to_group_sizes,
|
|
)
|
|
|
|
# ── Helpers ────────────────────────────────────────────────────────────
|
|
|
|
def ceil_div(a, b):
|
|
return (a + b - 1) // b
|
|
|
|
def round_up(a, b):
|
|
return ceil_div(a, b) * b
|
|
|
|
|
|
def quantize_bf16_to_nvfp4(x_bf16, block_size=16):
|
|
"""Quantize BF16 tensor to NVFP4 (E2M1 + E4M3 block scales + global scale).
|
|
|
|
Returns (x_fp4, block_scales, global_scale) where:
|
|
x_fp4: torch.float4_e2m1fn_x2 with same logical shape (packed along last dim)
|
|
block_scales: float8_e4m3fn with shape (..., ceil_div(last_dim, block_size))
|
|
global_scale: float32 scalar
|
|
"""
|
|
x_f32 = x_bf16.float()
|
|
amax = x_f32.abs().max().clamp(min=1e-8).float()
|
|
global_scale = amax / (6.0 * 448.0)
|
|
|
|
x_norm = x_f32 / global_scale
|
|
|
|
# Per-block amax for block scales
|
|
last_dim = x_norm.shape[-1]
|
|
n_blocks = ceil_div(last_dim, block_size)
|
|
|
|
# Pad last dim to multiple of block_size
|
|
if last_dim % block_size != 0:
|
|
pad_size = n_blocks * block_size - last_dim
|
|
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
|
|
|
|
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
|
|
block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8)
|
|
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
|
|
|
|
# Quantize to E2M1
|
|
E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
|
x_blocks = x_reshaped
|
|
block_sf_expanded = block_scale.float().unsqueeze(-1)
|
|
x_scaled = x_blocks / block_sf_expanded.clamp(min=1e-8)
|
|
|
|
magnitudes = torch.tensor(E2M1_MAGNITUDES, dtype=torch.float32, device=x_bf16.device)
|
|
signs = torch.sign(x_scaled)
|
|
abs_scaled = x_scaled.abs().unsqueeze(-1)
|
|
distances = (abs_scaled - magnitudes).abs()
|
|
indices = distances.argmin(dim=-1)
|
|
|
|
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
|
|
|
# Pack pairs: byte = (odd_nibble << 4) | even_nibble
|
|
even = nibbles[..., ::2]
|
|
odd = nibbles[..., 1::2]
|
|
packed = (odd << 4) | even
|
|
|
|
# View as float4_e2m1fn_x2 — same logical shape, packed last dim halved
|
|
# The logical shape has the original last_dim, but stored packed
|
|
# float4_e2m1fn_x2: each element is 1 byte = 2 FP4 values
|
|
# Shape: (..., last_dim // 2) in float4_e2m1fn_x2
|
|
packed_shape = list(x_bf16.shape)
|
|
packed_shape[-1] = last_dim // 2
|
|
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
|
|
|
|
# Reshape block scales
|
|
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
|
|
block_scale = block_scale.reshape(sf_shape)
|
|
|
|
return x_fp4, block_scale, global_scale
|
|
|
|
|
|
def dequantize_nvfp4(x_fp4, block_scales, global_scale):
|
|
"""Dequantize NVFP4 back to BF16 for reference comparison."""
|
|
E2M1_LUT = torch.tensor([
|
|
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
|
|
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
|
|
], dtype=torch.float32, device=x_fp4.device)
|
|
|
|
raw = x_fp4.view(torch.uint8)
|
|
lo = E2M1_LUT[(raw & 0x0F).long()]
|
|
hi = E2M1_LUT[((raw >> 4) & 0x0F).long()]
|
|
|
|
unpacked = torch.empty(*raw.shape[:-1], raw.shape[-1] * 2, dtype=torch.float32, device=x_fp4.device)
|
|
unpacked[..., ::2] = lo
|
|
unpacked[..., 1::2] = hi
|
|
|
|
# Expand block scales
|
|
n_blocks = block_scales.shape[-1]
|
|
block_size = (unpacked.shape[-1]) // n_blocks
|
|
block_sf = block_scales.float().unsqueeze(-1).expand(*block_scales.shape, block_size)
|
|
block_sf = block_sf.reshape(*unpacked.shape)
|
|
|
|
return (unpacked * block_sf * global_scale).to(torch.bfloat16)
|
|
|
|
|
|
# ── Main Test ──────────────────────────────────────────────────────────
|
|
|
|
def main():
|
|
torch.manual_seed(42)
|
|
device = "cuda"
|
|
|
|
# Problem sizes
|
|
num_experts = 2
|
|
tokens_per_expert = 64
|
|
hidden = 256 # K dimension
|
|
intermediate = 128 # N dimension
|
|
sf_vec_size = 16
|
|
block_size = 16
|
|
|
|
tokens_sum = num_experts * tokens_per_expert
|
|
|
|
print(f"Test: {num_experts} experts, {tokens_per_expert} tokens each, K={hidden}, N={intermediate}")
|
|
|
|
# ── Create BF16 reference data ──
|
|
x_bf16 = torch.randn(tokens_sum, hidden, dtype=torch.bfloat16, device=device) * 2.0
|
|
w_bf16 = torch.randn(num_experts, hidden, intermediate, dtype=torch.bfloat16, device=device) * 0.5
|
|
|
|
# BF16 reference: for each expert, matmul its tokens with its weight
|
|
ref_out = torch.zeros(tokens_sum, intermediate, dtype=torch.bfloat16, device=device)
|
|
for e in range(num_experts):
|
|
start = e * tokens_per_expert
|
|
end = (e + 1) * tokens_per_expert
|
|
ref_out[start:end] = x_bf16[start:end] @ w_bf16[e]
|
|
|
|
print(f"BF16 ref: amax={ref_out.abs().max():.4f} mean={ref_out.float().mean():.6f}")
|
|
|
|
# ── Quantize to NVFP4 ──
|
|
x_fp4, x_sf, x_gs = quantize_bf16_to_nvfp4(x_bf16)
|
|
|
|
# For weights: the kernel expects (experts, hidden, intermediate) with
|
|
# packed_dim=1 (the hidden/K dimension is packed).
|
|
# w_bf16[e] is (hidden, intermediate).
|
|
# We quantize each expert weight, keeping the packed dim as hidden.
|
|
w_fp4_list, w_sf_list, w_gs_list = [], [], []
|
|
for e in range(num_experts):
|
|
w = w_bf16[e] # (hidden, intermediate) — K=hidden, N=intermediate
|
|
w_f32 = w.float()
|
|
w_amax = w_f32.abs().max().clamp(min=1e-8).float()
|
|
w_gs = w_amax / (6.0 * 448.0)
|
|
w_norm = w_f32 / w_gs
|
|
|
|
# Block scales along the K dimension (dim 0 = hidden)
|
|
# Scale shape: (ceil_div(hidden, 16), intermediate)
|
|
k_blocks = ceil_div(hidden, block_size)
|
|
if hidden % block_size != 0:
|
|
w_norm = torch.nn.functional.pad(w_norm, (0, 0, 0, k_blocks * block_size - hidden))
|
|
|
|
w_reshaped = w_norm.reshape(k_blocks, block_size, intermediate)
|
|
w_block_amax = w_reshaped.abs().amax(dim=1).clamp(min=1e-8) # (k_blocks, intermediate)
|
|
w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn)
|
|
|
|
# Quantize to E2M1 along K (dim 0)
|
|
E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
|
w_block_sf = w_sf.float().unsqueeze(1) # (k_blocks, 1, intermediate)
|
|
w_scaled = w_reshaped / w_block_sf.clamp(min=1e-8)
|
|
|
|
magnitudes = torch.tensor(E2M1_MAGNITUDES, dtype=torch.float32, device=device)
|
|
signs = torch.sign(w_scaled)
|
|
abs_scaled = w_scaled.abs().unsqueeze(-1)
|
|
distances = (abs_scaled - magnitudes).abs()
|
|
indices = distances.argmin(dim=-1)
|
|
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
|
|
|
# Pack pairs along K (block_size dim, which is dim 1 after reshape)
|
|
even = nibbles[:, ::2, :]
|
|
odd = nibbles[:, 1::2, :]
|
|
packed = (odd << 4) | even # (k_blocks, block_size//2, intermediate)
|
|
|
|
# Reshape to (hidden//2, intermediate) in float4_e2m1fn_x2
|
|
w_fp4 = packed.reshape(hidden // 2, intermediate).view(torch.float4_e2m1fn_x2)
|
|
|
|
w_fp4_list.append(w_fp4)
|
|
w_sf_list.append(w_sf) # (k_blocks, intermediate) = (hidden//16, intermediate)
|
|
w_gs_list.append(w_gs)
|
|
|
|
# Verify quantization roundtrip
|
|
x_deq = dequantize_nvfp4(x_fp4, x_sf, x_gs)
|
|
cos_quant = torch.nn.functional.cosine_similarity(
|
|
x_bf16.flatten().unsqueeze(0).float(),
|
|
x_deq.flatten().unsqueeze(0).float(),
|
|
).item()
|
|
print(f"Quantization roundtrip cosine: {cos_quant:.6f}")
|
|
|
|
# ── Prepare CuTeDSL kernel inputs ──
|
|
# The kernel expects:
|
|
# mat_a: (tokens_sum, K_packed) float4_e2m1fn_x2
|
|
# mat_b: (experts, K_packed, N_packed) float4_e2m1fn_x2 — K-major
|
|
# scale_a: assembled 2D side (padded + swizzled)
|
|
# scale_b: assembled 3D side (padded + swizzled per expert)
|
|
# offs: (experts,) int32 cumulative token offsets
|
|
# global_scale_a: (experts,) float32
|
|
# global_scale_b: (experts,) float32
|
|
|
|
# Expert offsets (cumulative sum of tokens per expert)
|
|
offs = torch.tensor([tokens_per_expert * (e + 1) for e in range(num_experts)],
|
|
dtype=torch.int32, device=device)
|
|
|
|
# Assemble scale_a (2D side: concatenate per-expert, pad to 128, swizzle)
|
|
raw_scale_a = [x_sf[e*tokens_per_expert:(e+1)*tokens_per_expert] for e in range(num_experts)]
|
|
scale_a = assemble_raw_scales_2d3d_2d_side(raw_scale_a)
|
|
|
|
# Assemble scale_b (3D side: per-expert, pad and swizzle each)
|
|
# Reference uses (N, K_sf) = (intermediate, hidden//16) for each expert
|
|
# Our w_sf is (K_sf, intermediate) — need to transpose
|
|
w_sf_t = [sf.T.contiguous() for sf in w_sf_list]
|
|
scale_b = assemble_raw_scales_2d3d_3d_side(w_sf_t)
|
|
|
|
# Global scales
|
|
global_scale_a = torch.tensor([x_gs] * num_experts, dtype=torch.float32, device=device)
|
|
global_scale_b = torch.tensor([w_gs_list[e] for e in range(num_experts)], dtype=torch.float32, device=device)
|
|
|
|
# mat_a is (tokens_sum, K_packed) in float4_e2m1fn_x2, row-major (K-major)
|
|
# This matches the reference: A shape=(128,128) stride=(128,1)
|
|
mat_a = x_fp4
|
|
|
|
# mat_b: (experts, K_packed, N_packed) in float4_e2m1fn_x2, K-major
|
|
# Reference: B shape=(2,128,128) stride=(16384,1,128) — K is stride-1
|
|
# torch.stack gives stride (16384, 128, 1) — N is stride-1 (wrong)
|
|
# We need K-major: permute, make contiguous, permute back
|
|
mat_b = torch.stack(w_fp4_list).permute(0, 2, 1).contiguous().permute(0, 2, 1)
|
|
|
|
print(f"\nKernel inputs:")
|
|
print(f" mat_a: {mat_a.shape} {mat_a.dtype}")
|
|
print(f" mat_b: {mat_b.shape} {mat_b.dtype}")
|
|
print(f" scale_a: {scale_a.shape} {scale_a.dtype}")
|
|
print(f" scale_b: {scale_b.shape} {scale_b.dtype}")
|
|
print(f" offs: {offs.tolist()}")
|
|
print(f" global_scale_a: {global_scale_a.tolist()}")
|
|
print(f" global_scale_b: {[f'{v:.6e}' for v in global_scale_b.tolist()]}")
|
|
|
|
# ── Run CuTeDSL kernel ──
|
|
print("\nCompiling and running CuTeDSL kernel (first run takes ~1 min to compile)...")
|
|
|
|
out = torch.zeros(tokens_sum, intermediate, dtype=torch.bfloat16, device=device)
|
|
|
|
kernel = ScaledGroupedGemmKernel(
|
|
scenario="2Dx3D",
|
|
sf_vec_size=sf_vec_size,
|
|
accumulate_on_output=False,
|
|
separate_tensormap_init=True,
|
|
consistent_token_padding=False,
|
|
mma_tiler_mnk=(128, 128, 256),
|
|
cluster_shape_mnk=(1, 1, 1),
|
|
)
|
|
|
|
# Convert to CuTe tensors
|
|
a_cute = cutlass_torch.from_dlpack(mat_a)
|
|
a_cute = a_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(mat_a))
|
|
|
|
b_cute = cutlass_torch.from_dlpack(mat_b)
|
|
b_cute = b_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(mat_b))
|
|
|
|
sfa_cute = cutlass_torch.from_dlpack(scale_a)
|
|
sfa_cute = sfa_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(scale_a))
|
|
|
|
sfb_cute = cutlass_torch.from_dlpack(scale_b)
|
|
sfb_cute = sfb_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(scale_b))
|
|
|
|
c_cute = cutlass_torch.from_dlpack(out)
|
|
c_cute = c_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(out))
|
|
|
|
offs_cute = cutlass_torch.from_dlpack(offs)
|
|
offs_cute = offs_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(offs))
|
|
|
|
workspace_size = kernel.get_workspace_size(num_experts)
|
|
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=device)
|
|
ws_cute = cutlass_torch.from_dlpack(workspace)
|
|
ws_cute = ws_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(workspace))
|
|
|
|
gsa_cute = cutlass_torch.from_dlpack(global_scale_a)
|
|
gsa_cute = gsa_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(global_scale_a))
|
|
|
|
gsb_cute = cutlass_torch.from_dlpack(global_scale_b)
|
|
gsb_cute = gsb_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(global_scale_b))
|
|
|
|
import cuda.bindings.driver as cuda
|
|
cluster_size = 1
|
|
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
compiled = cute.compile(
|
|
kernel,
|
|
a_cute, b_cute, sfa_cute, sfb_cute, c_cute, offs_cute, ws_cute,
|
|
max_active_clusters, stream,
|
|
global_scale_a=gsa_cute,
|
|
global_scale_b=gsb_cute,
|
|
)
|
|
|
|
compiled(
|
|
a_cute, b_cute, sfa_cute, sfb_cute, c_cute, offs_cute, ws_cute,
|
|
stream,
|
|
global_scale_a=gsa_cute,
|
|
global_scale_b=gsb_cute,
|
|
)
|
|
torch.cuda.synchronize()
|
|
|
|
# ── Compare ──
|
|
cosine = torch.nn.functional.cosine_similarity(
|
|
out.flatten().unsqueeze(0).float(),
|
|
ref_out.flatten().unsqueeze(0).float(),
|
|
).item()
|
|
mse = (out.float() - ref_out.float()).pow(2).mean().item()
|
|
|
|
print(f"\n{'='*70}")
|
|
print(f" RESULT: cosine={cosine:.6f} MSE={mse:.6e}")
|
|
print(f"{'='*70}")
|
|
|
|
if cosine > 0.99:
|
|
print(f" ✅ PASS: CuTeDSL kernel matches BF16 reference")
|
|
elif cosine > 0.95:
|
|
print(f" ⚠️ Close but not perfect — quantization loss?")
|
|
else:
|
|
print(f" ❌ FAIL: kernel output doesn't match reference")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|