diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index a061430a..c5fa3cfb 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -67,15 +67,29 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE): block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8) block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn) - # Nearest E2M1 + # Nearest E2M1 — memory-efficient clamp approach + # Instead of computing distances to all 8 magnitudes (creates 32x tensor), + # clamp to [0, 6] and round to nearest E2M1 value. + # E2M1 values: 0, 0.5, 1, 1.5, 2, 3, 4, 6 + # Scale to [0, 12] (integer half-steps), round, then map back block_sf_expanded = block_scale.float().unsqueeze(-1) x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8) - - magnitudes = torch.tensor(E2M1_MAGNITUDES, dtype=torch.float32, device=x_bf16.device) signs = torch.sign(x_scaled) - abs_scaled = x_scaled.abs().unsqueeze(-1) - distances = (abs_scaled - magnitudes).abs() - indices = distances.argmin(dim=-1) + abs_scaled = x_scaled.abs().clamp(max=6.0) + + # Scale to half-integer grid: 0, 1, 2, 3, 4, 6, 8, 12 + # Multiply by 2 and round to get: 0, 1, 2, 3, 4, 6, 8, 12 + # But 3.0->6, 3.5->7(not valid)... Use LUT approach but on compressed data + # Actually, simplest correct approach: quantize to 3-bit index + # E2M1 is (1.mantissa) * 2^exp where mantissa is 2 bits + # Values: 0, 0.5, 1, 1.5, 2, 3, 4, 6 + # Simplest: just clamp + round to nearest value with small lookup + half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8) + # Map half-step values to E2M1 indices + # 0->0, 1->1, 2->2, 3->3, 4->4, 5->4, 6->5, 7->5, 8->6, 9->6, 10->6, 11->7, 12->7 + # Use a small lookup table (13 entries, 13 bytes) + step_to_idx = torch.tensor([0,1,2,3,4,4,5,5,6,6,6,7,7], dtype=torch.int8, device=x_bf16.device) + indices = step_to_idx[half_steps.long()] nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8) even = nibbles[..., ::2]