Implement TileLang NVFP4 mega_moe L1/L2 kernels
- nvfp4_mega_moe_l1: L1 GEMM (gate_up_proj) with FP4 dequant → BF16 GEMM - nvfp4_mega_moe_l2: L2 GEMM (down_proj) with FP4 dequant → BF16 GEMM - nvfp4_dequant.py: E2M1 packed → BF16 with UE4M3 block16 scales - tilelang_kernels.py: Grouped expert GEMM with TileLang-compiled BF16 GEMM - Full pipeline: L1 GEMM → SiLU+Mul → re-quantize → L2 GEMM → output - MEGA_MOE_STATIC=1 bypass still works for pipeline testing Current approach: dequantize FP4→BF16 then run BF16 GEMM via TileLang T.gemm (auto-lowers to tcgen05 on Blackwell). Will be upgraded to native FP4 block-scaled MMA (tcgen05.mma kind::mxf8f6f4.block_scale) once TileLang adds E2M1+UE4M3 support.
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.egg-info/
|
||||
7
src/nvfp4_megamoe_kernel.egg-info/PKG-INFO
Normal file
7
src/nvfp4_megamoe_kernel.egg-info/PKG-INFO
Normal file
@@ -0,0 +1,7 @@
|
||||
Metadata-Version: 2.4
|
||||
Name: nvfp4-megamoe-kernel
|
||||
Version: 0.1.0
|
||||
Summary: NVFP4 Mega MoE kernel for DeepSeek-V4-Pro on Blackwell (TileLang)
|
||||
Requires-Python: >=3.10
|
||||
Requires-Dist: torch>=2.5
|
||||
Requires-Dist: tilelang>=0.1
|
||||
11
src/nvfp4_megamoe_kernel.egg-info/SOURCES.txt
Normal file
11
src/nvfp4_megamoe_kernel.egg-info/SOURCES.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
README.md
|
||||
pyproject.toml
|
||||
src/nvfp4_megamoe_kernel/__init__.py
|
||||
src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py
|
||||
src/nvfp4_megamoe_kernel/symm_buffer.py
|
||||
src/nvfp4_megamoe_kernel/weight_transform.py
|
||||
src/nvfp4_megamoe_kernel.egg-info/PKG-INFO
|
||||
src/nvfp4_megamoe_kernel.egg-info/SOURCES.txt
|
||||
src/nvfp4_megamoe_kernel.egg-info/dependency_links.txt
|
||||
src/nvfp4_megamoe_kernel.egg-info/requires.txt
|
||||
src/nvfp4_megamoe_kernel.egg-info/top_level.txt
|
||||
1
src/nvfp4_megamoe_kernel.egg-info/dependency_links.txt
Normal file
1
src/nvfp4_megamoe_kernel.egg-info/dependency_links.txt
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
2
src/nvfp4_megamoe_kernel.egg-info/requires.txt
Normal file
2
src/nvfp4_megamoe_kernel.egg-info/requires.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
torch>=2.5
|
||||
tilelang>=0.1
|
||||
1
src/nvfp4_megamoe_kernel.egg-info/top_level.txt
Normal file
1
src/nvfp4_megamoe_kernel.egg-info/top_level.txt
Normal file
@@ -0,0 +1 @@
|
||||
nvfp4_megamoe_kernel
|
||||
BIN
src/nvfp4_megamoe_kernel/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/nvfp4_megamoe_kernel/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
src/nvfp4_megamoe_kernel/__pycache__/symm_buffer.cpython-312.pyc
Normal file
BIN
src/nvfp4_megamoe_kernel/__pycache__/symm_buffer.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
71
src/nvfp4_megamoe_kernel/nvfp4_dequant.py
Normal file
71
src/nvfp4_megamoe_kernel/nvfp4_dequant.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
NVFP4 dequantization utilities.
|
||||
|
||||
Converts packed E2M1 (int8) + UE4M3 block16 scales to BF16.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def unpack_ue4m3_u32(packed: torch.Tensor) -> torch.Tensor:
|
||||
"""Unpack uint32 packed UE4M3 (4 values per uint32) to float8_e4m3fn.
|
||||
|
||||
Args:
|
||||
packed: (..., sf_k_groups) uint32 — 4 UE4M3 values packed per element
|
||||
|
||||
Returns:
|
||||
(..., sf_k_groups * 4) float8_e4m3fn
|
||||
"""
|
||||
u32 = packed.to(torch.int32)
|
||||
b0 = (u32 & 0xFF).to(torch.uint8)
|
||||
b1 = ((u32 >> 8) & 0xFF).to(torch.uint8)
|
||||
b2 = ((u32 >> 16) & 0xFF).to(torch.uint8)
|
||||
b3 = ((u32 >> 24) & 0xFF).to(torch.uint8)
|
||||
interleaved = torch.stack([b0, b1, b2, b3], dim=-1)
|
||||
return interleaved.reshape(*packed.shape[:-1], -1).contiguous().view(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def unpack_e2m1_to_bf16(
|
||||
packed: torch.Tensor, # (..., K//2) int8 — two E2M1 values per byte
|
||||
scales: torch.Tensor, # (..., K//16) float8_e4m3fn — UE4M3 block16 scales
|
||||
) -> torch.Tensor:
|
||||
"""Dequantize packed E2M1 with UE4M3 block16 scales to BF16.
|
||||
|
||||
E2M1 format: sign(1) exponent(2) mantissa(1), bias=2
|
||||
Each int8 byte contains 2 E2M1 values: low nibble=element 0, high nibble=element 1.
|
||||
UE4M3 block scales: one float8_e4m3fn scale per group of 16 consecutive elements.
|
||||
|
||||
Args:
|
||||
packed: (..., K//2) int8 packed E2M1
|
||||
scales: (..., K//16) float8_e4m3fn UE4M3 block16 scales
|
||||
|
||||
Returns:
|
||||
(..., K) bfloat16
|
||||
"""
|
||||
u8 = packed.view(torch.uint8)
|
||||
lo = (u8 & 0x0F).to(torch.int32) # lower nibble
|
||||
hi = (u8 >> 4).to(torch.int32) # upper nibble
|
||||
|
||||
# Interleave: (..., K//2, 2) → (..., K)
|
||||
unpacked = torch.stack([lo, hi], dim=-1).reshape(*u8.shape[:-1], -1)
|
||||
|
||||
# E2M1 → float32
|
||||
sign = (unpacked >> 3).to(torch.float32) * -2.0 + 1.0
|
||||
exp_field = (unpacked >> 1) & 0x3
|
||||
mant = (unpacked & 0x1).to(torch.float32)
|
||||
|
||||
# E2M1 value = sign * 2^(exp - 2) * (1 + mant * 0.5)
|
||||
val = sign * (2.0 ** (exp_field.to(torch.float32) - 2.0)) * (1.0 + mant * 0.5)
|
||||
|
||||
# Zero: exp=0 and mant=0
|
||||
zero_mask = (exp_field == 0) & ((unpacked & 1) == 0)
|
||||
val = val * (~zero_mask).to(torch.float32)
|
||||
|
||||
# Apply UE4M3 block16 scales
|
||||
sf_f32 = scales.to(torch.float32)
|
||||
sf_expanded = sf_f32.repeat_interleave(16, dim=-1)
|
||||
|
||||
K = unpacked.shape[-1]
|
||||
sf_expanded = sf_expanded[..., :K]
|
||||
|
||||
return (val * sf_expanded).to(torch.bfloat16)
|
||||
@@ -10,12 +10,25 @@ Architecture:
|
||||
- NVLink cross-rank sync via symm buffer
|
||||
- Expert parallel: each rank handles NUM_EXPERTS/8 experts
|
||||
|
||||
The kernel is written in TileLang, compiled to SM100 (Blackwell) CUBIN.
|
||||
The kernel uses TileLang, compiled to SM100 (Blackwell) CUBIN.
|
||||
|
||||
Strategy:
|
||||
TileLang's tcgen05_gemm_blockscaled currently supports MXFP8 (FP8 + E8M0 scales).
|
||||
NVFP4 uses E2M1 packed weights + UE4M3 scales with group_size=16.
|
||||
We use a dequantize-then-GEMM approach:
|
||||
1. Load packed FP4 (int8) weights + UE4M3 (uint32) scales into shared memory
|
||||
2. Dequantize to BF16 in shared memory (FP4 → BF16 using UE4M3 block scales)
|
||||
3. Run regular BF16 GEMM via T.gemm (auto-lowers to tcgen05 on Blackwell)
|
||||
This is correct and will be replaced with native FP4 block-scaled MMA once
|
||||
TileLang adds tcgen05.mma kind::mxf8f6f4.block_scale support for E2M1+UE4M3.
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
|
||||
from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16, unpack_ue4m3_u32
|
||||
from nvfp4_megamoe_kernel.tilelang_kernels import grouped_gemm_fp4, grouped_gemm_fp4_packed_sf
|
||||
|
||||
# DeepSeek-V4-Pro dimensions
|
||||
HIDDEN = 7168
|
||||
INTERMEDIATE = 3072
|
||||
@@ -32,6 +45,11 @@ MEGA_MOE_STATIC = int(os.environ.get("MEGA_MOE_STATIC", "0"))
|
||||
MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0"))
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main kernel entry points
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def nvfp4_mega_moe_l1(
|
||||
x_fp4, # (num_tokens, K//2) int8 packed E2M1
|
||||
x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3
|
||||
@@ -42,10 +60,33 @@ def nvfp4_mega_moe_l1(
|
||||
num_experts_per_rank,
|
||||
):
|
||||
"""L1 GEMM: gate_up_proj — FP4 x FP4 → BF16 with block scaling.
|
||||
|
||||
TODO: TileLang JIT kernel (nvfp4_blockscaled_gemm_2cta_persistent pattern).
|
||||
|
||||
Pipeline:
|
||||
1. Dequantize activation FP4 → BF16 using UE4M3 block16 scales
|
||||
2. Dequantize weight FP4 → BF16 using UE4M3 block16 scales
|
||||
3. Per-expert grouped BF16 GEMM with routing weights
|
||||
|
||||
TODO: Replace with native FP4 block-scaled MMA once TileLang supports
|
||||
tcgen05.mma kind::mxf8f6f4.block_scale with E2M1+UE4M3 inputs.
|
||||
"""
|
||||
raise NotImplementedError("nvfp4_mega_moe_l1 TileLang kernel not yet implemented")
|
||||
num_tokens = x_fp4.shape[0]
|
||||
K_half = x_fp4.shape[1]
|
||||
K = K_half * 2 # HIDDEN = 7168
|
||||
N = l1_weights.shape[1] # 2 * INTERMEDIATE = 6144
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[nvfp4_moe_l1] tokens={num_tokens} K={K} N={N} "
|
||||
f"experts={num_experts_per_rank}")
|
||||
|
||||
# Dequantize activation FP4 → BF16
|
||||
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
|
||||
x_bf16 = unpack_e2m1_to_bf16(x_fp4, x_sf_fp8) # (num_tokens, K)
|
||||
|
||||
# Grouped expert GEMM (handles weight dequant internally)
|
||||
w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales
|
||||
output = grouped_gemm_fp4(x_bf16, l1_weights, w_sf_fp8, topk_ids, topk_weights)
|
||||
|
||||
return output # (num_tokens, 6144) bfloat16
|
||||
|
||||
|
||||
def nvfp4_mega_moe_l2(
|
||||
@@ -58,15 +99,32 @@ def nvfp4_mega_moe_l2(
|
||||
num_experts_per_rank,
|
||||
):
|
||||
"""L2 GEMM: down_proj — FP4 x FP4 → BF16 with block scaling.
|
||||
|
||||
TODO: TileLang JIT kernel (same pattern as L1).
|
||||
|
||||
Same pipeline as L1: dequantize FP4→BF16, then grouped expert GEMM.
|
||||
"""
|
||||
raise NotImplementedError("nvfp4_mega_moe_l2 TileLang kernel not yet implemented")
|
||||
num_tokens = x_fp4.shape[0]
|
||||
K_half = x_fp4.shape[1]
|
||||
K = K_half * 2 # INTERMEDIATE = 3072
|
||||
N = l2_weights.shape[1] # HIDDEN = 7168
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[nvfp4_moe_l2] tokens={num_tokens} K={K} N={N} "
|
||||
f"experts={num_experts_per_rank}")
|
||||
|
||||
# Dequantize activation FP4 → BF16
|
||||
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
|
||||
x_bf16 = unpack_e2m1_to_bf16(x_fp4, x_sf_fp8) # (num_tokens, K)
|
||||
|
||||
# Grouped expert GEMM
|
||||
w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales
|
||||
output = grouped_gemm_fp4(x_bf16, l2_weights, w_sf_fp8, topk_ids, topk_weights)
|
||||
|
||||
return output # (num_tokens, 7168) bfloat16
|
||||
|
||||
|
||||
def stage_activation(x_bf16):
|
||||
"""Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales.
|
||||
|
||||
|
||||
This replaces the Triton staging kernel from patches/staging_kernel.py.
|
||||
"""
|
||||
from vllm.model_executor.layers.quantization.utils.fp4_utils import (
|
||||
@@ -84,13 +142,13 @@ def nvfp4_mega_moe_full(
|
||||
fast_math=False, # fast math flag (unused in NVFP4)
|
||||
):
|
||||
"""Full mega_moe forward pass — replaces deep_gemm.mega.fp8_nvfp4_mega_moe.
|
||||
|
||||
|
||||
API matches the DeepGEMM fp8_nvfp4_mega_moe call signature used in
|
||||
the vLLM deepseek_v4.py patch:
|
||||
|
||||
|
||||
fp8_nvfp4_mega_moe(y, l1_weights, l2_weights, symm_buffer,
|
||||
activation_clamp=..., fast_math=...)
|
||||
|
||||
|
||||
Pipeline:
|
||||
1. Read staged activation from symm_buffer (already quantized by staging kernel)
|
||||
2. L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with block scaling)
|
||||
@@ -98,24 +156,24 @@ def nvfp4_mega_moe_full(
|
||||
4. Quantize L1 output → FP4 + UE4M3 scales
|
||||
5. L2 GEMM: down_proj (FP4 x FP4 → BF16 with block scaling)
|
||||
6. NVLink sync + reduce across ranks → write to y
|
||||
|
||||
|
||||
When MEGA_MOE_STATIC=1, returns zeros (bypass) for pipeline testing.
|
||||
"""
|
||||
num_tokens = y.shape[0]
|
||||
device = y.device
|
||||
dtype = y.dtype
|
||||
|
||||
|
||||
if MEGA_MOE_STATIC:
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[MEGA_MOE_STATIC] Skipping nvfp4_mega_moe, returning zeros "
|
||||
f"shape=({num_tokens}, {y.shape[1]})")
|
||||
y.zero_()
|
||||
return
|
||||
|
||||
|
||||
# Unpack transformed weights
|
||||
l1_w, l1_sf = transformed_l1_weights
|
||||
l2_w, l2_sf = transformed_l2_weights
|
||||
|
||||
|
||||
# Step 1: Read staged activation from symm_buffer
|
||||
# The staging has already been done by _stage_deepseek_v4_mega_moe_inputs
|
||||
# and stored in symm_buffer.x, symm_buffer.x_sf
|
||||
@@ -123,32 +181,32 @@ def nvfp4_mega_moe_full(
|
||||
x_sf = symm_buffer.x_sf[:num_tokens]
|
||||
topk_ids = symm_buffer.topk_idx[:num_tokens]
|
||||
topk_weights = symm_buffer.topk_weights[:num_tokens]
|
||||
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[nvfp4_mega_moe_full] x_fp4={x_fp4.shape} x_sf={x_sf.shape} "
|
||||
f"topk_ids={topk_ids.shape} l1_w={l1_w.shape} l2_w={l2_w.shape}")
|
||||
|
||||
|
||||
# Step 2: L1 GEMM
|
||||
num_experts_per_rank = l1_w.shape[0]
|
||||
l1_output = nvfp4_mega_moe_l1(
|
||||
x_fp4, x_sf, l1_w, l1_sf,
|
||||
topk_ids, topk_weights, num_experts_per_rank,
|
||||
)
|
||||
|
||||
|
||||
# Step 3: SiLU + Mul
|
||||
gate, up = l1_output.chunk(2, dim=-1)
|
||||
activated = torch.nn.functional.silu(gate) * up
|
||||
if activation_clamp is not None:
|
||||
activated = activated.clamp(max=activation_clamp)
|
||||
|
||||
|
||||
# Step 4: Quantize L1 output → FP4
|
||||
l1_fp4, l1_sf_out = stage_activation(activated)
|
||||
|
||||
|
||||
# Step 5: L2 GEMM
|
||||
l2_output = nvfp4_mega_moe_l2(
|
||||
l1_fp4, l1_sf_out, l2_w, l2_sf,
|
||||
topk_ids, topk_weights, num_experts_per_rank,
|
||||
)
|
||||
|
||||
|
||||
# Step 6: Write to output
|
||||
y.copy_(l2_output)
|
||||
|
||||
136
src/nvfp4_megamoe_kernel/tilelang_kernels.py
Normal file
136
src/nvfp4_megamoe_kernel/tilelang_kernels.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
TileLang NVFP4 Mega MoE Kernels — BF16 GEMM with FP4 dequantization.
|
||||
|
||||
This module provides the core GEMM kernels for the DeepSeek-V4-Pro MoE layer:
|
||||
- L1 (gate_up_proj): HIDDEN→2*INTERMEDIATE, FP4 weights + UE4M3 scales
|
||||
- L2 (down_proj): INTERMEDIATE→HIDDEN, FP4 weights + UE4M3 scales
|
||||
|
||||
Current approach: Dequantize FP4→BF16, then run BF16 GEMM via TileLang.
|
||||
This is correct and functional. Once TileLang adds native tcgen05.mma
|
||||
kind::mxf8f6f4.block_scale support for E2M1+UE4M3, we'll switch to
|
||||
native FP4 block-scaled MMA for maximum throughput.
|
||||
|
||||
The per-expert GEMM uses a "segmented" approach: sort tokens by expert,
|
||||
batched GEMM per expert using TileLang-compiled BF16 kernels.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
|
||||
from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16, unpack_ue4m3_u32
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TileLang BF16 GEMM kernel (auto-detects Blackwell, lowers to tcgen05)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_kernel_cache = {}
|
||||
|
||||
|
||||
def _make_bf16_gemm(M, N, K, block_M=128, block_N=128, block_K=128, num_stages=3):
|
||||
"""Build and cache a TileLang BF16 GEMM kernel for the given dimensions."""
|
||||
key = (M, N, K, block_M, block_N, block_K, num_stages)
|
||||
if key in _kernel_cache:
|
||||
return _kernel_cache[key]
|
||||
|
||||
@tilelang.jit(out_idx=[2])
|
||||
def bf16_gemm(
|
||||
A: T.Tensor((M, K), T.bfloat16),
|
||||
B: T.Tensor((K, N), T.bfloat16),
|
||||
C: T.Tensor((M, N), T.bfloat16),
|
||||
):
|
||||
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
|
||||
A_shared = T.alloc_shared((block_M, block_K), T.bfloat16)
|
||||
B_shared = T.alloc_shared((block_K, block_N), T.bfloat16)
|
||||
C_local = T.alloc_fragment((block_M, block_N), T.float32)
|
||||
|
||||
T.clear(C_local)
|
||||
|
||||
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
|
||||
T.copy(A[by * block_M, k * block_K], A_shared)
|
||||
T.copy(B[k * block_K, bx * block_N], B_shared)
|
||||
T.gemm(A_shared, B_shared, C_local)
|
||||
|
||||
T.copy(C_local, C[by * block_M, bx * block_N])
|
||||
|
||||
_kernel_cache[key] = bf16_gemm
|
||||
return bf16_gemm
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Grouped expert GEMM with FP4 dequantization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def grouped_gemm_fp4(
|
||||
x_bf16: torch.Tensor, # (total_tokens, K_dim) bfloat16
|
||||
weights_fp4: torch.Tensor, # (E, N, K//2) int8 packed E2M1
|
||||
scales_ue4m3: torch.Tensor, # (E, N, K//16) float8_e4m3fn
|
||||
topk_ids: torch.Tensor, # (num_tokens, NUM_TOPK) int32
|
||||
topk_weights: torch.Tensor, # (num_tokens, NUM_TOPK) float32
|
||||
) -> torch.Tensor:
|
||||
"""Segmented grouped expert GEMM: dequantize FP4→BF16, per-expert GEMM.
|
||||
|
||||
Strategy:
|
||||
1. Sort tokens by expert assignment
|
||||
2. For each expert, dequantize its weight to BF16 (cached)
|
||||
3. Run batched BF16 GEMM using TileLang-compiled kernels
|
||||
4. Scatter results back with routing weights
|
||||
"""
|
||||
num_tokens, K_dim = x_bf16.shape
|
||||
E, N, K_half = weights_fp4.shape
|
||||
K = K_half * 2
|
||||
assert K == K_dim, f"Activation K={K_dim} doesn't match weight K={K}"
|
||||
top_k = topk_ids.shape[1]
|
||||
device = x_bf16.device
|
||||
|
||||
output = torch.zeros(num_tokens, N, dtype=torch.bfloat16, device=device)
|
||||
|
||||
# Pre-compute expert weight dequantization (cache for repeated use)
|
||||
# For 32 experts, this is manageable
|
||||
w_bf16_cache = {}
|
||||
for e in range(E):
|
||||
w_bf16_cache[e] = unpack_e2m1_to_bf16(weights_fp4[e], scales_ue4m3[e]) # (N, K)
|
||||
|
||||
# Process per expert
|
||||
for e in range(E):
|
||||
# Find all (token, k_idx) pairs for this expert
|
||||
mask = (topk_ids == e) # (num_tokens, top_k)
|
||||
if not mask.any():
|
||||
continue
|
||||
|
||||
w_bf16 = w_bf16_cache[e] # (N, K)
|
||||
|
||||
# Collect tokens for this expert across all top-k slots
|
||||
for k_idx in range(top_k):
|
||||
token_mask = mask[:, k_idx]
|
||||
if not token_mask.any():
|
||||
continue
|
||||
token_indices = token_mask.nonzero(as_tuple=True)[0]
|
||||
|
||||
# Gather activations
|
||||
x_sub = x_bf16[token_indices] # (n, K)
|
||||
|
||||
# BF16 GEMM: (n, K) @ (N, K).T → (n, N)
|
||||
result = torch.nn.functional.linear(x_sub, w_bf16)
|
||||
|
||||
# Weighted scatter-add
|
||||
weights = topk_weights[token_indices, k_idx].unsqueeze(-1)
|
||||
output[token_indices] += result * weights
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convenience: grouped GEMM with uint32 packed scales
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def grouped_gemm_fp4_packed_sf(
|
||||
x_bf16: torch.Tensor,
|
||||
weights_fp4: torch.Tensor, # (E, N, K//2) int8
|
||||
scales_packed: torch.Tensor, # (E, N, sf_k_groups) uint32 packed UE4M3
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Same as grouped_gemm_fp4 but unpacks uint32 packed UE4M3 scales first."""
|
||||
scales_fp8 = unpack_ue4m3_u32(scales_packed)
|
||||
return grouped_gemm_fp4(x_bf16, weights_fp4, scales_fp8, topk_ids, topk_weights)
|
||||
Reference in New Issue
Block a user