From baf44c92f8d628029a2b44707e3eb467e4dddb63 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 16 May 2026 07:49:38 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20memory-efficient=20E2M1=20quantization?= =?UTF-8?q?=20=E2=80=94=20no=2032x=20distance=20tensor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit quantize_to_nvfp4 was allocating a (..., n_blocks, block_size, 8) float32 tensor for nearest-neighbor distances to all 8 E2M1 values. That's 32x the input size — 10.5GB for a typical batch, causing OOM with only 3GB free. New approach: clamp to [0, 6], scale to half-integer steps, round, then map through a 13-byte lookup table to E2M1 indices. Peak memory is now ~2x input (x_f32 + x_scaled) instead of 32x. This makes activation quantization CUDA-graph-safe for the memory-constrained DeepSeek-V4 on B200 (175GB model / 178GB GPU). --- cutedsl/bridge.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index a061430a..c5fa3cfb 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -67,15 +67,29 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE): block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8) block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn) - # Nearest E2M1 + # Nearest E2M1 — memory-efficient clamp approach + # Instead of computing distances to all 8 magnitudes (creates 32x tensor), + # clamp to [0, 6] and round to nearest E2M1 value. + # E2M1 values: 0, 0.5, 1, 1.5, 2, 3, 4, 6 + # Scale to [0, 12] (integer half-steps), round, then map back block_sf_expanded = block_scale.float().unsqueeze(-1) x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8) - - magnitudes = torch.tensor(E2M1_MAGNITUDES, dtype=torch.float32, device=x_bf16.device) signs = torch.sign(x_scaled) - abs_scaled = x_scaled.abs().unsqueeze(-1) - distances = (abs_scaled - magnitudes).abs() - indices = distances.argmin(dim=-1) + abs_scaled = x_scaled.abs().clamp(max=6.0) + + # Scale to half-integer grid: 0, 1, 2, 3, 4, 6, 8, 12 + # Multiply by 2 and round to get: 0, 1, 2, 3, 4, 6, 8, 12 + # But 3.0->6, 3.5->7(not valid)... Use LUT approach but on compressed data + # Actually, simplest correct approach: quantize to 3-bit index + # E2M1 is (1.mantissa) * 2^exp where mantissa is 2 bits + # Values: 0, 0.5, 1, 1.5, 2, 3, 4, 6 + # Simplest: just clamp + round to nearest value with small lookup + half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8) + # Map half-step values to E2M1 indices + # 0->0, 1->1, 2->2, 3->3, 4->4, 5->4, 6->5, 7->5, 8->6, 9->6, 10->6, 11->7, 12->7 + # Use a small lookup table (13 entries, 13 bytes) + step_to_idx = torch.tensor([0,1,2,3,4,4,5,5,6,6,6,7,7], dtype=torch.int8, device=x_bf16.device) + indices = step_to_idx[half_steps.long()] nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8) even = nibbles[..., ::2]