fix: staging kernel outputs unpacked E2M1 (1 byte/element, not packed 2/byte)

Matches the SMEM layout: float_e2m1_unpacksmem_t is 1 byte/element.
L1→L2 handoff uses unpacked format (same byte count as FP8).
No bandwidth savings at L1→L2 for v1 — can optimize later.
This commit is contained in:
2026-05-11 21:29:33 +00:00
parent 01cfd02759
commit c85b84b0fe
2 changed files with 12 additions and 15 deletions

View File

@@ -8,7 +8,7 @@ RUN apt-get update && apt-get install -y git screen cmake && rm -rf /var/lib/apt
# Clone and build DeepGEMM with NVFP4 mega_moe kernel
# CACHE_BUSTER: increment to force fresh clone
RUN git clone -b nvfp4-mega-moe https://sweetapi.com/biondizzle/DeepGEMM.git /root/DeepGEMM && CACHE_BUSTER=30
RUN git clone -b nvfp4-mega-moe https://sweetapi.com/biondizzle/DeepGEMM.git /root/DeepGEMM && CACHE_BUSTER=31
# Build DeepGEMM with proper CUDA/NVRTC paths
ENV CPATH="/usr/local/lib/python3.12/dist-packages/flashinfer/data/cutlass/include:/usr/local/lib/python3.12/dist-packages/nvidia/cu13/include:/usr/local/cuda-13.0/include:${CPATH}"
@@ -19,7 +19,7 @@ RUN ln -sf /usr/local/lib/python3.12/dist-packages/nvidia/cu13/lib/libnvrtc.so.1
RUN cd /root/DeepGEMM && python3 setup.py build_ext --inplace
# Bust cache for patch changes — ARG before COPY ensures layer invalidation
ARG PATCH_CACHE_BUSTER=30
ARG PATCH_CACHE_BUSTER=31
# Copy our DeepSeek V4 patch over vLLM's model file
COPY patches/deepseek_v4.py /usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v4.py

View File

@@ -12,7 +12,7 @@ import torch
@triton.jit
def _deepseek_v4_stage_mega_moe_inputs_kernel(
hidden_states,
x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte
x_fp4, # uint8, shape (M, K) — unpacked E2M1, 1 byte per element
x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32
topk_ids,
topk_weights,
@@ -84,7 +84,7 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625
e4m3_value = tl.where(e4m3_exp == 0, subnormal_value, normal_value)
# ---- E2M1 FP4 quantization ----
# ---- E2M1 FP4 quantization (unpacked, 1 byte/element) ----
# E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6]
# Nearest-neighbor using thresholds (midpoints between consecutive values)
scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None]
@@ -105,19 +105,16 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
sign_bit = (scaled < 0).to(tl.int32)
e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index
# Pack 2 E2M1 values per byte: even→low nibble, odd→high nibble
PACKED_K: tl.constexpr = BLOCK_K // 2 # 64
e2m1_pairs = tl.reshape(e2m1_4bit, [PACKED_K, 2])
even = e2m1_pairs[:, 0].to(tl.uint8)
odd = e2m1_pairs[:, 1].to(tl.uint8)
packed_byte = (odd << 4) | even
# Unpacked: 1 byte per E2M1 element (same byte count as FP8)
# Each E2M1 value stored as uint8 (4-bit value in low nibble)
e2m1_flat = tl.reshape(e2m1_4bit, [BLOCK_K]).to(tl.uint8)
packed_k_offsets = k_block_id * PACKED_K + tl.arange(0, PACKED_K)
packed_k_mask = packed_k_offsets < (hidden_size // 2)
k_offsets_out = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
k_mask_out = k_offsets_out < hidden_size
tl.store(
x_fp4 + token_id * x_stride_m + packed_k_offsets * x_stride_k,
packed_byte,
mask=packed_k_mask,
x_fp4 + token_id * x_stride_m + k_offsets_out * x_stride_k,
e2m1_flat,
mask=k_mask_out,
)
# Pack 4 UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements)