P4: Add QuantizedActivation + Nvfp4Linear.run_from_quantized
- QuantizedActivation: carries (x_fp4, x_sf, gsa) for skip-quantize path - Nvfp4Linear.run_from_quantized(): runs GEMM with pre-quantized input - Enables fused RMSNorm+quantize to feed directly into all downstream linears (q_a, kv, o_proj, etc.) without re-quantizing
This commit is contained in:
@@ -213,5 +213,55 @@ class Nvfp4Linear:
|
||||
|
||||
return out[:num_tokens]
|
||||
|
||||
def run_from_quantized(self, quant: 'QuantizedActivation') -> torch.Tensor:
|
||||
"""Run GEMM with pre-quantized activation (skip quantize step).
|
||||
|
||||
Used when the input has already been quantized by a fused
|
||||
RMSNorm+quantize kernel. Saves 2 kernel launches per call.
|
||||
|
||||
Args:
|
||||
quant: QuantizedActivation with x_fp4, x_sf, gsa
|
||||
"""
|
||||
from dsv4.ops.quantize import QuantizedActivation
|
||||
assert isinstance(quant, QuantizedActivation)
|
||||
|
||||
self._ensure_initialized()
|
||||
num_tokens = quant.num_tokens
|
||||
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
self._ensure_buffer_size(num_tokens)
|
||||
|
||||
# Scatter pre-quantized x_fp4 into padded buffer
|
||||
padded_x_fp4 = self._padded_x_fp4_buf
|
||||
padded_x_fp4.view(torch.uint8).zero_()
|
||||
padded_x_fp4.view(torch.uint8)[:quant.x_fp4.shape[0]] = quant.x_fp4.view(torch.uint8)
|
||||
|
||||
# Assemble A-side scales from pre-quantized sf
|
||||
scale_a = self._assemble_scales_single_group(quant.x_sf)
|
||||
|
||||
# Expert offsets
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.fill_(padded_rows)
|
||||
|
||||
# Global scales — use the per-row gsa from the fused kernel
|
||||
# Reshape to (1,) if scalar, or use per-row (M,) broadcast
|
||||
gsa = quant.gsa[:1].reshape(1) if quant.gsa.shape[0] == 1 else quant.gsa[:num_tokens]
|
||||
if gsa.shape != self._gsa_buf.shape:
|
||||
self._gsa_buf = gsa.contiguous()
|
||||
else:
|
||||
self._gsa_buf.copy_(gsa)
|
||||
|
||||
# Run GEMM
|
||||
out = run_nvfp4_grouped_gemm(
|
||||
mat_a=padded_x_fp4,
|
||||
mat_b=self._mat_b,
|
||||
scale_a=scale_a,
|
||||
scale_b=self._scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=self._gsa_buf,
|
||||
global_scale_b=self._gsb,
|
||||
)
|
||||
|
||||
return out[:num_tokens]
|
||||
|
||||
def __call__(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.run(hidden_states)
|
||||
|
||||
@@ -351,6 +351,26 @@ def quantize_nvfp4_gpu(x_bf16, global_scale):
|
||||
return mod.quantize_nvfp4(x_bf16, global_scale)
|
||||
|
||||
|
||||
class QuantizedActivation:
|
||||
"""Pre-quantized NVFP4 activation tensor.
|
||||
|
||||
Carries the FP4 data, block scales, and per-row global scale
|
||||
so downstream Nvfp4Linear calls can skip quantization and go
|
||||
straight to GEMM.
|
||||
|
||||
Created by rmsnorm_quantize_nvfp4() or quantize_nvfp4_gpu_fused().
|
||||
Consumed by Nvfp4Linear.run_from_quantized().
|
||||
"""
|
||||
__slots__ = ['x_fp4', 'x_sf', 'gsa', 'inv_rms', 'num_tokens']
|
||||
|
||||
def __init__(self, x_fp4, x_sf, gsa, inv_rms=None):
|
||||
self.x_fp4 = x_fp4 # (M, N//2) FP4
|
||||
self.x_sf = x_sf # (M, N//16) E4M3
|
||||
self.gsa = gsa # (M,) FP32
|
||||
self.inv_rms = inv_rms # (M,) FP32, optional
|
||||
self.num_tokens = x_fp4.shape[0]
|
||||
|
||||
|
||||
def dequantize_nvfp4(x_fp4, x_sf, gsa, shape=None):
|
||||
"""Dequantize NVFP4 → BF16 using the CUDA dequant kernel.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user