diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index b59f60b6..07e3048b 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -322,8 +322,8 @@ def warmup_compilation(num_experts, K_packed, N_packed, device, # Allocate minimal dummy tensors for compilation mat_a = torch.zeros(128, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2) mat_b = torch.zeros(num_experts, K_packed, N_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2) - K_sf = ceil_div(K_packed, 16) # K in scale-factor blocks (K_packed is already //2, sf is //16 of original) - N_sf = ceil_div(N_packed, 16) + K_sf = ceil_div(K_packed, 8) # K in scale-factor blocks (K_packed is already //2, sf is //16 of original) + N_sf = ceil_div(N_packed, 8) scale_a = torch.zeros(128, K_sf, dtype=torch.float8_e4m3fn, device=device) scale_b = torch.zeros(num_experts, N_sf, K_sf, dtype=torch.float8_e4m3fn, device=device) out = torch.zeros(128, N_packed, dtype=torch.bfloat16, device=device)