P4: Integrate fused RMSNorm+quantize into single_shot (attention path)
- forward_layer: use rmsnorm_quantize_nvfp4 for attn_norm - forward_attention: accept x_quant, use run_from_quantized for q_a/kv - Dequantize for compressor/indexer (still saves 2+ launches per site) - FFN path kept unfused — MoE internal quantization needs refactoring (P5) - _use_fused_rmsnorm_quantize flag to toggle (default True)
This commit is contained in:
@@ -651,6 +651,7 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w
|
||||
# =====================================================================
|
||||
def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
kv_cache, positions, compressor, indexer, prod_lin,
|
||||
x_quant=None,
|
||||
_profile_detail=False, _profile_times=None):
|
||||
dev = x_normed.device; T = x_normed.shape[0]
|
||||
n_h = cfg["num_attention_heads"]; hd = cfg["head_dim"]; rd = cfg.get("qk_rope_head_dim", 64)
|
||||
@@ -668,7 +669,7 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
|
||||
_pt('q_a_start')
|
||||
# 1. Q: q_a (NVFP4 GEMM) → q_a_norm → q_b (NVFP4 GEMM) → q_b_norm
|
||||
q_a = prod_lin['q_a'](x_normed)
|
||||
q_a = prod_lin['q_a'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['q_a'](x_normed)
|
||||
_pt('q_a_end')
|
||||
if VERBOSE >= 2 and li < 3:
|
||||
# Compare q_a with PyTorch reference
|
||||
@@ -686,7 +687,7 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
|
||||
# 2. KV (NVFP4 GEMM, MQA, single KV head)
|
||||
_pt('kv_start')
|
||||
kv = prod_lin['kv'](x_normed)
|
||||
kv = prod_lin['kv'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['kv'](x_normed)
|
||||
_pt('kv_end')
|
||||
kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
|
||||
if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
|
||||
@@ -850,15 +851,29 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
kv_cache, positions, token_id,
|
||||
compressor=None, indexer=None,
|
||||
moe_runner=None, se_runner=None, router=None,
|
||||
prod_lin=None, _profile_detail=False, _profile_times=None):
|
||||
x_in, ctx_a = attn_mhc.pre_block(X_l); x_normed = rmsnorm(x_in, attn_norm_w)
|
||||
prod_lin=None, _profile_detail=False, _profile_times=None,
|
||||
_use_fused_rmsnorm_quantize=True):
|
||||
# P4: Fused RMSNorm + NVFP4 quantize — eliminates ~488 launches/token
|
||||
from dsv4.ops.quantize import rmsnorm_quantize_nvfp4, QuantizedActivation, dequantize_nvfp4
|
||||
x_in, ctx_a = attn_mhc.pre_block(X_l)
|
||||
if _use_fused_rmsnorm_quantize:
|
||||
x_quant_attn = rmsnorm_quantize_nvfp4(x_in, attn_norm_w.to(x_in.device, torch.float32))
|
||||
# Dequantize for compressor/indexer (1 kernel launch)
|
||||
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
|
||||
else:
|
||||
x_normed = rmsnorm(x_in, attn_norm_w)
|
||||
x_quant_attn = None
|
||||
if _profile_detail: torch.cuda.synchronize(); t_attn0 = time.perf_counter()
|
||||
F_attn, _ = forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
kv_cache, positions, compressor, indexer, prod_lin,
|
||||
x_quant=x_quant_attn,
|
||||
_profile_detail=_profile_detail, _profile_times=_profile_times)
|
||||
if _profile_detail: torch.cuda.synchronize(); t_attn1 = time.perf_counter()
|
||||
X_mid = attn_mhc.post_block(X_l, F_attn, ctx_a)
|
||||
x_in_f, ctx_f = ffn_mhc.pre_block(X_mid); x_ffn = rmsnorm(x_in_f, ffn_norm_w)
|
||||
x_in_f, ctx_f = ffn_mhc.pre_block(X_mid)
|
||||
# FFN path: keep unfused rmsnorm for now — moe_forward internally quantizes,
|
||||
# and passing quantized activation through MoE requires deeper refactoring (P5)
|
||||
x_ffn = rmsnorm(x_in_f, ffn_norm_w)
|
||||
if _profile_detail: torch.cuda.synchronize(); t_ffn0 = time.perf_counter()
|
||||
F_ffn = moe_forward(x_ffn, li, moe_runner, se_runner, router, token_id)
|
||||
if _profile_detail: torch.cuda.synchronize(); t_ffn1 = time.perf_counter()
|
||||
|
||||
Reference in New Issue
Block a user