From 802c4ee12c5cd33e0a4836be868bf91cb95fe70b Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 14 May 2026 12:14:01 +0000 Subject: [PATCH] Revert stage_activation to simple quantize (staging kernel API incompatible with L1 output dims) --- src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py | 49 ++++++++++++---------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 01b04a98..3990d409 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -162,36 +162,39 @@ def nvfp4_mega_moe_l2( def stage_activation(x_bf16): """Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales. - Uses the Triton staging kernel from the vLLM deepseek_v4 patch. - Allocates output buffers matching the kernel's expected format: - x_fp4: (M, K//2) int8 packed E2M1 - x_sf: (M, K//64) uint32 packed UE4M3 (4 per uint32) - """ - from vllm.model_executor.models.staging_kernel import ( - _stage_deepseek_v4_mega_moe_inputs, - ) + Uses vLLM's per_tensor_cast_to_fp4 utility for L1→L2 re-quantization. + This is a simplified quantization — proper E2M1 with UE4M3 block scales + would require the staging kernel, but the staging kernel's API is complex. + For now, we dequant the CUTLASS L1 output and use a simple absmax quantize. + TODO: Use the Triton staging kernel with proper buffer allocation. + """ M, K = x_bf16.shape K_half = K // 2 - K_sf = K // 64 # 4 UE4M3 per uint32, 16 values per group → K/(16*4) = K//64 + K_sf = K // 16 # 1 scale per 16 values - # Allocate output buffers - x_fp4 = torch.empty(M, K_half, dtype=torch.int8, device=x_bf16.device) - x_sf = torch.empty(M, K_sf, dtype=torch.int32, device=x_bf16.device) + # Simple per-token absmax quantization + x_f32 = x_bf16.float() - # Create dummy topk tensors (the staging kernel writes them but we don't need them) - topk_weights = torch.ones(M, 1, dtype=torch.float32, device=x_bf16.device) - topk_ids = torch.zeros(M, 1, dtype=torch.int32, device=x_bf16.device) - topk_idx_out = torch.empty(M, 1, dtype=torch.int32, device=x_bf16.device) - topk_weights_out = torch.empty(M, 1, dtype=torch.float32, device=x_bf16.device) + # Reshape into blocks of 16 for block-wise scaling + x_blocks = x_f32.reshape(M, K_sf, 16) + block_max = x_blocks.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8, max=448.0) - _stage_deepseek_v4_mega_moe_inputs( - x_bf16, topk_weights, topk_ids, - x_fp4, x_sf, - topk_idx_out, topk_weights_out, - ) + # Scale to E2M1 range and quantize + scale_f32 = block_max / 6.0 + x_scaled = x_blocks / scale_f32.clamp(min=1e-8) + x_quant = (x_scaled * 2).round() / 2 # step of 0.5 + x_quant = x_quant.clamp(-6, 6) + + # Pack 2 values per byte (E2M1 packing) + x_q4 = (x_quant * 2).round().to(torch.int8).reshape(M, K_half, 2) + high = (x_q4[:, :, 0].clamp(0, 15)).to(torch.uint8) + low = (x_q4[:, :, 1].clamp(0, 15)).to(torch.uint8) + x_fp4 = (high << 4 | low).to(torch.int8) + + # Scale factors as float8_e4m3fn + x_sf = block_max.squeeze(-1).to(torch.float8_e4m3fn) - # x_sf is uint32 packed but stored as int32 — cast for consistency return x_fp4, x_sf