271 lines
10 KiB
Python
271 lines
10 KiB
Python
"""
|
|
NVFP4 staging kernel — full FP4 (E2M1) activations + UE4M3 block16 scales.
|
|
|
|
The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed).
|
|
This kernel quantizes BF16 activations → E2M1 packed uint8 with UE4M3 scales.
|
|
"""
|
|
import triton
|
|
import triton.language as tl
|
|
import torch
|
|
|
|
|
|
@triton.jit
|
|
def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
|
hidden_states,
|
|
x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte
|
|
x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32
|
|
topk_ids,
|
|
topk_weights,
|
|
topk_idx_out,
|
|
topk_weights_out,
|
|
hidden_stride_m: tl.constexpr,
|
|
hidden_stride_k: tl.constexpr,
|
|
x_stride_m: tl.constexpr,
|
|
x_stride_k: tl.constexpr,
|
|
x_sf_stride_m: tl.constexpr,
|
|
x_sf_stride_k: tl.constexpr,
|
|
topk_ids_stride_m: tl.constexpr,
|
|
topk_ids_stride_k: tl.constexpr,
|
|
topk_weights_stride_m: tl.constexpr,
|
|
topk_weights_stride_k: tl.constexpr,
|
|
topk_idx_stride_m: tl.constexpr,
|
|
topk_idx_stride_k: tl.constexpr,
|
|
topk_weights_out_stride_m: tl.constexpr,
|
|
topk_weights_out_stride_k: tl.constexpr,
|
|
hidden_size: tl.constexpr,
|
|
top_k: tl.constexpr,
|
|
BLOCK_K: tl.constexpr, # 128 elements (loaded from hidden)
|
|
GROUP_K: tl.constexpr, # 16 (NVFP4 group_size)
|
|
BLOCK_TOPK: tl.constexpr,
|
|
) -> None:
|
|
token_id = tl.program_id(0)
|
|
k_block_id = tl.program_id(1)
|
|
|
|
k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
k_mask = k_offsets < hidden_size
|
|
hidden = tl.load(
|
|
hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k,
|
|
mask=k_mask,
|
|
other=0.0,
|
|
).to(tl.float32)
|
|
|
|
num_groups: tl.constexpr = BLOCK_K // GROUP_K # 8
|
|
hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
|
|
abs_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
|
|
amax = tl.max(abs_groups, axis=1)
|
|
amax = tl.maximum(amax, 1.0e-4)
|
|
|
|
# ---- UE4M3 scale computation ----
|
|
# scale = amax / 6.0 (E2M1 max value = 6)
|
|
# Then quantize scale to UE4M3 format
|
|
scale = amax / 6.0
|
|
scale_bits = scale.to(tl.uint32, bitcast=True)
|
|
scale_exp = (scale_bits >> 23) & 0xFF
|
|
scale_mant = scale_bits & 0x7FFFFF
|
|
|
|
# Convert FP32 → E4M3 manually (with subnormal support)
|
|
# FP32 bias=127, E4M3 bias=7 → raw exp = scale_exp - 120
|
|
e4m3_exp_raw = scale_exp - 120 # can be negative → subnormal
|
|
|
|
# Normal path: exp >= 1, just truncate mantissa top 3 bits
|
|
# RNE rounding: need guard (bit 19), sticky (OR of bits 18:0), and LSB of result
|
|
normal_mant = scale_mant >> 20
|
|
guard_bit = (scale_mant >> 19) & 1
|
|
sticky_bit = tl.where((scale_mant & 0x7FFFF) != 0, 1, 0) # OR of bits [18:0]
|
|
result_lsb = normal_mant & 1
|
|
# RNE: round up if (guard=1 and sticky=1) or (guard=1 and sticky=0 and lsb=1)
|
|
round_up = guard_bit & (sticky_bit | result_lsb)
|
|
normal_mant = normal_mant + round_up
|
|
normal_exp = e4m3_exp_raw
|
|
|
|
# Subnormal path: exp_raw <= 0
|
|
# Insert implicit leading 1 and right-shift by (1 - exp_raw)
|
|
# E4M3 subnormal: value = (mant/8) * 2^(1-7) = (mant/8) * 2^-6
|
|
# So we need: (1 + mant_fp32/2^23) * 2^(exp_raw - 7) = (shifted_mant/8) * 2^-6
|
|
# shifted_mant = (implicit_1 | mant_fp32) >> (1 - exp_raw - 1) then take top 3 bits
|
|
shift = 1 - e4m3_exp_raw # positive when subnormal
|
|
mant_with_leading = (0x800000 | scale_mant) # insert implicit 1
|
|
# Right-shift to get into the 3-bit E4M3 mantissa window
|
|
# We want bits [shift+19 : shift+23) of mant_with_leading for 3 mantissa bits + 1 round bit
|
|
subnormal_mant = (mant_with_leading >> (shift.to(tl.int32) + 20)) & 0x7
|
|
sub_guard_bit = (mant_with_leading >> (shift.to(tl.int32) + 19)) & 1
|
|
# Sticky: OR of all bits below the guard bit in the shifted result
|
|
# shift ≤ 8 in practice (amax floor = 1e-4 → scale ≈ 2^-15 → exp_raw ≈ -7), so mask ≤ 2^27
|
|
sub_sticky_mask = (1 << (shift.to(tl.int32) + 19)) - 1
|
|
sub_sticky_bit = tl.where((mant_with_leading & sub_sticky_mask) != 0, 1, 0)
|
|
sub_result_lsb = subnormal_mant & 1
|
|
sub_round_up = sub_guard_bit & (sub_sticky_bit | sub_result_lsb)
|
|
subnormal_mant = subnormal_mant + sub_round_up
|
|
|
|
is_normal = e4m3_exp_raw >= 1
|
|
e4m3_mant = tl.where(is_normal, normal_mant, subnormal_mant)
|
|
e4m3_exp = tl.where(is_normal, normal_exp, 0) # exp=0 for subnormals
|
|
|
|
# Handle mantissa overflow after rounding
|
|
overflow = e4m3_mant >= 8
|
|
e4m3_mant = tl.where(overflow, 0, e4m3_mant)
|
|
e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp)
|
|
e4m3_exp = tl.maximum(e4m3_exp, 0)
|
|
e4m3_exp = tl.minimum(e4m3_exp, 15)
|
|
scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant
|
|
|
|
# Reconstruct dequantized scale by decoding the STORED E4M3 bits.
|
|
# This guarantees the E2M1 quantization divides by exactly the value
|
|
# the CUDA kernel will multiply back — same bits, single decode, no
|
|
# possibility of encode/decode disagreement.
|
|
stored_exp = (scale_e4m3_bits >> 3) & 0xF
|
|
stored_mant = scale_e4m3_bits & 0x7
|
|
e4m3_exp_for_recon = tl.maximum(stored_exp.to(tl.int32) - 7, -126)
|
|
two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23
|
|
two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True)
|
|
normal_value = (1.0 + stored_mant.to(tl.float32) / 8.0) * two_pow_exp
|
|
subnormal_value = (stored_mant.to(tl.float32) / 8.0) * 0.015625
|
|
e4m3_value = tl.where(stored_exp == 0, subnormal_value, normal_value)
|
|
|
|
# ---- E2M1 FP4 quantization (unpacked, 1 byte/element) ----
|
|
# E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
|
# Nearest-neighbor using thresholds (midpoints between consecutive values)
|
|
scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None]
|
|
# Clamp to E2M1 range [-6, 6]
|
|
scaled = tl.maximum(scaled, -6.0)
|
|
scaled = tl.minimum(scaled, 6.0)
|
|
|
|
abs_s = tl.abs(scaled)
|
|
# Thresholds: midpoints between [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
|
# [0, 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0, INF]
|
|
e2m1_idx = tl.where(abs_s < 0.25, 0,
|
|
tl.where(abs_s < 0.75, 1,
|
|
tl.where(abs_s < 1.25, 2,
|
|
tl.where(abs_s < 1.75, 3,
|
|
tl.where(abs_s < 2.5, 4,
|
|
tl.where(abs_s < 3.5, 5,
|
|
tl.where(abs_s < 5.0, 6, 7)))))))
|
|
sign_bit = (scaled < 0).to(tl.int32)
|
|
e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index
|
|
|
|
# Pack E2M1 pairs into single bytes (2 per byte, low nibble first)
|
|
# mxf4nvf4 reads FP4 packed from SMEM — must match kernel's TMA layout
|
|
e2m1_flat = tl.reshape(e2m1_4bit, [BLOCK_K])
|
|
e2m1_lo = e2m1_flat[0::2] # even indices → low nibble
|
|
e2m1_hi = e2m1_flat[1::2] # odd indices → high nibble
|
|
e2m1_packed = (e2m1_hi << 4 | e2m1_lo).to(tl.uint8) # [BLOCK_K // 2]
|
|
|
|
k_offsets_out = k_block_id * (BLOCK_K // 2) + tl.arange(0, BLOCK_K // 2)
|
|
k_mask_out = k_offsets_out < (hidden_size // 2)
|
|
tl.store(
|
|
x_fp4 + token_id * x_stride_m + k_offsets_out * x_stride_k,
|
|
e2m1_packed,
|
|
mask=k_mask_out,
|
|
)
|
|
|
|
# Pack UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements)
|
|
# 8 groups per k_block of 128 → 2 int32s per k_block
|
|
# int32 can only pack 4 bytes (shifts >= 32 are UB), so split into two packs
|
|
scale_offsets = tl.arange(0, num_groups) # [0..7]
|
|
first_half = scale_offsets < 4 # groups 0-3 → int32[0]
|
|
second_half = scale_offsets >= 4 # groups 4-7 → int32[1]
|
|
|
|
packed_lo = tl.sum(
|
|
tl.where(first_half, scale_e4m3_bits.to(tl.int32) << (scale_offsets * 8), 0),
|
|
axis=0,
|
|
).to(tl.int32)
|
|
packed_hi = tl.sum(
|
|
tl.where(second_half, scale_e4m3_bits.to(tl.int32) << ((scale_offsets - 4) * 8), 0),
|
|
axis=0,
|
|
).to(tl.int32)
|
|
|
|
# Write 2 int32s per k_block: x_sf shape is (M, K//64) = (M, num_k_blocks * 2)
|
|
sf_base = token_id * x_sf_stride_m + k_block_id * 2 * x_sf_stride_k
|
|
tl.store(x_sf + sf_base, packed_lo)
|
|
tl.store(x_sf + sf_base + x_sf_stride_k, packed_hi)
|
|
|
|
if k_block_id == 0:
|
|
topk_offsets = tl.arange(0, BLOCK_TOPK)
|
|
topk_mask = topk_offsets < top_k
|
|
|
|
ids = tl.load(
|
|
topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k,
|
|
mask=topk_mask,
|
|
other=0,
|
|
).to(tl.int64)
|
|
tl.store(
|
|
topk_idx_out
|
|
+ token_id * topk_idx_stride_m
|
|
+ topk_offsets * topk_idx_stride_k,
|
|
ids,
|
|
mask=topk_mask,
|
|
)
|
|
|
|
weights = tl.load(
|
|
topk_weights
|
|
+ token_id * topk_weights_stride_m
|
|
+ topk_offsets * topk_weights_stride_k,
|
|
mask=topk_mask,
|
|
other=0.0,
|
|
)
|
|
tl.store(
|
|
topk_weights_out
|
|
+ token_id * topk_weights_out_stride_m
|
|
+ topk_offsets * topk_weights_out_stride_k,
|
|
weights,
|
|
mask=topk_mask,
|
|
)
|
|
|
|
|
|
def _stage_deepseek_v4_mega_moe_inputs(
|
|
hidden_states: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
x_fp4: torch.Tensor, # uint8, shape (M, K//2)
|
|
x_sf: torch.Tensor, # int32, shape (M, K//64)
|
|
topk_idx_out: torch.Tensor,
|
|
topk_weights_out: torch.Tensor,
|
|
) -> None:
|
|
num_tokens, hidden_size = hidden_states.shape
|
|
if num_tokens == 0:
|
|
return
|
|
if hidden_size % 128 != 0:
|
|
raise ValueError(
|
|
"DeepSeek V4 MegaMoE input staging requires hidden_size to be "
|
|
"a multiple of 128."
|
|
)
|
|
top_k = topk_ids.shape[1]
|
|
if topk_weights.shape != topk_ids.shape:
|
|
raise ValueError(
|
|
"DeepSeek V4 MegaMoE input staging requires topk_weights and "
|
|
"topk_ids to have the same shape."
|
|
)
|
|
|
|
block_k = 128
|
|
grid = (num_tokens, triton.cdiv(hidden_size, block_k))
|
|
block_topk = triton.next_power_of_2(top_k)
|
|
_deepseek_v4_stage_mega_moe_inputs_kernel[grid](
|
|
hidden_states,
|
|
x_fp4,
|
|
x_sf,
|
|
topk_ids,
|
|
topk_weights,
|
|
topk_idx_out,
|
|
topk_weights_out,
|
|
hidden_states.stride(0),
|
|
hidden_states.stride(1),
|
|
x_fp4.stride(0),
|
|
x_fp4.stride(1),
|
|
x_sf.stride(0),
|
|
x_sf.stride(1),
|
|
topk_ids.stride(0),
|
|
topk_ids.stride(1),
|
|
topk_weights.stride(0),
|
|
topk_weights.stride(1),
|
|
topk_idx_out.stride(0),
|
|
topk_idx_out.stride(1),
|
|
topk_weights_out.stride(0),
|
|
topk_weights_out.stride(1),
|
|
hidden_size,
|
|
top_k,
|
|
BLOCK_K=block_k,
|
|
GROUP_K=16, # NVFP4: group_size=16 (scale_vec::4X)
|
|
BLOCK_TOPK=block_topk,
|
|
num_warps=4,
|
|
)
|