- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py - Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc. - Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda) - Moved PyTorch bridges to dsv4/ops/ - Moved nn.Module layers to dsv4layers/ - Moved reference implementations to dsv4/reference/ - Moved vendored CUTLASS code to vendored/ - Archived ~190 debug tests to tests/archive/ - Kept ~15 canonical tests in tests/unit/ - Updated all import paths - Added stubs for future components (model/, cache/, loader/) - Updated pyproject.toml: dsv4-inference package name
99 lines
3.5 KiB
Python
99 lines
3.5 KiB
Python
"""Test: Validate gate/up subtile detection and SiLU on gate subtiles.
|
|
|
|
This test runs the fused kernel with:
|
|
- Gate subtiles (0,1): SiLU applied, NOT written to GMEM
|
|
- Up subtiles (2,3): kept as-is, written to GMEM at positions (0,1)
|
|
|
|
Expected output: (M, intermediate) BF16 with up values.
|
|
The output should match the up portion of the standard L1 GEMM output.
|
|
"""
|
|
import torch
|
|
import sys
|
|
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel')
|
|
|
|
from dsv4.ops.quantize import (
|
|
quantize_weight_to_nvfp4,
|
|
quantize_activation_nvfp4,
|
|
)
|
|
from dsv4.ops.layouts import (
|
|
make_b_k_major,
|
|
assemble_scales_2d_side,
|
|
assemble_scales_3d_side,
|
|
)
|
|
from dsv4.ops.gemm_runner import (
|
|
run_nvfp4_grouped_gemm,
|
|
run_fused_swiglu_grouped_gemm,
|
|
warmup_compilation,
|
|
)
|
|
|
|
|
|
def test_gate_up_subtile():
|
|
device = "cuda"
|
|
num_experts = 4
|
|
hidden = 512
|
|
intermediate = 256
|
|
num_tokens = 32
|
|
|
|
torch.manual_seed(42)
|
|
x = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device=device)
|
|
l1_w = torch.randn(num_experts, 2 * intermediate, hidden, dtype=torch.bfloat16, device=device)
|
|
|
|
l1_fp4_list, l1_sf_list, l1_gs_list = [], [], []
|
|
for e in range(num_experts):
|
|
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_w[e].T)
|
|
l1_fp4_list.append(w_fp4)
|
|
l1_sf_list.append(w_sf)
|
|
l1_gs_list.append(w_gs)
|
|
|
|
l1_mat_b = make_b_k_major(torch.stack(l1_fp4_list))
|
|
l1_scale_b = assemble_scales_3d_side(l1_sf_list)
|
|
l1_gs = torch.tensor(l1_gs_list, dtype=torch.float32, device=device)
|
|
|
|
gs_val = x.abs().max().item() / (6.0 * 448.0)
|
|
x_fp4, x_sf = quantize_activation_nvfp4(x, gs_val)
|
|
tokens_per_expert = [num_tokens // num_experts] * num_experts
|
|
scale_a = assemble_scales_2d_side([x_sf[i*tpe:(i+1)*tpe] for i, tpe in enumerate(tokens_per_expert)])
|
|
expert_offsets = torch.tensor(
|
|
[sum(tokens_per_expert[:e+1]) for e in range(num_experts)],
|
|
dtype=torch.int32, device=device,
|
|
)
|
|
global_scale_a = torch.full((num_experts,), gs_val, dtype=torch.float32, device=device)
|
|
|
|
warmup_compilation(num_experts, hidden // 2, (2 * intermediate) // 2, device)
|
|
|
|
# 1. Standard L1 GEMM
|
|
out_bf16 = run_nvfp4_grouped_gemm(
|
|
mat_a=x_fp4, mat_b=l1_mat_b,
|
|
scale_a=scale_a, scale_b=l1_scale_b,
|
|
expert_offsets=expert_offsets,
|
|
global_scale_a=global_scale_a, global_scale_b=l1_gs,
|
|
)
|
|
gate_ref = out_bf16[:, :intermediate]
|
|
up_ref = out_bf16[:, intermediate:]
|
|
print(f"Standard L1 output: shape={out_bf16.shape}")
|
|
print(f"Gate ref amax: {gate_ref.abs().amax().item():.4f}")
|
|
print(f"Up ref amax: {up_ref.abs().amax().item():.4f}")
|
|
|
|
# 2. Fused kernel (gate: SiLU, up: as-is, only up written to GMEM)
|
|
print("\nRunning fused kernel...")
|
|
out_fused = run_fused_swiglu_grouped_gemm(
|
|
mat_a=x_fp4, mat_b=l1_mat_b,
|
|
scale_a=scale_a, scale_b=l1_scale_b,
|
|
expert_offsets=expert_offsets,
|
|
global_scale_a=global_scale_a, global_scale_b=l1_gs,
|
|
)
|
|
print(f"Fused output: shape={out_fused.shape}, amax={out_fused.abs().amax().item():.4f}")
|
|
|
|
# 3. Compare: fused output should match the up half of the standard output
|
|
diff = (out_fused - up_ref).float()
|
|
rel_err = diff.norm() / up_ref.float().norm()
|
|
max_err = diff.abs().max()
|
|
print(f"\n=== Results ===")
|
|
print(f"Rel error vs up_ref: {rel_err.item():.6f}")
|
|
print(f"Max abs error: {max_err.item():.6f}")
|
|
print(f"PASS" if rel_err.item() < 0.05 else "FAIL")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_gate_up_subtile()
|