This commit is contained in:
2026-05-20 07:15:01 +00:00
parent 1b4742a438
commit d3a7e7a286

View File

@@ -402,13 +402,13 @@ def warmup_compilation(num_experts, K_packed, N_packed, device,
_warmup_a_bf16 = torch.randn(128, K_packed * 2, dtype=torch.bfloat16, device=device) * 0.1
mat_a, scale_a, _ = quantize_to_nvfp4(_warmup_a_bf16)
del _warmup_a_bf16
_warmup_b_bf16 = torch.randn(num_experts, K_packed * 2, N_packed * 2, dtype=torch.bfloat16, device=device) * 0.1
_warmup_b_bf16 = torch.randn(1, K_packed * 2, N_packed * 2, dtype=torch.bfloat16, device=device) * 0.1 # 1 expert: kernel compiles same regardless of count
mat_b, scale_b, _ = quantize_to_nvfp4(_warmup_b_bf16)
del _warmup_b_bf16
out = torch.zeros(128, N_packed, dtype=torch.bfloat16, device=device)
expert_offsets = torch.full((num_experts,), max(128 // num_experts, 1), dtype=torch.int32, device=device)
global_scale_a = torch.ones(num_experts, dtype=torch.float32, device=device)
global_scale_b = torch.ones(num_experts, dtype=torch.float32, device=device)
expert_offsets = torch.full((1,), 128, dtype=torch.int32, device=device)
global_scale_a = torch.ones(1, dtype=torch.float32, device=device)
global_scale_b = torch.ones(1, dtype=torch.float32, device=device)
kernel = ScaledGroupedGemmKernel(
scenario="2Dx3D",
@@ -614,15 +614,15 @@ def warmup_fused_swiglu_compilation(num_experts, K_packed, N_packed, device,
_warmup_a_bf16 = torch.randn(128, K_packed * 2, dtype=torch.bfloat16, device=device) * 0.1
mat_a, scale_a, _ = quantize_to_nvfp4(_warmup_a_bf16)
del _warmup_a_bf16
_warmup_b_bf16 = torch.randn(num_experts, K_packed * 2, N_packed * 2, dtype=torch.bfloat16, device=device) * 0.1
_warmup_b_bf16 = torch.randn(1, K_packed * 2, N_packed * 2, dtype=torch.bfloat16, device=device) * 0.1 # 1 expert: kernel compiles same regardless of count
mat_b, scale_b, _ = quantize_to_nvfp4(_warmup_b_bf16)
del _warmup_b_bf16
# BF16 output (Stage 1: we still write BF16)
# The fused kernel writes intermediate (N/2) since gate+up → silu result
out = torch.zeros(128, N_packed, dtype=torch.bfloat16, device=device)
expert_offsets = torch.full((num_experts,), max(128 // num_experts, 1), dtype=torch.int32, device=device)
global_scale_a = torch.ones(num_experts, dtype=torch.float32, device=device)
global_scale_b = torch.ones(num_experts, dtype=torch.float32, device=device)
expert_offsets = torch.full((1,), 128, dtype=torch.int32, device=device)
global_scale_a = torch.ones(1, dtype=torch.float32, device=device)
global_scale_b = torch.ones(1, dtype=torch.float32, device=device)
kernel = FusedSwiGLUScaledGroupedGemmKernel(
scenario="2Dx3D",