diff --git a/dsv4/layers/linear.py b/dsv4/layers/linear.py index a7a1554b..30ba86f7 100644 --- a/dsv4/layers/linear.py +++ b/dsv4/layers/linear.py @@ -213,5 +213,55 @@ class Nvfp4Linear: return out[:num_tokens] + def run_from_quantized(self, quant: 'QuantizedActivation') -> torch.Tensor: + """Run GEMM with pre-quantized activation (skip quantize step). + + Used when the input has already been quantized by a fused + RMSNorm+quantize kernel. Saves 2 kernel launches per call. + + Args: + quant: QuantizedActivation with x_fp4, x_sf, gsa + """ + from dsv4.ops.quantize import QuantizedActivation + assert isinstance(quant, QuantizedActivation) + + self._ensure_initialized() + num_tokens = quant.num_tokens + padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 + self._ensure_buffer_size(num_tokens) + + # Scatter pre-quantized x_fp4 into padded buffer + padded_x_fp4 = self._padded_x_fp4_buf + padded_x_fp4.view(torch.uint8).zero_() + padded_x_fp4.view(torch.uint8)[:quant.x_fp4.shape[0]] = quant.x_fp4.view(torch.uint8) + + # Assemble A-side scales from pre-quantized sf + scale_a = self._assemble_scales_single_group(quant.x_sf) + + # Expert offsets + expert_offsets = self._expert_offsets_buf + expert_offsets.fill_(padded_rows) + + # Global scales — use the per-row gsa from the fused kernel + # Reshape to (1,) if scalar, or use per-row (M,) broadcast + gsa = quant.gsa[:1].reshape(1) if quant.gsa.shape[0] == 1 else quant.gsa[:num_tokens] + if gsa.shape != self._gsa_buf.shape: + self._gsa_buf = gsa.contiguous() + else: + self._gsa_buf.copy_(gsa) + + # Run GEMM + out = run_nvfp4_grouped_gemm( + mat_a=padded_x_fp4, + mat_b=self._mat_b, + scale_a=scale_a, + scale_b=self._scale_b, + expert_offsets=expert_offsets, + global_scale_a=self._gsa_buf, + global_scale_b=self._gsb, + ) + + return out[:num_tokens] + def __call__(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.run(hidden_states) diff --git a/dsv4/ops/quantize.py b/dsv4/ops/quantize.py index 8f797675..83074220 100644 --- a/dsv4/ops/quantize.py +++ b/dsv4/ops/quantize.py @@ -351,6 +351,26 @@ def quantize_nvfp4_gpu(x_bf16, global_scale): return mod.quantize_nvfp4(x_bf16, global_scale) +class QuantizedActivation: + """Pre-quantized NVFP4 activation tensor. + + Carries the FP4 data, block scales, and per-row global scale + so downstream Nvfp4Linear calls can skip quantization and go + straight to GEMM. + + Created by rmsnorm_quantize_nvfp4() or quantize_nvfp4_gpu_fused(). + Consumed by Nvfp4Linear.run_from_quantized(). + """ + __slots__ = ['x_fp4', 'x_sf', 'gsa', 'inv_rms', 'num_tokens'] + + def __init__(self, x_fp4, x_sf, gsa, inv_rms=None): + self.x_fp4 = x_fp4 # (M, N//2) FP4 + self.x_sf = x_sf # (M, N//16) E4M3 + self.gsa = gsa # (M,) FP32 + self.inv_rms = inv_rms # (M,) FP32, optional + self.num_tokens = x_fp4.shape[0] + + def dequantize_nvfp4(x_fp4, x_sf, gsa, shape=None): """Dequantize NVFP4 → BF16 using the CUDA dequant kernel.