fix: staging kernel outputs unpacked E2M1 (1 byte/element, not packed 2/byte)
Matches the SMEM layout: float_e2m1_unpacksmem_t is 1 byte/element. L1→L2 handoff uses unpacked format (same byte count as FP8). No bandwidth savings at L1→L2 for v1 — can optimize later.
This commit is contained in:
@@ -8,7 +8,7 @@ RUN apt-get update && apt-get install -y git screen cmake && rm -rf /var/lib/apt
|
||||
|
||||
# Clone and build DeepGEMM with NVFP4 mega_moe kernel
|
||||
# CACHE_BUSTER: increment to force fresh clone
|
||||
RUN git clone -b nvfp4-mega-moe https://sweetapi.com/biondizzle/DeepGEMM.git /root/DeepGEMM && CACHE_BUSTER=30
|
||||
RUN git clone -b nvfp4-mega-moe https://sweetapi.com/biondizzle/DeepGEMM.git /root/DeepGEMM && CACHE_BUSTER=31
|
||||
|
||||
# Build DeepGEMM with proper CUDA/NVRTC paths
|
||||
ENV CPATH="/usr/local/lib/python3.12/dist-packages/flashinfer/data/cutlass/include:/usr/local/lib/python3.12/dist-packages/nvidia/cu13/include:/usr/local/cuda-13.0/include:${CPATH}"
|
||||
@@ -19,7 +19,7 @@ RUN ln -sf /usr/local/lib/python3.12/dist-packages/nvidia/cu13/lib/libnvrtc.so.1
|
||||
RUN cd /root/DeepGEMM && python3 setup.py build_ext --inplace
|
||||
|
||||
# Bust cache for patch changes — ARG before COPY ensures layer invalidation
|
||||
ARG PATCH_CACHE_BUSTER=30
|
||||
ARG PATCH_CACHE_BUSTER=31
|
||||
# Copy our DeepSeek V4 patch over vLLM's model file
|
||||
COPY patches/deepseek_v4.py /usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v4.py
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import torch
|
||||
@triton.jit
|
||||
def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
hidden_states,
|
||||
x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte
|
||||
x_fp4, # uint8, shape (M, K) — unpacked E2M1, 1 byte per element
|
||||
x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
@@ -84,7 +84,7 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625
|
||||
e4m3_value = tl.where(e4m3_exp == 0, subnormal_value, normal_value)
|
||||
|
||||
# ---- E2M1 FP4 quantization ----
|
||||
# ---- E2M1 FP4 quantization (unpacked, 1 byte/element) ----
|
||||
# E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||
# Nearest-neighbor using thresholds (midpoints between consecutive values)
|
||||
scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None]
|
||||
@@ -105,19 +105,16 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
sign_bit = (scaled < 0).to(tl.int32)
|
||||
e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index
|
||||
|
||||
# Pack 2 E2M1 values per byte: even→low nibble, odd→high nibble
|
||||
PACKED_K: tl.constexpr = BLOCK_K // 2 # 64
|
||||
e2m1_pairs = tl.reshape(e2m1_4bit, [PACKED_K, 2])
|
||||
even = e2m1_pairs[:, 0].to(tl.uint8)
|
||||
odd = e2m1_pairs[:, 1].to(tl.uint8)
|
||||
packed_byte = (odd << 4) | even
|
||||
# Unpacked: 1 byte per E2M1 element (same byte count as FP8)
|
||||
# Each E2M1 value stored as uint8 (4-bit value in low nibble)
|
||||
e2m1_flat = tl.reshape(e2m1_4bit, [BLOCK_K]).to(tl.uint8)
|
||||
|
||||
packed_k_offsets = k_block_id * PACKED_K + tl.arange(0, PACKED_K)
|
||||
packed_k_mask = packed_k_offsets < (hidden_size // 2)
|
||||
k_offsets_out = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
k_mask_out = k_offsets_out < hidden_size
|
||||
tl.store(
|
||||
x_fp4 + token_id * x_stride_m + packed_k_offsets * x_stride_k,
|
||||
packed_byte,
|
||||
mask=packed_k_mask,
|
||||
x_fp4 + token_id * x_stride_m + k_offsets_out * x_stride_k,
|
||||
e2m1_flat,
|
||||
mask=k_mask_out,
|
||||
)
|
||||
|
||||
# Pack 4 UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements)
|
||||
|
||||
Reference in New Issue
Block a user