fix: correct weight quantization for CuTeDSL kernel
Weight K dimension (hidden) must be the packed dimension, not N. Block scales computed along K dim. FP4 packing along K.
This commit is contained in:
@@ -47,7 +47,7 @@ def quantize_bf16_to_nvfp4(x_bf16, block_size=16):
|
||||
"""Quantize BF16 tensor to NVFP4 (E2M1 + E4M3 block scales + global scale).
|
||||
|
||||
Returns (x_fp4, block_scales, global_scale) where:
|
||||
x_fp4: torch.float4_e2m1fn_x2 with same shape (packed along last dim)
|
||||
x_fp4: torch.float4_e2m1fn_x2 with same logical shape (packed along last dim)
|
||||
block_scales: float8_e4m3fn with shape (..., ceil_div(last_dim, block_size))
|
||||
global_scale: float32 scalar
|
||||
"""
|
||||
@@ -72,19 +72,16 @@ def quantize_bf16_to_nvfp4(x_bf16, block_size=16):
|
||||
|
||||
# Quantize to E2M1
|
||||
E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
||||
# For each value, find nearest E2M1 magnitude
|
||||
x_blocks = x_reshaped # (..., n_blocks, block_size)
|
||||
block_sf_expanded = block_scale.float().unsqueeze(-1) # (..., n_blocks, 1)
|
||||
x_scaled = x_blocks / block_sf_expanded.clamp(min=1e-8) # normalize by block scale
|
||||
x_blocks = x_reshaped
|
||||
block_sf_expanded = block_scale.float().unsqueeze(-1)
|
||||
x_scaled = x_blocks / block_sf_expanded.clamp(min=1e-8)
|
||||
|
||||
# Nearest E2M1
|
||||
magnitudes = torch.tensor(E2M1_MAGNITUDES, dtype=torch.float32, device=x_bf16.device)
|
||||
signs = torch.sign(x_scaled)
|
||||
abs_scaled = x_scaled.abs().unsqueeze(-1) # (..., block_size, 1)
|
||||
distances = (abs_scaled - magnitudes).abs() # (..., block_size, 8)
|
||||
indices = distances.argmin(dim=-1) # (..., block_size)
|
||||
abs_scaled = x_scaled.abs().unsqueeze(-1)
|
||||
distances = (abs_scaled - magnitudes).abs()
|
||||
indices = distances.argmin(dim=-1)
|
||||
|
||||
# Sign: positive = 0-7, negative = 8-15
|
||||
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
||||
|
||||
# Pack pairs: byte = (odd_nibble << 4) | even_nibble
|
||||
@@ -92,10 +89,13 @@ def quantize_bf16_to_nvfp4(x_bf16, block_size=16):
|
||||
odd = nibbles[..., 1::2]
|
||||
packed = (odd << 4) | even
|
||||
|
||||
# Reshape back to original shape (with packed last dim)
|
||||
orig_shape = list(x_bf16.shape)
|
||||
orig_shape[-1] = ceil_div(orig_shape[-1], 2)
|
||||
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(orig_shape)
|
||||
# View as float4_e2m1fn_x2 — same logical shape, packed last dim halved
|
||||
# The logical shape has the original last_dim, but stored packed
|
||||
# float4_e2m1fn_x2: each element is 1 byte = 2 FP4 values
|
||||
# Shape: (..., last_dim // 2) in float4_e2m1fn_x2
|
||||
packed_shape = list(x_bf16.shape)
|
||||
packed_shape[-1] = last_dim // 2
|
||||
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
|
||||
|
||||
# Reshape block scales
|
||||
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
|
||||
@@ -161,11 +161,51 @@ def main():
|
||||
|
||||
# ── Quantize to NVFP4 ──
|
||||
x_fp4, x_sf, x_gs = quantize_bf16_to_nvfp4(x_bf16)
|
||||
|
||||
# For weights: the kernel expects (experts, hidden, intermediate) with
|
||||
# packed_dim=1 (the hidden/K dimension is packed).
|
||||
# w_bf16[e] is (hidden, intermediate).
|
||||
# We quantize each expert weight, keeping the packed dim as hidden.
|
||||
w_fp4_list, w_sf_list, w_gs_list = [], [], []
|
||||
for e in range(num_experts):
|
||||
w_fp4, w_sf, w_gs = quantize_bf16_to_nvfp4(w_bf16[e])
|
||||
w = w_bf16[e] # (hidden, intermediate) — K=hidden, N=intermediate
|
||||
w_f32 = w.float()
|
||||
w_amax = w_f32.abs().max().clamp(min=1e-8).float()
|
||||
w_gs = w_amax / (6.0 * 448.0)
|
||||
w_norm = w_f32 / w_gs
|
||||
|
||||
# Block scales along the K dimension (dim 0 = hidden)
|
||||
# Scale shape: (ceil_div(hidden, 16), intermediate)
|
||||
k_blocks = ceil_div(hidden, block_size)
|
||||
if hidden % block_size != 0:
|
||||
w_norm = torch.nn.functional.pad(w_norm, (0, 0, 0, k_blocks * block_size - hidden))
|
||||
|
||||
w_reshaped = w_norm.reshape(k_blocks, block_size, intermediate)
|
||||
w_block_amax = w_reshaped.abs().amax(dim=1).clamp(min=1e-8) # (k_blocks, intermediate)
|
||||
w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn)
|
||||
|
||||
# Quantize to E2M1 along K (dim 0)
|
||||
E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
||||
w_block_sf = w_sf.float().unsqueeze(1) # (k_blocks, 1, intermediate)
|
||||
w_scaled = w_reshaped / w_block_sf.clamp(min=1e-8)
|
||||
|
||||
magnitudes = torch.tensor(E2M1_MAGNITUDES, dtype=torch.float32, device=device)
|
||||
signs = torch.sign(w_scaled)
|
||||
abs_scaled = w_scaled.abs().unsqueeze(-1)
|
||||
distances = (abs_scaled - magnitudes).abs()
|
||||
indices = distances.argmin(dim=-1)
|
||||
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
||||
|
||||
# Pack pairs along K (block_size dim, which is dim 1 after reshape)
|
||||
even = nibbles[:, ::2, :]
|
||||
odd = nibbles[:, 1::2, :]
|
||||
packed = (odd << 4) | even # (k_blocks, block_size//2, intermediate)
|
||||
|
||||
# Reshape to (hidden//2, intermediate) in float4_e2m1fn_x2
|
||||
w_fp4 = packed.reshape(hidden // 2, intermediate).view(torch.float4_e2m1fn_x2)
|
||||
|
||||
w_fp4_list.append(w_fp4)
|
||||
w_sf_list.append(w_sf)
|
||||
w_sf_list.append(w_sf) # (k_blocks, intermediate) = (hidden//16, intermediate)
|
||||
w_gs_list.append(w_gs)
|
||||
|
||||
# Verify quantization roundtrip
|
||||
@@ -201,10 +241,15 @@ def main():
|
||||
global_scale_a = torch.tensor([x_gs] * num_experts, dtype=torch.float32, device=device)
|
||||
global_scale_b = torch.tensor([w_gs_list[e] for e in range(num_experts)], dtype=torch.float32, device=device)
|
||||
|
||||
# mat_a is already (tokens_sum, K_packed)
|
||||
# mat_a is already (tokens_sum, K_packed) in float4_e2m1fn_x2
|
||||
# The kernel's 2Dx3D scenario expects mat_a: (tokens, hidden) where
|
||||
# hidden is the LOGICAL K dimension (packed as float4_e2m1fn_x2)
|
||||
mat_a = x_fp4
|
||||
|
||||
# mat_b needs to be (experts, K_packed, N_packed) — K-major
|
||||
# mat_b: (experts, hidden, intermediate) in float4_e2m1fn_x2
|
||||
# packed_dim=1 means hidden (K) is packed
|
||||
# w_bf16[e] is (hidden, intermediate) — we need (hidden, intermediate) in FP4
|
||||
# with K (hidden) as the packed dimension
|
||||
mat_b = torch.stack(w_fp4_list) # (experts, K_packed, N_packed)
|
||||
|
||||
print(f"\nKernel inputs:")
|
||||
|
||||
Reference in New Issue
Block a user