""" NVFP4 staging kernel — full FP4 (E2M1) activations + UE4M3 block16 scales. The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed). This kernel quantizes BF16 activations → E2M1 packed uint8 with UE4M3 scales. """ import triton import triton.language as tl import torch @triton.jit def _deepseek_v4_stage_mega_moe_inputs_kernel( hidden_states, x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32 topk_ids, topk_weights, topk_idx_out, topk_weights_out, hidden_stride_m: tl.constexpr, hidden_stride_k: tl.constexpr, x_stride_m: tl.constexpr, x_stride_k: tl.constexpr, x_sf_stride_m: tl.constexpr, x_sf_stride_k: tl.constexpr, topk_ids_stride_m: tl.constexpr, topk_ids_stride_k: tl.constexpr, topk_weights_stride_m: tl.constexpr, topk_weights_stride_k: tl.constexpr, topk_idx_stride_m: tl.constexpr, topk_idx_stride_k: tl.constexpr, topk_weights_out_stride_m: tl.constexpr, topk_weights_out_stride_k: tl.constexpr, hidden_size: tl.constexpr, top_k: tl.constexpr, BLOCK_K: tl.constexpr, # 128 elements (loaded from hidden) GROUP_K: tl.constexpr, # 16 (NVFP4 group_size) BLOCK_TOPK: tl.constexpr, ) -> None: token_id = tl.program_id(0) k_block_id = tl.program_id(1) k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K) k_mask = k_offsets < hidden_size hidden = tl.load( hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k, mask=k_mask, other=0.0, ).to(tl.float32) num_groups: tl.constexpr = BLOCK_K // GROUP_K # 8 hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) abs_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) amax = tl.max(abs_groups, axis=1) amax = tl.maximum(amax, 1.0e-4) # ---- UE4M3 scale computation ---- # scale = amax / 6.0 (E2M1 max value = 6) # Then quantize scale to UE4M3 format scale = amax / 6.0 scale_bits = scale.to(tl.uint32, bitcast=True) scale_exp = (scale_bits >> 23) & 0xFF scale_mant = scale_bits & 0x7FFFFF # Convert FP32 → E4M3 manually (with subnormal support) # FP32 bias=127, E4M3 bias=7 → raw exp = scale_exp - 120 e4m3_exp_raw = scale_exp - 120 # can be negative → subnormal # Normal path: exp >= 1, just truncate mantissa top 3 bits # RNE rounding: need guard (bit 19), sticky (OR of bits 18:0), and LSB of result normal_mant = scale_mant >> 20 guard_bit = (scale_mant >> 19) & 1 sticky_bit = tl.where((scale_mant & 0x7FFFF) != 0, 1, 0) # OR of bits [18:0] result_lsb = normal_mant & 1 # RNE: round up if (guard=1 and sticky=1) or (guard=1 and sticky=0 and lsb=1) round_up = guard_bit & (sticky_bit | result_lsb) normal_mant = normal_mant + round_up normal_exp = e4m3_exp_raw # Subnormal path: exp_raw <= 0 # Insert implicit leading 1 and right-shift by (1 - exp_raw) # E4M3 subnormal: value = (mant/8) * 2^(1-7) = (mant/8) * 2^-6 # So we need: (1 + mant_fp32/2^23) * 2^(exp_raw - 7) = (shifted_mant/8) * 2^-6 # shifted_mant = (implicit_1 | mant_fp32) >> (1 - exp_raw - 1) then take top 3 bits shift = 1 - e4m3_exp_raw # positive when subnormal mant_with_leading = (0x800000 | scale_mant) # insert implicit 1 # Right-shift to get into the 3-bit E4M3 mantissa window # We want bits [shift+19 : shift+23) of mant_with_leading for 3 mantissa bits + 1 round bit subnormal_mant = (mant_with_leading >> (shift.to(tl.int32) + 20)) & 0x7 sub_guard_bit = (mant_with_leading >> (shift.to(tl.int32) + 19)) & 1 # Sticky: OR of all bits below the guard bit in the shifted result # shift ≤ 8 in practice (amax floor = 1e-4 → scale ≈ 2^-15 → exp_raw ≈ -7), so mask ≤ 2^27 sub_sticky_mask = (1 << (shift.to(tl.int32) + 19)) - 1 sub_sticky_bit = tl.where((mant_with_leading & sub_sticky_mask) != 0, 1, 0) sub_result_lsb = subnormal_mant & 1 sub_round_up = sub_guard_bit & (sub_sticky_bit | sub_result_lsb) subnormal_mant = subnormal_mant + sub_round_up is_normal = e4m3_exp_raw >= 1 e4m3_mant = tl.where(is_normal, normal_mant, subnormal_mant) e4m3_exp = tl.where(is_normal, normal_exp, 0) # exp=0 for subnormals # Handle mantissa overflow after rounding overflow = e4m3_mant >= 8 e4m3_mant = tl.where(overflow, 0, e4m3_mant) e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp) e4m3_exp = tl.maximum(e4m3_exp, 0) e4m3_exp = tl.minimum(e4m3_exp, 15) scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant # Reconstruct dequantized scale by decoding the STORED E4M3 bits. # This guarantees the E2M1 quantization divides by exactly the value # the CUDA kernel will multiply back — same bits, single decode, no # possibility of encode/decode disagreement. stored_exp = (scale_e4m3_bits >> 3) & 0xF stored_mant = scale_e4m3_bits & 0x7 e4m3_exp_for_recon = tl.maximum(stored_exp.to(tl.int32) - 7, -126) two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23 two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True) normal_value = (1.0 + stored_mant.to(tl.float32) / 8.0) * two_pow_exp subnormal_value = (stored_mant.to(tl.float32) / 8.0) * 0.015625 e4m3_value = tl.where(stored_exp == 0, subnormal_value, normal_value) # ---- E2M1 FP4 quantization (unpacked, 1 byte/element) ---- # E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6] # Nearest-neighbor using thresholds (midpoints between consecutive values) scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None] # Clamp to E2M1 range [-6, 6] scaled = tl.maximum(scaled, -6.0) scaled = tl.minimum(scaled, 6.0) abs_s = tl.abs(scaled) # Thresholds: midpoints between [0, 0.5, 1, 1.5, 2, 3, 4, 6] # [0, 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0, INF] e2m1_idx = tl.where(abs_s < 0.25, 0, tl.where(abs_s < 0.75, 1, tl.where(abs_s < 1.25, 2, tl.where(abs_s < 1.75, 3, tl.where(abs_s < 2.5, 4, tl.where(abs_s < 3.5, 5, tl.where(abs_s < 5.0, 6, 7))))))) sign_bit = (scaled < 0).to(tl.int32) e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index # Pack E2M1 pairs into single bytes (2 per byte, low nibble first) # mxf4nvf4 reads FP4 packed from SMEM — must match kernel's TMA layout e2m1_flat = tl.reshape(e2m1_4bit, [BLOCK_K]) e2m1_lo = e2m1_flat[0::2] # even indices → low nibble e2m1_hi = e2m1_flat[1::2] # odd indices → high nibble e2m1_packed = (e2m1_hi << 4 | e2m1_lo).to(tl.uint8) # [BLOCK_K // 2] k_offsets_out = k_block_id * (BLOCK_K // 2) + tl.arange(0, BLOCK_K // 2) k_mask_out = k_offsets_out < (hidden_size // 2) tl.store( x_fp4 + token_id * x_stride_m + k_offsets_out * x_stride_k, e2m1_packed, mask=k_mask_out, ) # Pack UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements) # 8 groups per k_block of 128 → 2 int32s per k_block # int32 can only pack 4 bytes (shifts >= 32 are UB), so split into two packs scale_offsets = tl.arange(0, num_groups) # [0..7] first_half = scale_offsets < 4 # groups 0-3 → int32[0] second_half = scale_offsets >= 4 # groups 4-7 → int32[1] packed_lo = tl.sum( tl.where(first_half, scale_e4m3_bits.to(tl.int32) << (scale_offsets * 8), 0), axis=0, ).to(tl.int32) packed_hi = tl.sum( tl.where(second_half, scale_e4m3_bits.to(tl.int32) << ((scale_offsets - 4) * 8), 0), axis=0, ).to(tl.int32) # Write 2 int32s per k_block: x_sf shape is (M, K//64) = (M, num_k_blocks * 2) sf_base = token_id * x_sf_stride_m + k_block_id * 2 * x_sf_stride_k tl.store(x_sf + sf_base, packed_lo) tl.store(x_sf + sf_base + x_sf_stride_k, packed_hi) if k_block_id == 0: topk_offsets = tl.arange(0, BLOCK_TOPK) topk_mask = topk_offsets < top_k ids = tl.load( topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k, mask=topk_mask, other=0, ).to(tl.int64) tl.store( topk_idx_out + token_id * topk_idx_stride_m + topk_offsets * topk_idx_stride_k, ids, mask=topk_mask, ) weights = tl.load( topk_weights + token_id * topk_weights_stride_m + topk_offsets * topk_weights_stride_k, mask=topk_mask, other=0.0, ) tl.store( topk_weights_out + token_id * topk_weights_out_stride_m + topk_offsets * topk_weights_out_stride_k, weights, mask=topk_mask, ) def _stage_deepseek_v4_mega_moe_inputs( hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, x_fp4: torch.Tensor, # uint8, shape (M, K//2) x_sf: torch.Tensor, # int32, shape (M, K//64) topk_idx_out: torch.Tensor, topk_weights_out: torch.Tensor, ) -> None: num_tokens, hidden_size = hidden_states.shape if num_tokens == 0: return if hidden_size % 128 != 0: raise ValueError( "DeepSeek V4 MegaMoE input staging requires hidden_size to be " "a multiple of 128." ) top_k = topk_ids.shape[1] if topk_weights.shape != topk_ids.shape: raise ValueError( "DeepSeek V4 MegaMoE input staging requires topk_weights and " "topk_ids to have the same shape." ) block_k = 128 grid = (num_tokens, triton.cdiv(hidden_size, block_k)) block_topk = triton.next_power_of_2(top_k) _deepseek_v4_stage_mega_moe_inputs_kernel[grid]( hidden_states, x_fp4, x_sf, topk_ids, topk_weights, topk_idx_out, topk_weights_out, hidden_states.stride(0), hidden_states.stride(1), x_fp4.stride(0), x_fp4.stride(1), x_sf.stride(0), x_sf.stride(1), topk_ids.stride(0), topk_ids.stride(1), topk_weights.stride(0), topk_weights.stride(1), topk_idx_out.stride(0), topk_idx_out.stride(1), topk_weights_out.stride(0), topk_weights_out.stride(1), hidden_size, top_k, BLOCK_K=block_k, GROUP_K=16, # NVFP4: group_size=16 (scale_vec::4X) BLOCK_TOPK=block_topk, num_warps=4, )