fix: memory-efficient E2M1 quantization — no 32x distance tensor

quantize_to_nvfp4 was allocating a (..., n_blocks, block_size, 8)
float32 tensor for nearest-neighbor distances to all 8 E2M1 values.
That's 32x the input size — 10.5GB for a typical batch, causing OOM
with only 3GB free.

New approach: clamp to [0, 6], scale to half-integer steps, round,
then map through a 13-byte lookup table to E2M1 indices.
Peak memory is now ~2x input (x_f32 + x_scaled) instead of 32x.

This makes activation quantization CUDA-graph-safe for the
memory-constrained DeepSeek-V4 on B200 (175GB model / 178GB GPU).
This commit is contained in:
2026-05-16 07:49:38 +00:00
parent a2cac7a7fe
commit baf44c92f8

View File

@@ -67,15 +67,29 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8)
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
# Nearest E2M1
# Nearest E2M1 — memory-efficient clamp approach
# Instead of computing distances to all 8 magnitudes (creates 32x tensor),
# clamp to [0, 6] and round to nearest E2M1 value.
# E2M1 values: 0, 0.5, 1, 1.5, 2, 3, 4, 6
# Scale to [0, 12] (integer half-steps), round, then map back
block_sf_expanded = block_scale.float().unsqueeze(-1)
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
magnitudes = torch.tensor(E2M1_MAGNITUDES, dtype=torch.float32, device=x_bf16.device)
signs = torch.sign(x_scaled)
abs_scaled = x_scaled.abs().unsqueeze(-1)
distances = (abs_scaled - magnitudes).abs()
indices = distances.argmin(dim=-1)
abs_scaled = x_scaled.abs().clamp(max=6.0)
# Scale to half-integer grid: 0, 1, 2, 3, 4, 6, 8, 12
# Multiply by 2 and round to get: 0, 1, 2, 3, 4, 6, 8, 12
# But 3.0->6, 3.5->7(not valid)... Use LUT approach but on compressed data
# Actually, simplest correct approach: quantize to 3-bit index
# E2M1 is (1.mantissa) * 2^exp where mantissa is 2 bits
# Values: 0, 0.5, 1, 1.5, 2, 3, 4, 6
# Simplest: just clamp + round to nearest value with small lookup
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
# Map half-step values to E2M1 indices
# 0->0, 1->1, 2->2, 3->3, 4->4, 5->4, 6->5, 7->5, 8->6, 9->6, 10->6, 11->7, 12->7
# Use a small lookup table (13 entries, 13 bytes)
step_to_idx = torch.tensor([0,1,2,3,4,4,5,5,6,6,6,7,7], dtype=torch.int8, device=x_bf16.device)
indices = step_to_idx[half_steps.long()]
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
even = nibbles[..., ::2]