fix: cache E2M1 step_to_idx LUT per device (no CPU->CUDA copy in forward)

torch.tensor() and new_tensor() both trigger CPU->CUDA copies during
cudagraph capture. Pre-cache the LUT on first use per device.
This commit is contained in:
2026-05-16 18:48:31 +00:00
parent 6c298be842
commit 2f68c7ba77

View File

@@ -13,6 +13,25 @@ no dynamic tensor allocation in the forward path, no Python control flow
on GPU data.
"""
import math
import threading
# Cached LUT for E2M1 quantization (created once per device, cudagraph-safe)
_NVFP4_STEP_LUT_CACHE = {}
_NVFP4_STEP_LUT_LOCK = threading.Lock()
def _get_step_to_idx_lut(device):
"""Get or create the E2M1 step-to-index LUT for the given device.
Cached per device to avoid CPU→CUDA copies during cudagraph capture.
"""
with _NVFP4_STEP_LUT_LOCK:
if device not in _NVFP4_STEP_LUT_CACHE:
_NVFP4_STEP_LUT_CACHE[device] = torch.as_tensor(
[0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7],
dtype=torch.int8, device=device,
)
return _NVFP4_STEP_LUT_CACHE[device]
import torch
import cutlass
import cutlass.cute as cute
@@ -85,7 +104,7 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
abs_scaled = x_scaled.abs().clamp(max=6.0)
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
step_to_idx = x_bf16.new_tensor([0,1,2,3,4,4,5,5,6,6,6,7,7], dtype=torch.int8)
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
indices = step_to_idx[half_steps.long()]
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
@@ -141,7 +160,7 @@ def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE):
abs_scaled = x_scaled.abs().clamp(max=6.0)
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
step_to_idx = x_bf16.new_tensor([0,1,2,3,4,4,5,5,6,6,6,7,7], dtype=torch.int8)
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
indices = step_to_idx[half_steps.long()]
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)