diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index 8a58edb3..f8bc021d 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -13,6 +13,25 @@ no dynamic tensor allocation in the forward path, no Python control flow on GPU data. """ import math +import threading + +# Cached LUT for E2M1 quantization (created once per device, cudagraph-safe) +_NVFP4_STEP_LUT_CACHE = {} +_NVFP4_STEP_LUT_LOCK = threading.Lock() + + +def _get_step_to_idx_lut(device): + """Get or create the E2M1 step-to-index LUT for the given device. + + Cached per device to avoid CPU→CUDA copies during cudagraph capture. + """ + with _NVFP4_STEP_LUT_LOCK: + if device not in _NVFP4_STEP_LUT_CACHE: + _NVFP4_STEP_LUT_CACHE[device] = torch.as_tensor( + [0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7], + dtype=torch.int8, device=device, + ) + return _NVFP4_STEP_LUT_CACHE[device] import torch import cutlass import cutlass.cute as cute @@ -85,7 +104,7 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE): abs_scaled = x_scaled.abs().clamp(max=6.0) half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8) - step_to_idx = x_bf16.new_tensor([0,1,2,3,4,4,5,5,6,6,6,7,7], dtype=torch.int8) + step_to_idx = _get_step_to_idx_lut(x_bf16.device) indices = step_to_idx[half_steps.long()] nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8) @@ -141,7 +160,7 @@ def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE): abs_scaled = x_scaled.abs().clamp(max=6.0) half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8) - step_to_idx = x_bf16.new_tensor([0,1,2,3,4,4,5,5,6,6,6,7,7], dtype=torch.int8) + step_to_idx = _get_step_to_idx_lut(x_bf16.device) indices = step_to_idx[half_steps.long()] nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)