The checkpoint's input_scale was designed for training-time FP8 quantization, not NVFP4 activation quantization. Using it as gsa causes x/gsa to exceed the E4M3 block scale maximum (448), leading to systematic magnitude loss in every projection. This accumulates over 61 layers, compressing the logit range and producing garbage tokens. Fix: compute gsa at runtime from actual activation magnitude: gsa = max(|x|) / (6.0 * 448.0) This ensures x/gsa ≤ 2688 (the maximum representable in E4M3 block scales). Applied to: Nvfp4Linear, Nvfp4GroupedLinear, Nvfp4MoE, Nvfp4SharedExpert, Router gate
205 lines
7.7 KiB
Python
205 lines
7.7 KiB
Python
"""CuTeDSL NVFP4 Linear (single GEMM)
|
|
|
|
Generic NVFP4 GEMM runner for attention projections and any single
|
|
linear layer. Uses ScaledGroupedGemmKernel with num_groups=1.
|
|
|
|
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
|
|
"""
|
|
|
|
import torch
|
|
|
|
from dsv4.ops.quantize import (
|
|
quantize_activation_nvfp4,
|
|
quantize_to_nvfp4,
|
|
)
|
|
from dsv4.ops.layouts import (
|
|
make_b_k_major,
|
|
)
|
|
from dsv4.ops.gemm_runner import (
|
|
run_nvfp4_grouped_gemm,
|
|
)
|
|
from dsv4.kernels.gemm.grouped import (
|
|
ceil_div as cutedsl_ceil_div,
|
|
pad_and_swizzle_single,
|
|
)
|
|
from dsv4.ops.custom_ops import register_runner, nvfp4_linear_gemm
|
|
|
|
|
|
class Nvfp4Linear:
|
|
"""Single NVFP4 GEMM using CuTeDSL (num_groups=1).
|
|
|
|
Handles any (K, N) weight matrix in NVFP4 format.
|
|
Simple: quantize activation → GEMM → BF16 output.
|
|
No SiLU, no fusion, no routing.
|
|
|
|
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int,
|
|
max_num_tokens: int = 8192,
|
|
device: str = "cuda",
|
|
):
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.max_num_tokens = max_num_tokens
|
|
self.device = device
|
|
|
|
# Weights (set after construction, then call finalize_weights)
|
|
self.fp4 = None # list of 1 tensor
|
|
self.sf = None # list of 1 tensor
|
|
self.gs = None # list of 1 float
|
|
self.ws2 = None # list of 1 tensor — weight_scale_2 (scalar, folded into global_scale_b)
|
|
|
|
# Processed weights
|
|
self._mat_b = None
|
|
self._scale_b = None
|
|
self._gsb = None
|
|
|
|
# Activation global scale
|
|
self._activation_global_scale = 1.0 / (6.0 * 448.0)
|
|
|
|
# Pre-allocated buffers
|
|
self._padded_x_fp4_buf = None
|
|
self._expert_offsets_buf = None
|
|
self._gsa_buf = None
|
|
self._buffers_allocated = False
|
|
|
|
def finalize_weights(self):
|
|
"""Process weights for CuTeDSL GEMM."""
|
|
# Convert uint8 checkpoint weights to float4_e2m1fn_x2 view
|
|
fp4_view = [w.view(torch.float4_e2m1fn_x2) if w.dtype == torch.uint8 else w for w in self.fp4]
|
|
# Checkpoint weight is (out_features//2, in_features//2) = (N_packed, K_packed)
|
|
# make_b_k_major expects (E, K_packed, N_packed), so we need to permute
|
|
stacked = torch.stack(fp4_view).permute(0, 2, 1).contiguous() # (1, K_packed, N_packed)
|
|
self._mat_b = make_b_k_major(stacked)
|
|
# Checkpoint scale is (N_packed, K_sf) — already in the right row order for the
|
|
# kernel's swizzle. Use assemble_raw_scales_2d3d_3d_side (no transpose),
|
|
# NOT assemble_scales_3d_side (which transposes K_sf↔N).
|
|
from dsv4.ops.layouts import assemble_raw_scales_2d3d_3d_side
|
|
self._scale_b = assemble_raw_scales_2d3d_3d_side(self.sf)
|
|
self._gsb = torch.tensor(self.gs, dtype=torch.float32, device=self.device)
|
|
|
|
# Fold weight_scale_2 into global_scale_b
|
|
# Dequant formula: w = lut[w_packed] * weight_scale * weight_scale_2
|
|
# Production GEMM: y = (x * scale_a * gsa) @ (w * scale_b * gsb)
|
|
# So gsb = input_scale * weight_scale_2
|
|
if self.ws2 is not None and len(self.ws2) > 0 and self.ws2[0] is not None:
|
|
ws2_val = self.ws2[0].float().item()
|
|
self._gsb = self._gsb * ws2_val
|
|
|
|
# Free raw weights
|
|
self.fp4 = None
|
|
self.sf = None
|
|
self.gs = None
|
|
self.ws2 = None
|
|
|
|
# Eagerly JIT-compile the GEMM kernel for this (K, N) shape.
|
|
# Uses num_groups=1 since this is a single linear layer.
|
|
K_packed = self.in_features // 2
|
|
N_packed = self.out_features // 2
|
|
# warmup_compilation(1, K_packed, N_packed, self.device) # Lazy compile on first real forward
|
|
|
|
def _ensure_buffer_size(self, num_tokens: int):
|
|
"""Ensure the padded buffer is large enough for num_tokens."""
|
|
needed_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
|
if self._padded_x_fp4_buf is not None and self._padded_x_fp4_buf.shape[0] >= needed_rows:
|
|
return # Already big enough
|
|
|
|
self._padded_x_fp4_buf = torch.zeros(
|
|
needed_rows, self.in_features // 2, dtype=torch.uint8, device=self.device
|
|
).view(torch.float4_e2m1fn_x2)
|
|
|
|
self._expert_offsets_buf = torch.zeros(1, dtype=torch.int32, device=self.device)
|
|
self._gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
|
|
|
|
def _ensure_initialized(self):
|
|
if self._mat_b is None:
|
|
self.finalize_weights()
|
|
|
|
def _assemble_scales_single_group(self, x_sf):
|
|
"""Assemble 2D-side activation scales for num_groups=1."""
|
|
num_rows, num_cols = x_sf.shape
|
|
padded_rows = cutedsl_ceil_div(num_rows, 128) * 128
|
|
padded_cols = cutedsl_ceil_div(num_cols, 4) * 4
|
|
|
|
buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn)
|
|
buf[:num_rows, :num_cols] = x_sf
|
|
swizzled_flat = pad_and_swizzle_single(buf)
|
|
return swizzled_flat.reshape(padded_rows, padded_cols)
|
|
|
|
def compute_activation_global_scale(self, hidden_states_sample):
|
|
"""Compute activation global scale from a warmup forward."""
|
|
self._ensure_initialized()
|
|
with torch.no_grad():
|
|
_, _, gs = quantize_to_nvfp4(hidden_states_sample)
|
|
self._activation_global_scale = gs
|
|
|
|
|
|
def run(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
"""Forward: BF16 input → NVFP4 GEMM → BF16 output.
|
|
|
|
Uses torch.library.custom_op (nvfp4::linear_gemm) so torch.compile
|
|
treats this as an opaque op. The custom op calls _run_impl internally.
|
|
"""
|
|
if not hasattr(self, '_runner_id'):
|
|
self._runner_id = register_runner(self)
|
|
return nvfp4_linear_gemm(
|
|
hidden_states, self._runner_id, self.out_features,
|
|
)
|
|
|
|
def _run_impl(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
"""Actual implementation — called via custom autograd to be torch.compile-safe."""
|
|
self._ensure_initialized()
|
|
|
|
num_tokens = hidden_states.shape[0]
|
|
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
|
|
|
# Ensure buffer is large enough
|
|
self._ensure_buffer_size(num_tokens)
|
|
|
|
# Compute activation global scale at runtime if requested.
|
|
# This prevents E4M3 block scale overflow when the checkpoint's
|
|
# input_scale is too small for the actual activation magnitudes.
|
|
if getattr(self, '_use_runtime_gsa', False):
|
|
amax = hidden_states.float().abs().max().clamp(min=1e-8).item()
|
|
self._activation_global_scale = amax / (6.0 * 448.0)
|
|
|
|
# Quantize activation
|
|
x_fp4, x_sf = quantize_activation_nvfp4(
|
|
hidden_states, self._activation_global_scale
|
|
)
|
|
|
|
# Scatter x_fp4 into padded buffer
|
|
padded_x_fp4 = self._padded_x_fp4_buf
|
|
padded_x_fp4.view(torch.uint8).zero_()
|
|
padded_x_fp4.view(torch.uint8)[:x_fp4.shape[0]] = x_fp4.view(torch.uint8)
|
|
|
|
# Assemble A-side scales
|
|
scale_a = self._assemble_scales_single_group(x_sf)
|
|
|
|
# Expert offsets: [padded_rows] for 1 group
|
|
expert_offsets = self._expert_offsets_buf
|
|
expert_offsets.fill_(padded_rows)
|
|
|
|
# Global scales
|
|
gsa = self._gsa_buf.fill_(self._activation_global_scale)
|
|
|
|
# Run GEMM
|
|
out = run_nvfp4_grouped_gemm(
|
|
mat_a=padded_x_fp4,
|
|
mat_b=self._mat_b,
|
|
scale_a=scale_a,
|
|
scale_b=self._scale_b,
|
|
expert_offsets=expert_offsets,
|
|
global_scale_a=gsa,
|
|
global_scale_b=self._gsb,
|
|
)
|
|
|
|
return out[:num_tokens]
|
|
|
|
def __call__(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
return self.run(hidden_states)
|