diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index 26fb538c..b15a7ff3 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -402,13 +402,13 @@ def warmup_compilation(num_experts, K_packed, N_packed, device, _warmup_a_bf16 = torch.randn(128, K_packed * 2, dtype=torch.bfloat16, device=device) * 0.1 mat_a, scale_a, _ = quantize_to_nvfp4(_warmup_a_bf16) del _warmup_a_bf16 - _warmup_b_bf16 = torch.randn(num_experts, K_packed * 2, N_packed * 2, dtype=torch.bfloat16, device=device) * 0.1 + _warmup_b_bf16 = torch.randn(1, K_packed * 2, N_packed * 2, dtype=torch.bfloat16, device=device) * 0.1 # 1 expert: kernel compiles same regardless of count mat_b, scale_b, _ = quantize_to_nvfp4(_warmup_b_bf16) del _warmup_b_bf16 out = torch.zeros(128, N_packed, dtype=torch.bfloat16, device=device) - expert_offsets = torch.full((num_experts,), max(128 // num_experts, 1), dtype=torch.int32, device=device) - global_scale_a = torch.ones(num_experts, dtype=torch.float32, device=device) - global_scale_b = torch.ones(num_experts, dtype=torch.float32, device=device) + expert_offsets = torch.full((1,), 128, dtype=torch.int32, device=device) + global_scale_a = torch.ones(1, dtype=torch.float32, device=device) + global_scale_b = torch.ones(1, dtype=torch.float32, device=device) kernel = ScaledGroupedGemmKernel( scenario="2Dx3D", @@ -614,15 +614,15 @@ def warmup_fused_swiglu_compilation(num_experts, K_packed, N_packed, device, _warmup_a_bf16 = torch.randn(128, K_packed * 2, dtype=torch.bfloat16, device=device) * 0.1 mat_a, scale_a, _ = quantize_to_nvfp4(_warmup_a_bf16) del _warmup_a_bf16 - _warmup_b_bf16 = torch.randn(num_experts, K_packed * 2, N_packed * 2, dtype=torch.bfloat16, device=device) * 0.1 + _warmup_b_bf16 = torch.randn(1, K_packed * 2, N_packed * 2, dtype=torch.bfloat16, device=device) * 0.1 # 1 expert: kernel compiles same regardless of count mat_b, scale_b, _ = quantize_to_nvfp4(_warmup_b_bf16) del _warmup_b_bf16 # BF16 output (Stage 1: we still write BF16) # The fused kernel writes intermediate (N/2) since gate+up → silu result out = torch.zeros(128, N_packed, dtype=torch.bfloat16, device=device) - expert_offsets = torch.full((num_experts,), max(128 // num_experts, 1), dtype=torch.int32, device=device) - global_scale_a = torch.ones(num_experts, dtype=torch.float32, device=device) - global_scale_b = torch.ones(num_experts, dtype=torch.float32, device=device) + expert_offsets = torch.full((1,), 128, dtype=torch.int32, device=device) + global_scale_a = torch.ones(1, dtype=torch.float32, device=device) + global_scale_b = torch.ones(1, dtype=torch.float32, device=device) kernel = FusedSwiGLUScaledGroupedGemmKernel( scenario="2Dx3D",