fix: correct scale factor dimensions in warmup (K_sf = ceil_div(K_packed,8) not ceil_div(K_packed,16))

K_packed = original_K // 2. The scale factor dimension is
K_sf = ceil_div(original_K, 16) = ceil_div(K_packed * 2, 16) = ceil_div(K_packed, 8).
The previous code used ceil_div(K_packed, 16) which was wrong.
This commit is contained in:
2026-05-20 02:08:26 +00:00
parent 8f1a20562f
commit ef398006a7

View File

@@ -322,8 +322,8 @@ def warmup_compilation(num_experts, K_packed, N_packed, device,
# Allocate minimal dummy tensors for compilation
mat_a = torch.zeros(128, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)
mat_b = torch.zeros(num_experts, K_packed, N_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)
K_sf = ceil_div(K_packed, 16) # K in scale-factor blocks (K_packed is already //2, sf is //16 of original)
N_sf = ceil_div(N_packed, 16)
K_sf = ceil_div(K_packed, 8) # K in scale-factor blocks (K_packed is already //2, sf is //16 of original)
N_sf = ceil_div(N_packed, 8)
scale_a = torch.zeros(128, K_sf, dtype=torch.float8_e4m3fn, device=device)
scale_b = torch.zeros(num_experts, N_sf, K_sf, dtype=torch.float8_e4m3fn, device=device)
out = torch.zeros(128, N_packed, dtype=torch.bfloat16, device=device)