P4: Integrate fused RMSNorm+quantize into single_shot (attention path)

- forward_layer: use rmsnorm_quantize_nvfp4 for attn_norm
- forward_attention: accept x_quant, use run_from_quantized for q_a/kv
- Dequantize for compressor/indexer (still saves 2+ launches per site)
- FFN path kept unfused — MoE internal quantization needs refactoring (P5)
- _use_fused_rmsnorm_quantize flag to toggle (default True)
This commit is contained in:
2026-06-02 16:38:44 +00:00
parent 0d1cd1e216
commit 7bb3207347

View File

@@ -651,6 +651,7 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w
# =====================================================================
def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
kv_cache, positions, compressor, indexer, prod_lin,
x_quant=None,
_profile_detail=False, _profile_times=None):
dev = x_normed.device; T = x_normed.shape[0]
n_h = cfg["num_attention_heads"]; hd = cfg["head_dim"]; rd = cfg.get("qk_rope_head_dim", 64)
@@ -668,7 +669,7 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
_pt('q_a_start')
# 1. Q: q_a (NVFP4 GEMM) → q_a_norm → q_b (NVFP4 GEMM) → q_b_norm
q_a = prod_lin['q_a'](x_normed)
q_a = prod_lin['q_a'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['q_a'](x_normed)
_pt('q_a_end')
if VERBOSE >= 2 and li < 3:
# Compare q_a with PyTorch reference
@@ -686,7 +687,7 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
# 2. KV (NVFP4 GEMM, MQA, single KV head)
_pt('kv_start')
kv = prod_lin['kv'](x_normed)
kv = prod_lin['kv'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['kv'](x_normed)
_pt('kv_end')
kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
@@ -850,15 +851,29 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
kv_cache, positions, token_id,
compressor=None, indexer=None,
moe_runner=None, se_runner=None, router=None,
prod_lin=None, _profile_detail=False, _profile_times=None):
x_in, ctx_a = attn_mhc.pre_block(X_l); x_normed = rmsnorm(x_in, attn_norm_w)
prod_lin=None, _profile_detail=False, _profile_times=None,
_use_fused_rmsnorm_quantize=True):
# P4: Fused RMSNorm + NVFP4 quantize — eliminates ~488 launches/token
from dsv4.ops.quantize import rmsnorm_quantize_nvfp4, QuantizedActivation, dequantize_nvfp4
x_in, ctx_a = attn_mhc.pre_block(X_l)
if _use_fused_rmsnorm_quantize:
x_quant_attn = rmsnorm_quantize_nvfp4(x_in, attn_norm_w.to(x_in.device, torch.float32))
# Dequantize for compressor/indexer (1 kernel launch)
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
else:
x_normed = rmsnorm(x_in, attn_norm_w)
x_quant_attn = None
if _profile_detail: torch.cuda.synchronize(); t_attn0 = time.perf_counter()
F_attn, _ = forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
kv_cache, positions, compressor, indexer, prod_lin,
x_quant=x_quant_attn,
_profile_detail=_profile_detail, _profile_times=_profile_times)
if _profile_detail: torch.cuda.synchronize(); t_attn1 = time.perf_counter()
X_mid = attn_mhc.post_block(X_l, F_attn, ctx_a)
x_in_f, ctx_f = ffn_mhc.pre_block(X_mid); x_ffn = rmsnorm(x_in_f, ffn_norm_w)
x_in_f, ctx_f = ffn_mhc.pre_block(X_mid)
# FFN path: keep unfused rmsnorm for now — moe_forward internally quantizes,
# and passing quantized activation through MoE requires deeper refactoring (P5)
x_ffn = rmsnorm(x_in_f, ffn_norm_w)
if _profile_detail: torch.cuda.synchronize(); t_ffn0 = time.perf_counter()
F_ffn = moe_forward(x_ffn, li, moe_runner, se_runner, router, token_id)
if _profile_detail: torch.cuda.synchronize(); t_ffn1 = time.perf_counter()