fix: cache E2M1 step_to_idx LUT per device (no CPU->CUDA copy in forward)
torch.tensor() and new_tensor() both trigger CPU->CUDA copies during cudagraph capture. Pre-cache the LUT on first use per device.
This commit is contained in:
@@ -13,6 +13,25 @@ no dynamic tensor allocation in the forward path, no Python control flow
|
||||
on GPU data.
|
||||
"""
|
||||
import math
|
||||
import threading
|
||||
|
||||
# Cached LUT for E2M1 quantization (created once per device, cudagraph-safe)
|
||||
_NVFP4_STEP_LUT_CACHE = {}
|
||||
_NVFP4_STEP_LUT_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _get_step_to_idx_lut(device):
|
||||
"""Get or create the E2M1 step-to-index LUT for the given device.
|
||||
|
||||
Cached per device to avoid CPU→CUDA copies during cudagraph capture.
|
||||
"""
|
||||
with _NVFP4_STEP_LUT_LOCK:
|
||||
if device not in _NVFP4_STEP_LUT_CACHE:
|
||||
_NVFP4_STEP_LUT_CACHE[device] = torch.as_tensor(
|
||||
[0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7],
|
||||
dtype=torch.int8, device=device,
|
||||
)
|
||||
return _NVFP4_STEP_LUT_CACHE[device]
|
||||
import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
@@ -85,7 +104,7 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
|
||||
abs_scaled = x_scaled.abs().clamp(max=6.0)
|
||||
|
||||
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
|
||||
step_to_idx = x_bf16.new_tensor([0,1,2,3,4,4,5,5,6,6,6,7,7], dtype=torch.int8)
|
||||
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
|
||||
indices = step_to_idx[half_steps.long()]
|
||||
|
||||
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
||||
@@ -141,7 +160,7 @@ def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE):
|
||||
abs_scaled = x_scaled.abs().clamp(max=6.0)
|
||||
|
||||
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
|
||||
step_to_idx = x_bf16.new_tensor([0,1,2,3,4,4,5,5,6,6,6,7,7], dtype=torch.int8)
|
||||
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
|
||||
indices = step_to_idx[half_steps.long()]
|
||||
|
||||
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
||||
|
||||
Reference in New Issue
Block a user