P4: Add QuantizedActivation + Nvfp4Linear.run_from_quantized

- QuantizedActivation: carries (x_fp4, x_sf, gsa) for skip-quantize path
- Nvfp4Linear.run_from_quantized(): runs GEMM with pre-quantized input
- Enables fused RMSNorm+quantize to feed directly into all downstream
  linears (q_a, kv, o_proj, etc.) without re-quantizing
This commit is contained in:
2026-06-02 16:37:38 +00:00
parent 149ecefb56
commit 0d1cd1e216
2 changed files with 70 additions and 0 deletions

View File

@@ -213,5 +213,55 @@ class Nvfp4Linear:
return out[:num_tokens]
def run_from_quantized(self, quant: 'QuantizedActivation') -> torch.Tensor:
"""Run GEMM with pre-quantized activation (skip quantize step).
Used when the input has already been quantized by a fused
RMSNorm+quantize kernel. Saves 2 kernel launches per call.
Args:
quant: QuantizedActivation with x_fp4, x_sf, gsa
"""
from dsv4.ops.quantize import QuantizedActivation
assert isinstance(quant, QuantizedActivation)
self._ensure_initialized()
num_tokens = quant.num_tokens
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
self._ensure_buffer_size(num_tokens)
# Scatter pre-quantized x_fp4 into padded buffer
padded_x_fp4 = self._padded_x_fp4_buf
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[:quant.x_fp4.shape[0]] = quant.x_fp4.view(torch.uint8)
# Assemble A-side scales from pre-quantized sf
scale_a = self._assemble_scales_single_group(quant.x_sf)
# Expert offsets
expert_offsets = self._expert_offsets_buf
expert_offsets.fill_(padded_rows)
# Global scales — use the per-row gsa from the fused kernel
# Reshape to (1,) if scalar, or use per-row (M,) broadcast
gsa = quant.gsa[:1].reshape(1) if quant.gsa.shape[0] == 1 else quant.gsa[:num_tokens]
if gsa.shape != self._gsa_buf.shape:
self._gsa_buf = gsa.contiguous()
else:
self._gsa_buf.copy_(gsa)
# Run GEMM
out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4,
mat_b=self._mat_b,
scale_a=scale_a,
scale_b=self._scale_b,
expert_offsets=expert_offsets,
global_scale_a=self._gsa_buf,
global_scale_b=self._gsb,
)
return out[:num_tokens]
def __call__(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.run(hidden_states)

View File

@@ -351,6 +351,26 @@ def quantize_nvfp4_gpu(x_bf16, global_scale):
return mod.quantize_nvfp4(x_bf16, global_scale)
class QuantizedActivation:
"""Pre-quantized NVFP4 activation tensor.
Carries the FP4 data, block scales, and per-row global scale
so downstream Nvfp4Linear calls can skip quantization and go
straight to GEMM.
Created by rmsnorm_quantize_nvfp4() or quantize_nvfp4_gpu_fused().
Consumed by Nvfp4Linear.run_from_quantized().
"""
__slots__ = ['x_fp4', 'x_sf', 'gsa', 'inv_rms', 'num_tokens']
def __init__(self, x_fp4, x_sf, gsa, inv_rms=None):
self.x_fp4 = x_fp4 # (M, N//2) FP4
self.x_sf = x_sf # (M, N//16) E4M3
self.gsa = gsa # (M,) FP32
self.inv_rms = inv_rms # (M,) FP32, optional
self.num_tokens = x_fp4.shape[0]
def dequantize_nvfp4(x_fp4, x_sf, gsa, shape=None):
"""Dequantize NVFP4 → BF16 using the CUDA dequant kernel.