fix: correct scale factor dimensions in warmup (K_sf = ceil_div(K_packed,8) not ceil_div(K_packed,16))
K_packed = original_K // 2. The scale factor dimension is K_sf = ceil_div(original_K, 16) = ceil_div(K_packed * 2, 16) = ceil_div(K_packed, 8). The previous code used ceil_div(K_packed, 16) which was wrong.
This commit is contained in:
@@ -322,8 +322,8 @@ def warmup_compilation(num_experts, K_packed, N_packed, device,
|
||||
# Allocate minimal dummy tensors for compilation
|
||||
mat_a = torch.zeros(128, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)
|
||||
mat_b = torch.zeros(num_experts, K_packed, N_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)
|
||||
K_sf = ceil_div(K_packed, 16) # K in scale-factor blocks (K_packed is already //2, sf is //16 of original)
|
||||
N_sf = ceil_div(N_packed, 16)
|
||||
K_sf = ceil_div(K_packed, 8) # K in scale-factor blocks (K_packed is already //2, sf is //16 of original)
|
||||
N_sf = ceil_div(N_packed, 8)
|
||||
scale_a = torch.zeros(128, K_sf, dtype=torch.float8_e4m3fn, device=device)
|
||||
scale_b = torch.zeros(num_experts, N_sf, K_sf, dtype=torch.float8_e4m3fn, device=device)
|
||||
out = torch.zeros(128, N_packed, dtype=torch.bfloat16, device=device)
|
||||
|
||||
Reference in New Issue
Block a user