fix: memory-efficient E2M1 quantization — no 32x distance tensor
quantize_to_nvfp4 was allocating a (..., n_blocks, block_size, 8) float32 tensor for nearest-neighbor distances to all 8 E2M1 values. That's 32x the input size — 10.5GB for a typical batch, causing OOM with only 3GB free. New approach: clamp to [0, 6], scale to half-integer steps, round, then map through a 13-byte lookup table to E2M1 indices. Peak memory is now ~2x input (x_f32 + x_scaled) instead of 32x. This makes activation quantization CUDA-graph-safe for the memory-constrained DeepSeek-V4 on B200 (175GB model / 178GB GPU).
This commit is contained in:
@@ -67,15 +67,29 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
|
||||
block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8)
|
||||
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
|
||||
|
||||
# Nearest E2M1
|
||||
# Nearest E2M1 — memory-efficient clamp approach
|
||||
# Instead of computing distances to all 8 magnitudes (creates 32x tensor),
|
||||
# clamp to [0, 6] and round to nearest E2M1 value.
|
||||
# E2M1 values: 0, 0.5, 1, 1.5, 2, 3, 4, 6
|
||||
# Scale to [0, 12] (integer half-steps), round, then map back
|
||||
block_sf_expanded = block_scale.float().unsqueeze(-1)
|
||||
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
|
||||
|
||||
magnitudes = torch.tensor(E2M1_MAGNITUDES, dtype=torch.float32, device=x_bf16.device)
|
||||
signs = torch.sign(x_scaled)
|
||||
abs_scaled = x_scaled.abs().unsqueeze(-1)
|
||||
distances = (abs_scaled - magnitudes).abs()
|
||||
indices = distances.argmin(dim=-1)
|
||||
abs_scaled = x_scaled.abs().clamp(max=6.0)
|
||||
|
||||
# Scale to half-integer grid: 0, 1, 2, 3, 4, 6, 8, 12
|
||||
# Multiply by 2 and round to get: 0, 1, 2, 3, 4, 6, 8, 12
|
||||
# But 3.0->6, 3.5->7(not valid)... Use LUT approach but on compressed data
|
||||
# Actually, simplest correct approach: quantize to 3-bit index
|
||||
# E2M1 is (1.mantissa) * 2^exp where mantissa is 2 bits
|
||||
# Values: 0, 0.5, 1, 1.5, 2, 3, 4, 6
|
||||
# Simplest: just clamp + round to nearest value with small lookup
|
||||
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
|
||||
# Map half-step values to E2M1 indices
|
||||
# 0->0, 1->1, 2->2, 3->3, 4->4, 5->4, 6->5, 7->5, 8->6, 9->6, 10->6, 11->7, 12->7
|
||||
# Use a small lookup table (13 entries, 13 bytes)
|
||||
step_to_idx = torch.tensor([0,1,2,3,4,4,5,5,6,6,6,7,7], dtype=torch.int8, device=x_bf16.device)
|
||||
indices = step_to_idx[half_steps.long()]
|
||||
|
||||
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
||||
even = nibbles[..., ::2]
|
||||
|
||||
Reference in New Issue
Block a user