diff --git a/single_shot_inference.py b/single_shot_inference.py index a587a871..68e9441d 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -651,6 +651,7 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w # ===================================================================== def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, kv_cache, positions, compressor, indexer, prod_lin, + x_quant=None, _profile_detail=False, _profile_times=None): dev = x_normed.device; T = x_normed.shape[0] n_h = cfg["num_attention_heads"]; hd = cfg["head_dim"]; rd = cfg.get("qk_rope_head_dim", 64) @@ -668,7 +669,7 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, _pt('q_a_start') # 1. Q: q_a (NVFP4 GEMM) → q_a_norm → q_b (NVFP4 GEMM) → q_b_norm - q_a = prod_lin['q_a'](x_normed) + q_a = prod_lin['q_a'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['q_a'](x_normed) _pt('q_a_end') if VERBOSE >= 2 and li < 3: # Compare q_a with PyTorch reference @@ -686,7 +687,7 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, # 2. KV (NVFP4 GEMM, MQA, single KV head) _pt('kv_start') - kv = prod_lin['kv'](x_normed) + kv = prod_lin['kv'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['kv'](x_normed) _pt('kv_end') kv_norm_w = w.get(f"{pfx}.kv_norm.weight") if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32)) @@ -850,15 +851,29 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, kv_cache, positions, token_id, compressor=None, indexer=None, moe_runner=None, se_runner=None, router=None, - prod_lin=None, _profile_detail=False, _profile_times=None): - x_in, ctx_a = attn_mhc.pre_block(X_l); x_normed = rmsnorm(x_in, attn_norm_w) + prod_lin=None, _profile_detail=False, _profile_times=None, + _use_fused_rmsnorm_quantize=True): + # P4: Fused RMSNorm + NVFP4 quantize — eliminates ~488 launches/token + from dsv4.ops.quantize import rmsnorm_quantize_nvfp4, QuantizedActivation, dequantize_nvfp4 + x_in, ctx_a = attn_mhc.pre_block(X_l) + if _use_fused_rmsnorm_quantize: + x_quant_attn = rmsnorm_quantize_nvfp4(x_in, attn_norm_w.to(x_in.device, torch.float32)) + # Dequantize for compressor/indexer (1 kernel launch) + x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa) + else: + x_normed = rmsnorm(x_in, attn_norm_w) + x_quant_attn = None if _profile_detail: torch.cuda.synchronize(); t_attn0 = time.perf_counter() F_attn, _ = forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, kv_cache, positions, compressor, indexer, prod_lin, + x_quant=x_quant_attn, _profile_detail=_profile_detail, _profile_times=_profile_times) if _profile_detail: torch.cuda.synchronize(); t_attn1 = time.perf_counter() X_mid = attn_mhc.post_block(X_l, F_attn, ctx_a) - x_in_f, ctx_f = ffn_mhc.pre_block(X_mid); x_ffn = rmsnorm(x_in_f, ffn_norm_w) + x_in_f, ctx_f = ffn_mhc.pre_block(X_mid) + # FFN path: keep unfused rmsnorm for now — moe_forward internally quantizes, + # and passing quantized activation through MoE requires deeper refactoring (P5) + x_ffn = rmsnorm(x_in_f, ffn_norm_w) if _profile_detail: torch.cuda.synchronize(); t_ffn0 = time.perf_counter() F_ffn = moe_forward(x_ffn, li, moe_runner, se_runner, router, token_id) if _profile_detail: torch.cuda.synchronize(); t_ffn1 = time.perf_counter()