diff --git a/patches/staging_kernel.py b/patches/staging_kernel.py index d409d14..3bc0fb0 100644 --- a/patches/staging_kernel.py +++ b/patches/staging_kernel.py @@ -63,26 +63,64 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel( scale_exp = (scale_bits >> 23) & 0xFF scale_mant = scale_bits & 0x7FFFFF - # Convert FP32 → E4M3 manually - e4m3_exp = scale_exp - 120 # FP32 bias=127, E4M3 bias=7 - e4m3_exp = tl.maximum(e4m3_exp, 0) - e4m3_exp = tl.minimum(e4m3_exp, 15) - e4m3_mant = scale_mant >> 20 - round_bit = (scale_mant >> 19) & 1 - e4m3_mant = e4m3_mant + round_bit + # Convert FP32 → E4M3 manually (with subnormal support) + # FP32 bias=127, E4M3 bias=7 → raw exp = scale_exp - 120 + e4m3_exp_raw = scale_exp - 120 # can be negative → subnormal + + # Normal path: exp >= 1, just truncate mantissa top 3 bits + # RNE rounding: need guard (bit 19), sticky (OR of bits 18:0), and LSB of result + normal_mant = scale_mant >> 20 + guard_bit = (scale_mant >> 19) & 1 + sticky_bit = tl.where((scale_mant & 0x7FFFF) != 0, 1, 0) # OR of bits [18:0] + result_lsb = normal_mant & 1 + # RNE: round up if (guard=1 and sticky=1) or (guard=1 and sticky=0 and lsb=1) + round_up = guard_bit & (sticky_bit | result_lsb) + normal_mant = normal_mant + round_up + normal_exp = e4m3_exp_raw + + # Subnormal path: exp_raw <= 0 + # Insert implicit leading 1 and right-shift by (1 - exp_raw) + # E4M3 subnormal: value = (mant/8) * 2^(1-7) = (mant/8) * 2^-6 + # So we need: (1 + mant_fp32/2^23) * 2^(exp_raw - 7) = (shifted_mant/8) * 2^-6 + # shifted_mant = (implicit_1 | mant_fp32) >> (1 - exp_raw - 1) then take top 3 bits + shift = 1 - e4m3_exp_raw # positive when subnormal + mant_with_leading = (0x800000 | scale_mant) # insert implicit 1 + # Right-shift to get into the 3-bit E4M3 mantissa window + # We want bits [shift+19 : shift+23) of mant_with_leading for 3 mantissa bits + 1 round bit + subnormal_mant = (mant_with_leading >> (shift.to(tl.int32) + 20)) & 0x7 + sub_guard_bit = (mant_with_leading >> (shift.to(tl.int32) + 19)) & 1 + # Sticky: OR of all bits below the guard bit in the shifted result + # shift ≤ 8 in practice (amax floor = 1e-4 → scale ≈ 2^-15 → exp_raw ≈ -7), so mask ≤ 2^27 + sub_sticky_mask = (1 << (shift.to(tl.int32) + 19)) - 1 + sub_sticky_bit = tl.where((mant_with_leading & sub_sticky_mask) != 0, 1, 0) + sub_result_lsb = subnormal_mant & 1 + sub_round_up = sub_guard_bit & (sub_sticky_bit | sub_result_lsb) + subnormal_mant = subnormal_mant + sub_round_up + + is_normal = e4m3_exp_raw >= 1 + e4m3_mant = tl.where(is_normal, normal_mant, subnormal_mant) + e4m3_exp = tl.where(is_normal, normal_exp, 0) # exp=0 for subnormals + + # Handle mantissa overflow after rounding overflow = e4m3_mant >= 8 e4m3_mant = tl.where(overflow, 0, e4m3_mant) e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp) + e4m3_exp = tl.maximum(e4m3_exp, 0) e4m3_exp = tl.minimum(e4m3_exp, 15) scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant - # Reconstruct dequantized scale for E2M1 quantization - e4m3_exp_for_recon = tl.maximum(e4m3_exp.to(tl.int32) - 7, -126) + # Reconstruct dequantized scale by decoding the STORED E4M3 bits. + # This guarantees the E2M1 quantization divides by exactly the value + # the CUDA kernel will multiply back — same bits, single decode, no + # possibility of encode/decode disagreement. + stored_exp = (scale_e4m3_bits >> 3) & 0xF + stored_mant = scale_e4m3_bits & 0x7 + e4m3_exp_for_recon = tl.maximum(stored_exp.to(tl.int32) - 7, -126) two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23 two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True) - normal_value = (1.0 + e4m3_mant.to(tl.float32) / 8.0) * two_pow_exp - subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625 - e4m3_value = tl.where(e4m3_exp == 0, subnormal_value, normal_value) + normal_value = (1.0 + stored_mant.to(tl.float32) / 8.0) * two_pow_exp + subnormal_value = (stored_mant.to(tl.float32) / 8.0) * 0.015625 + e4m3_value = tl.where(stored_exp == 0, subnormal_value, normal_value) # ---- E2M1 FP4 quantization (unpacked, 1 byte/element) ---- # E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6] diff --git a/patches/test_nvfp4_mega_moe.py b/patches/test_nvfp4_mega_moe.py deleted file mode 100644 index 5a0a1d4..0000000 --- a/patches/test_nvfp4_mega_moe.py +++ /dev/null @@ -1,121 +0,0 @@ -"""Minimal test for fp8_nvfp4_mega_moe kernel with synthetic data. - -Run inside the vllm container: - python3 /patches/test_nvfp4_mega_moe.py -""" -import torch -import torch.distributed as dist -import os -import sys - -def test_nvfp4_mega_moe(): - # Use dimensions that satisfy all alignment requirements: - # - hidden and intermediate_hidden must be multiples of 128 and 64 - # - block_m will be at least 32 (SMEM alignment: 32 * 64 = 2048 >= 1024) - num_experts = 2 - num_tokens = 32 # must be multiple of alignment - top_k = 2 - hidden = 512 # multiple of 128 and 64 - intermediate_hidden = 1024 # multiple of 128 and 64 - - device = "cuda" - torch.cuda.set_device(0) - - # Single-rank process group for SymmBuffer - os.environ.setdefault("MASTER_ADDR", "127.0.0.1") - os.environ.setdefault("MASTER_PORT", "29501") - os.environ.setdefault("RANK", "0") - os.environ.setdefault("WORLD_SIZE", "1") - if not dist.is_initialized(): - dist.init_process_group("nccl") - group = dist.new_group() - - from deep_gemm.mega import ( - fp8_nvfp4_mega_moe, - get_symm_buffer_for_nvfp4_mega_moe, - transform_nvfp4_weights_for_mega_moe, - ) - - # --- Weights: random NVFP4 --- - w13_weight = torch.randint(0, 256, (num_experts, 2 * intermediate_hidden, hidden // 2), - dtype=torch.uint8, device=device).view(torch.int8) - w13_weight_scale = torch.randn(num_experts, 2 * intermediate_hidden, hidden // 16, - device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn) - w13_weight_scale_2 = torch.ones(num_experts, device=device) # global scale = 1 for simplicity - w2_weight = torch.randint(0, 256, (num_experts, hidden, intermediate_hidden // 2), - dtype=torch.uint8, device=device).view(torch.int8) - w2_weight_scale = torch.randn(num_experts, hidden, intermediate_hidden // 16, - device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn) - w2_weight_scale_2 = torch.ones(num_experts, device=device) - - print("Transforming weights...") - l1_weights, l2_weights = transform_nvfp4_weights_for_mega_moe( - (w13_weight, w13_weight_scale), - (w2_weight, w2_weight_scale), - l1_weight_scale_2=w13_weight_scale_2, - l2_weight_scale_2=w2_weight_scale_2, - ) - for name, t in [("l1_w", l1_weights[0]), ("l1_w_sf", l1_weights[1]), - ("l2_w", l2_weights[0]), ("l2_w_sf", l2_weights[1])]: - print(f" {name}: dtype={t.dtype} shape={tuple(t.shape)} strides={t.stride()} contig={t.is_contiguous()}") - - # --- Symm buffer --- - print("Creating symm buffer...") - symm_buffer = get_symm_buffer_for_nvfp4_mega_moe( - group, num_experts, num_tokens, top_k, hidden, intermediate_hidden) - for name, t in [("x", symm_buffer.x), ("x_sf", symm_buffer.x_sf), - ("l1_acts", symm_buffer.l1_acts), ("l1_acts_sf", symm_buffer.l1_acts_sf), - ("l2_acts", symm_buffer.l2_acts), ("l2_acts_sf", symm_buffer.l2_acts_sf)]: - print(f" symm_{name}: dtype={t.dtype} shape={tuple(t.shape)} strides={t.stride()}") - - # --- Stage inputs (BF16 hidden_states → FP4 packed + UE4M3 scales) --- - print("Staging inputs...") - hidden_states = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.5 - topk_weights = torch.softmax(torch.randn(num_tokens, top_k, device=device), dim=-1) - topk_ids = torch.randint(0, num_experts, (num_tokens, top_k), device=device) - - # Import the staging kernel directly (can't import from vllm model without full init) - import triton - import triton.language as tl - # Just manually pack a few tokens into FP4 for testing - # Write actual FP4 data to the activation buffer (random but valid packed E2M1) - symm_buffer.x[:num_tokens].copy_( - torch.randint(0, 256, (num_tokens, hidden // 2), dtype=torch.uint8, device=device).view(torch.int8)) - # Write valid UE4M3 scales (random but non-zero) - # x_sf shape is (tokens, hidden//64) as int32 — each int32 = 4 packed UE4M3 bytes - # Just fill with simple non-zero int32 values (the data doesn't need to be - # perfectly valid UE4M3 for a launch test, just non-garbage) - symm_buffer.x_sf[:num_tokens].fill_(0x3C3C3C3C) # repeating 0x3C = ~0.5 in E4M3 - # Write topk data directly - for i in range(num_tokens): - for j in range(top_k): - symm_buffer.topk_idx[i, j] = topk_ids[i, j].item() - symm_buffer.topk_weights[i, j] = topk_weights[i, j].item() - torch.cuda.synchronize() - print("Buffer populated with random FP4 data") - - # --- Run kernel --- - y = torch.zeros(num_tokens, hidden, dtype=torch.bfloat16, device=device) - print("Calling fp8_nvfp4_mega_moe...", flush=True) - import signal - timed_out = False - def handler(signum, frame): - nonlocal timed_out - timed_out = True - raise TimeoutError("Kernel timeout") - signal.signal(signal.SIGALRM, handler) - signal.alarm(15) # 15 second timeout - try: - fp8_nvfp4_mega_moe(y, l1_weights, l2_weights, symm_buffer) - torch.cuda.synchronize() - signal.alarm(0) - print(f"SUCCESS! y stats: min={y.min().item():.4f} max={y.max().item():.4f} mean={y.mean().item():.4f} nonzero={torch.count_nonzero(y).item()}") - except TimeoutError: - print("TIMEOUT: kernel did not complete in 15s (GPU hang?)") - except Exception as e: - signal.alarm(0) - print(f"FAILED: {e}") - raise - -if __name__ == "__main__": - test_nvfp4_mega_moe()