""" NVFP4 staging kernel — full FP4 (E2M1) activations + UE4M3 block16 scales. The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed). This kernel quantizes BF16 activations → E2M1 packed uint8 with UE4M3 scales. """ import triton import triton.language as tl import torch @triton.jit def _deepseek_v4_stage_mega_moe_inputs_kernel( hidden_states, x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32 topk_ids, topk_weights, topk_idx_out, topk_weights_out, hidden_stride_m: tl.constexpr, hidden_stride_k: tl.constexpr, x_stride_m: tl.constexpr, x_stride_k: tl.constexpr, x_sf_stride_m: tl.constexpr, x_sf_stride_k: tl.constexpr, topk_ids_stride_m: tl.constexpr, topk_ids_stride_k: tl.constexpr, topk_weights_stride_m: tl.constexpr, topk_weights_stride_k: tl.constexpr, topk_idx_stride_m: tl.constexpr, topk_idx_stride_k: tl.constexpr, topk_weights_out_stride_m: tl.constexpr, topk_weights_out_stride_k: tl.constexpr, hidden_size: tl.constexpr, top_k: tl.constexpr, BLOCK_K: tl.constexpr, # 128 elements (loaded from hidden) GROUP_K: tl.constexpr, # 16 (NVFP4 group_size) BLOCK_TOPK: tl.constexpr, ) -> None: token_id = tl.program_id(0) k_block_id = tl.program_id(1) k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K) k_mask = k_offsets < hidden_size hidden = tl.load( hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k, mask=k_mask, other=0.0, ).to(tl.float32) num_groups: tl.constexpr = BLOCK_K // GROUP_K # 8 hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) abs_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) amax = tl.max(abs_groups, axis=1) amax = tl.maximum(amax, 1.0e-4) # ---- UE4M3 scale computation ---- # scale = amax / 6.0 (E2M1 max value = 6) # Then quantize scale to UE4M3 format scale = amax / 6.0 scale_bits = scale.to(tl.uint32, bitcast=True) scale_exp = (scale_bits >> 23) & 0xFF scale_mant = scale_bits & 0x7FFFFF # Convert FP32 → E4M3 manually e4m3_exp = scale_exp - 120 # FP32 bias=127, E4M3 bias=7 e4m3_exp = tl.maximum(e4m3_exp, 0) e4m3_exp = tl.minimum(e4m3_exp, 15) e4m3_mant = scale_mant >> 20 round_bit = (scale_mant >> 19) & 1 e4m3_mant = e4m3_mant + round_bit overflow = e4m3_mant >= 8 e4m3_mant = tl.where(overflow, 0, e4m3_mant) e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp) e4m3_exp = tl.minimum(e4m3_exp, 15) scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant # Reconstruct dequantized scale for E2M1 quantization e4m3_exp_for_recon = tl.maximum(e4m3_exp.to(tl.int32) - 7, -126) two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23 two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True) normal_value = (1.0 + e4m3_mant.to(tl.float32) / 8.0) * two_pow_exp subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625 e4m3_value = tl.where(e4m3_exp == 0, subnormal_value, normal_value) # ---- E2M1 FP4 quantization (unpacked, 1 byte/element) ---- # E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6] # Nearest-neighbor using thresholds (midpoints between consecutive values) scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None] # Clamp to E2M1 range [-6, 6] scaled = tl.maximum(scaled, -6.0) scaled = tl.minimum(scaled, 6.0) abs_s = tl.abs(scaled) # Thresholds: midpoints between [0, 0.5, 1, 1.5, 2, 3, 4, 6] # [0, 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0, INF] e2m1_idx = tl.where(abs_s < 0.25, 0, tl.where(abs_s < 0.75, 1, tl.where(abs_s < 1.25, 2, tl.where(abs_s < 1.75, 3, tl.where(abs_s < 2.5, 4, tl.where(abs_s < 3.5, 5, tl.where(abs_s < 5.0, 6, 7))))))) sign_bit = (scaled < 0).to(tl.int32) e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index # Pack E2M1 pairs into single bytes (2 per byte, low nibble first) # mxf4nvf4 reads FP4 packed from SMEM — must match kernel's TMA layout e2m1_flat = tl.reshape(e2m1_4bit, [BLOCK_K]) e2m1_lo = e2m1_flat[0::2] # even indices → low nibble e2m1_hi = e2m1_flat[1::2] # odd indices → high nibble e2m1_packed = (e2m1_hi << 4 | e2m1_lo).to(tl.uint8) # [BLOCK_K // 2] k_offsets_out = k_block_id * (BLOCK_K // 2) + tl.arange(0, BLOCK_K // 2) k_mask_out = k_offsets_out < (hidden_size // 2) tl.store( x_fp4 + token_id * x_stride_m + k_offsets_out * x_stride_k, e2m1_packed, mask=k_mask_out, ) # Pack UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements) # 8 groups per k_block of 128 → 2 int32s per k_block # int32 can only pack 4 bytes (shifts >= 32 are UB), so split into two packs scale_offsets = tl.arange(0, num_groups) # [0..7] first_half = scale_offsets < 4 # groups 0-3 → int32[0] second_half = scale_offsets >= 4 # groups 4-7 → int32[1] packed_lo = tl.sum( tl.where(first_half, scale_e4m3_bits.to(tl.int32) << (scale_offsets * 8), 0), axis=0, ).to(tl.int32) packed_hi = tl.sum( tl.where(second_half, scale_e4m3_bits.to(tl.int32) << ((scale_offsets - 4) * 8), 0), axis=0, ).to(tl.int32) # Write 2 int32s per k_block: x_sf shape is (M, K//64) = (M, num_k_blocks * 2) sf_base = token_id * x_sf_stride_m + k_block_id * 2 * x_sf_stride_k tl.store(x_sf + sf_base, packed_lo) tl.store(x_sf + sf_base + x_sf_stride_k, packed_hi) if k_block_id == 0: topk_offsets = tl.arange(0, BLOCK_TOPK) topk_mask = topk_offsets < top_k ids = tl.load( topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k, mask=topk_mask, other=0, ).to(tl.int64) tl.store( topk_idx_out + token_id * topk_idx_stride_m + topk_offsets * topk_idx_stride_k, ids, mask=topk_mask, ) weights = tl.load( topk_weights + token_id * topk_weights_stride_m + topk_offsets * topk_weights_stride_k, mask=topk_mask, other=0.0, ) tl.store( topk_weights_out + token_id * topk_weights_out_stride_m + topk_offsets * topk_weights_out_stride_k, weights, mask=topk_mask, ) def _stage_deepseek_v4_mega_moe_inputs( hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, x_fp4: torch.Tensor, # uint8, shape (M, K//2) x_sf: torch.Tensor, # int32, shape (M, K//64) topk_idx_out: torch.Tensor, topk_weights_out: torch.Tensor, ) -> None: num_tokens, hidden_size = hidden_states.shape if num_tokens == 0: return if hidden_size % 128 != 0: raise ValueError( "DeepSeek V4 MegaMoE input staging requires hidden_size to be " "a multiple of 128." ) top_k = topk_ids.shape[1] if topk_weights.shape != topk_ids.shape: raise ValueError( "DeepSeek V4 MegaMoE input staging requires topk_weights and " "topk_ids to have the same shape." ) block_k = 128 grid = (num_tokens, triton.cdiv(hidden_size, block_k)) block_topk = triton.next_power_of_2(top_k) _deepseek_v4_stage_mega_moe_inputs_kernel[grid]( hidden_states, x_fp4, x_sf, topk_ids, topk_weights, topk_idx_out, topk_weights_out, hidden_states.stride(0), hidden_states.stride(1), x_fp4.stride(0), x_fp4.stride(1), x_sf.stride(0), x_sf.stride(1), topk_ids.stride(0), topk_ids.stride(1), topk_weights.stride(0), topk_weights.stride(1), topk_idx_out.stride(0), topk_idx_out.stride(1), topk_weights_out.stride(0), topk_weights_out.stride(1), hidden_size, top_k, BLOCK_K=block_k, GROUP_K=16, # NVFP4: group_size=16 (scale_vec::4X) BLOCK_TOPK=block_topk, num_warps=4, )