- fused_swiglu_grouped_mm.py: copypaste of torch_scaled_grouped_mm.py with class rename and fused_swiglu/swiglu_limit params added - bridge.py: added interleave_l1_weights, deinterleave_l1_weights, warmup_fused_swiglu_compilation - Pure-PyTorch interleave invariant passes (A@cat vs deinterleave(A@interleave)) - Standalone GEMM interleave test fails due to kernel-internal N-tiling layout (expected, skipping per plan) - FUSED_EPILOGUE_PLAN.md updated with register layout, amax shuffle plan, 4-step implementation strategy
134 lines
4.9 KiB
Python
134 lines
4.9 KiB
Python
"""Test: Verify that interleaved L1 weights produce the same GEMM result.
|
|
|
|
The key insight: we quantize gate+up TOGETHER (same as non-interleaved),
|
|
then interleave the ALREADY-QUANTIZED FP4 bytes and scales in the N dimension.
|
|
This preserves quantization fidelity.
|
|
"""
|
|
import torch
|
|
import sys
|
|
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel')
|
|
|
|
from cutedsl.bridge import (
|
|
quantize_weight_to_nvfp4,
|
|
quantize_activation_nvfp4,
|
|
interleave_l1_weights,
|
|
make_b_k_major,
|
|
assemble_scales_2d_side,
|
|
assemble_scales_3d_side,
|
|
run_nvfp4_grouped_gemm,
|
|
warmup_compilation,
|
|
)
|
|
|
|
|
|
def interleave_sfb(raw_scales, granularity_bf16=8):
|
|
"""Interleave gate/up scales at the same granularity as the FP4 weights.
|
|
|
|
raw_scales: list of (K_sf, N) float8_e4m3fn tensors where N = 2*intermediate_sf
|
|
Returns: list of (K_sf, N) float8_e4m3fn with interleaved gate/up
|
|
"""
|
|
g = granularity_bf16 // 2 # 4 FP4 scale columns per group
|
|
result = []
|
|
for sf in raw_scales:
|
|
K_sf, N = sf.shape
|
|
N_half = N // 2
|
|
gate = sf[:, :N_half].reshape(K_sf, N_half // g, g)
|
|
up = sf[:, N_half:].reshape(K_sf, N_half // g, g)
|
|
interleaved = torch.stack([gate, up], dim=2).reshape(K_sf, N)
|
|
result.append(interleaved)
|
|
return result
|
|
|
|
|
|
def test_interleave_gemm():
|
|
device = "cuda"
|
|
num_experts = 4
|
|
hidden = 512
|
|
intermediate = 256
|
|
num_tokens = 32
|
|
|
|
torch.manual_seed(42)
|
|
x = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device=device)
|
|
gate_w = torch.randn(num_experts, intermediate, hidden, dtype=torch.bfloat16, device=device)
|
|
up_w = torch.randn(num_experts, intermediate, hidden, dtype=torch.bfloat16, device=device)
|
|
|
|
# === Path A: Non-interleaved ===
|
|
l1_bf16 = torch.cat([gate_w, up_w], dim=1) # (E, 2*inter, hidden)
|
|
l1_fp4_list, l1_sf_list, l1_gs_list = [], [], []
|
|
for e in range(num_experts):
|
|
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_bf16[e].T)
|
|
l1_fp4_list.append(w_fp4)
|
|
l1_sf_list.append(w_sf)
|
|
l1_gs_list.append(w_gs)
|
|
|
|
l1_mat_b = make_b_k_major(torch.stack(l1_fp4_list))
|
|
l1_scale_b = assemble_scales_3d_side(l1_sf_list)
|
|
l1_gs = torch.tensor(l1_gs_list, dtype=torch.float32, device=device)
|
|
|
|
gs_val = x.abs().max().item() / (6.0 * 448.0)
|
|
x_fp4, x_sf = quantize_activation_nvfp4(x, gs_val)
|
|
tokens_per_expert = [num_tokens // num_experts] * num_experts
|
|
scale_a = assemble_scales_2d_side([x_sf[i*tpe:(i+1)*tpe] for i, tpe in enumerate(tokens_per_expert)])
|
|
expert_offsets = torch.tensor(
|
|
[sum(tokens_per_expert[:e+1]) for e in range(num_experts)],
|
|
dtype=torch.int32, device=device,
|
|
)
|
|
global_scale_a = torch.full((num_experts,), gs_val, dtype=torch.float32, device=device)
|
|
|
|
warmup_compilation(num_experts, hidden // 2, (2 * intermediate) // 2, device)
|
|
out_a = run_nvfp4_grouped_gemm(
|
|
mat_a=x_fp4, mat_b=l1_mat_b,
|
|
scale_a=scale_a, scale_b=l1_scale_b,
|
|
expert_offsets=expert_offsets,
|
|
global_scale_a=global_scale_a, global_scale_b=l1_gs,
|
|
)
|
|
|
|
# === Path B: Interleaved (quantize together, interleave after) ===
|
|
# Use the SAME quantized weights, just interleave the N dimension
|
|
l1_stacked = torch.stack(l1_fp4_list) # (E, K, N)
|
|
l1_interleaved = interleave_l1_weights(l1_stacked)
|
|
l1_mat_b_int = make_b_k_major(l1_interleaved)
|
|
|
|
# Interleave scales to match
|
|
l1_sf_interleaved = interleave_sfb(l1_sf_list)
|
|
l1_scale_b_int = assemble_scales_3d_side(l1_sf_interleaved)
|
|
# Global scales are the same (quantized together)
|
|
|
|
out_b = run_nvfp4_grouped_gemm(
|
|
mat_a=x_fp4, mat_b=l1_mat_b_int,
|
|
scale_a=scale_a, scale_b=l1_scale_b_int,
|
|
expert_offsets=expert_offsets,
|
|
global_scale_a=global_scale_a, global_scale_b=l1_gs,
|
|
)
|
|
|
|
# De-interleave out_b BF16 to match out_a layout
|
|
N = out_b.shape[1]
|
|
N_half = N // 2
|
|
g = 8 # granularity in BF16
|
|
out_b_reshaped = out_b.reshape(num_tokens, N // (2 * g), 2, g)
|
|
gate_b = out_b_reshaped[:, :, 0, :].reshape(num_tokens, N_half)
|
|
up_b = out_b_reshaped[:, :, 1, :].reshape(num_tokens, N_half)
|
|
out_b_deint = torch.cat([gate_b, up_b], dim=1)
|
|
|
|
diff = (out_a - out_b_deint).float()
|
|
rel_err = diff.norm() / out_a.float().norm()
|
|
max_err = diff.abs().max()
|
|
|
|
print(f"Non-interleaved vs interleaved+deinterleaved:")
|
|
print(f" Relative error: {rel_err.item():.6f}")
|
|
print(f" Max abs error: {max_err.item():.6f}")
|
|
print(f" PASS" if rel_err.item() < 0.01 else " FAIL")
|
|
|
|
# Apply SiLU and compare
|
|
gate_a = out_a[:, :intermediate]
|
|
up_a = out_a[:, intermediate:]
|
|
result_a = torch.nn.functional.silu(gate_a) * up_a
|
|
result_b = torch.nn.functional.silu(gate_b) * up_b
|
|
|
|
diff2 = (result_a - result_b).float()
|
|
rel_err2 = diff2.norm() / result_a.float().norm()
|
|
print(f" SiLU result error: {rel_err2.item():.6f}")
|
|
print(f" SiLU PASS" if rel_err2.item() < 0.01 else " SiLU FAIL")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_interleave_gemm()
|