Restructure: cutedsl/ -> dsv4/ with proper layering

- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
This commit is contained in:
2026-05-21 17:30:44 +00:00
parent 99e143dd0e
commit 3fb3c925af
274 changed files with 715 additions and 556 deletions

View File

@@ -1 +0,0 @@
/root/dsv4-nvfp4-workspace/kernel/cutedsl

2
dsv4/cache/block_table.py vendored Normal file
View File

@@ -0,0 +1,2 @@
"""Block table for paged KV cache."""
# TODO: Phase 3

2
dsv4/cache/paged_cache.py vendored Normal file
View File

@@ -0,0 +1,2 @@
"""Paged KV cache."""
# TODO: Phase 3

2
dsv4/cache/state_cache.py vendored Normal file
View File

@@ -0,0 +1,2 @@
"""State cache for KV."""
# TODO: Phase 3

View File

@@ -0,0 +1,2 @@
"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Stage B+). Extracted from test_fmha_v3.py."""
# TODO: Extract FmhaV3 kernel class here

View File

View File

View File

@@ -1,4 +1,4 @@
"""
FP8 E4M3 -> BF16 conversion for CuTeDSL on Blackwell (SM100+).
STATUS: NOT USABLE INSIDE CUTE KERNELS.
@@ -23,4 +23,4 @@ or when we can properly construct vector<4xf8E4M3FN> inside kernel code,
we can fuse the dequant into the attention kernel. The PTX instruction
exists (cvt.rn.bf16x2.e4m3x2), but CuTeDSL's AST preprocessor currently
prevents us from injecting the necessary MLIR ops.
"""

View File

View File

View File

@@ -60,15 +60,15 @@ if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(current_dir, "../../.."))
from cutedsl.kernel.moe.moe_utils import (
from dsv4.kernels.gemm.utils import (
MoEScaledGroupedGemmTensormapConstructor,
)
from cutedsl.kernel.moe.moe_persistent_scheduler import (
from dsv4.kernels.gemm.scheduler import (
MoEStaticSchedulerParams,
MoEStaticPersistentTileScheduler,
MoEWorkTileInfo,
)
from cutedsl.kernel.moe.moe_sched_extension import ScaledGroupedMmSchedExtension
from dsv4.kernels.gemm.sched_extension import ScaledGroupedMmSchedExtension
import cutlass.utils.blackwell_helpers as sm100_utils
import cutlass.utils.blockscaled_layout as blockscaled_utils
from cutlass.utils.gemm.sm100 import (
@@ -3665,7 +3665,7 @@ class ScaledGroupedGemmTester:
if _examples_root not in sys.path:
sys.path.insert(0, _examples_root)
from cutedsl.kernel.blockscaled_gemm.dense_blockscaled_gemm_persistent import (
from dsv4.kernels.gemm.dense import (
Sm100BlockScaledPersistentDenseGemmKernel,
)
from cutlass.cute.nvgpu import OperandMajorMode

View File

@@ -60,15 +60,15 @@ if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(current_dir, "../../.."))
from cutedsl.kernel.moe.moe_utils import (
from dsv4.kernels.gemm.utils import (
MoEScaledGroupedGemmTensormapConstructor,
)
from cutedsl.kernel.moe.moe_persistent_scheduler import (
from dsv4.kernels.gemm.scheduler import (
MoEStaticSchedulerParams,
MoEStaticPersistentTileScheduler,
MoEWorkTileInfo,
)
from cutedsl.kernel.moe.moe_sched_extension import ScaledGroupedMmSchedExtension
from dsv4.kernels.gemm.sched_extension import ScaledGroupedMmSchedExtension
import cutlass.utils.blackwell_helpers as sm100_utils
import cutlass.utils.blockscaled_layout as blockscaled_utils
from cutlass.utils.gemm.sm100 import (
@@ -3608,7 +3608,7 @@ class ScaledGroupedGemmTester:
if _examples_root not in sys.path:
sys.path.insert(0, _examples_root)
from cutedsl.kernel.blockscaled_gemm.dense_blockscaled_gemm_persistent import (
from dsv4.kernels.gemm.dense import (
Sm100BlockScaledPersistentDenseGemmKernel,
)
from cutlass.cute.nvgpu import OperandMajorMode

View File

@@ -73,14 +73,14 @@ from cutlass.cutlass_dsl import Int32
from dataclasses import dataclass
from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF
from cutedsl.kernel.moe.moe_utils import (
from dsv4.kernels.gemm.utils import (
OnlineTensormapDescCreator,
tensormap_ptr_for_copy,
compute_expert_token_range,
rewrite_tensor_shape,
prefetch_tma_descriptor,
)
from cutedsl.kernel.moe.moe_persistent_scheduler import MoEWorkTileInfo
from dsv4.kernels.gemm.scheduler import MoEWorkTileInfo
@dataclass(frozen=True)

0
dsv4/layers/__init__.py Normal file
View File

2
dsv4/layers/attention.py Normal file
View File

@@ -0,0 +1,2 @@
"""DSV4 attention sub-block."""
# TODO: Phase 3+4

2
dsv4/layers/embedding.py Normal file
View File

@@ -0,0 +1,2 @@
"""Token embedding + mHC init wrapper."""
# TODO: Implement

2
dsv4/layers/ffn.py Normal file
View File

@@ -0,0 +1,2 @@
"""FFN: router + MoE + shared expert."""
# TODO: Phase 2

View File

@@ -14,22 +14,26 @@ CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
import torch
from cutedsl.bridge import (
from dsv4.ops.quantize import (
quantize_activation_nvfp4,
quantize_weight_to_nvfp4,
)
from dsv4.ops.layouts import (
make_b_k_major,
assemble_scales_2d_side,
assemble_scales_3d_side,
)
from dsv4.ops.gemm_runner import (
run_nvfp4_grouped_gemm,
)
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
from dsv4.ops.layouts import (
ceil_div as cutedsl_ceil_div,
pad_and_swizzle_single,
)
from cutedsl.custom_ops import register_runner, nvfp4_linear_gemm
from dsv4.ops.custom_ops import register_runner, nvfp4_linear_gemm
class CuTeDSLNvfp4WoA:
class Nvfp4GroupedLinear:
"""Grouped NVFP4 linear for wo_a (o-projection first half).
Handles the "bhr,hdr->bhd" einsum pattern:
@@ -181,7 +185,9 @@ class CuTeDSLNvfp4WoA:
# Reshape to grouped format, then flatten to 2D for quantization
o_grouped = o_sample.reshape(-1, self.n_local_groups, self.group_in_features)
# We need a single gs for all groups — use the overall amax
from cutedsl.bridge import quantize_to_nvfp4
from dsv4.ops.quantize import (
quantize_to_nvfp4,
)
o_flat = o_sample.reshape(-1, o_sample.shape[-1]) # (tokens, n_local_heads * head_dim) — not right
# Actually, for grouped GEMM, each group's activation is (tokens, group_in_features)
# The global scale should be computed per-group, but for simplicity use one scale
@@ -256,7 +262,9 @@ class CuTeDSLNvfp4WoA:
# Assemble A-side scales for all groups
# The grouped GEMM expects scales for all groups assembled together
# For 2Dx3D scenario, scale_a is assembled from per-group scale tensors
from cutedsl.bridge import assemble_scales_2d_side
from dsv4.ops.layouts import (
assemble_scales_2d_side,
)
scale_a = assemble_scales_2d_side(all_x_sf)
# Expert offsets: cumulative [padded_T, 2*padded_T, ..., n_groups*padded_T]

View File

@@ -8,21 +8,25 @@ CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
import torch
from cutedsl.bridge import (
from dsv4.ops.quantize import (
quantize_activation_nvfp4,
quantize_to_nvfp4,
)
from dsv4.ops.layouts import (
make_b_k_major,
assemble_scales_3d_side,
)
from dsv4.ops.gemm_runner import (
run_nvfp4_grouped_gemm,
)
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
from dsv4.kernels.gemm.grouped import (
ceil_div as cutedsl_ceil_div,
pad_and_swizzle_single,
)
from cutedsl.custom_ops import register_runner, nvfp4_linear_gemm
from dsv4.ops.custom_ops import register_runner, nvfp4_linear_gemm
class CuTeDSLNvfp4Linear:
class Nvfp4Linear:
"""Single NVFP4 GEMM using CuTeDSL (num_groups=1).
Handles any (K, N) weight matrix in NVFP4 format.
@@ -76,7 +80,6 @@ class CuTeDSLNvfp4Linear:
# Eagerly JIT-compile the GEMM kernel for this (K, N) shape.
# Uses num_groups=1 since this is a single linear layer.
# from cutedsl.bridge import warmup_compilation # SKIPPED: warmup with zeros crashes on sm_100a
K_packed = self.in_features // 2
N_packed = self.out_features // 2
# warmup_compilation(1, K_packed, N_packed, self.device) # Lazy compile on first real forward

View File

@@ -15,26 +15,30 @@ processes max_slots = budget * top_k rows; padding rows are zeros.
"""
import torch
from cutedsl.bridge import (
from dsv4.ops.quantize import (
quantize_activation_nvfp4,
quantize_weight_to_nvfp4,
quantize_to_nvfp4,
)
from dsv4.ops.layouts import (
make_b_k_major,
assemble_scales_3d_side,
interleave_l1_weights,
deinterleave_l1_weights,
)
from dsv4.ops.gemm_runner import (
run_nvfp4_grouped_gemm,
run_fused_swiglu_grouped_gemm,
warmup_fused_swiglu_compilation,
)
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
from dsv4.ops.layouts import (
ceil_div as cutedsl_ceil_div,
pad_and_swizzle_single,
)
from cutedsl.custom_ops import register_runner, nvfp4_moe_gemm
from dsv4.ops.custom_ops import register_runner, nvfp4_moe_gemm
class CuTeDSLMoERunner:
class Nvfp4MoE:
"""Manages NVFP4 MoE execution via the CuTeDSL kernel.
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs,
@@ -127,15 +131,15 @@ class CuTeDSLMoERunner:
# Initialize shared buffers dict (if not already)
device_key = str(self.device)
if not hasattr(CuTeDSLMoERunner, '_shared_padded_bufs'):
CuTeDSLMoERunner._shared_padded_bufs = {}
if device_key not in CuTeDSLMoERunner._shared_padded_bufs:
CuTeDSLMoERunner._shared_padded_bufs[device_key] = {}
if not hasattr(Nvfp4MoE, '_shared_padded_bufs'):
Nvfp4MoE._shared_padded_bufs = {}
if device_key not in Nvfp4MoE._shared_padded_bufs:
Nvfp4MoE._shared_padded_bufs[device_key] = {}
# Padded x_sf buffers: SHARED across all runners (not per-layer)
max_sf_rows = self.num_experts * self._max_chunks_per_expert * 128
if 'xsf_l1' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]:
CuTeDSLMoERunner._shared_padded_bufs[device_key].update({
if 'xsf_l1' not in Nvfp4MoE._shared_padded_bufs[device_key]:
Nvfp4MoE._shared_padded_bufs[device_key].update({
'xsf_l1': torch.zeros(
max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn),
@@ -146,9 +150,9 @@ class CuTeDSLMoERunner:
self.max_num_tokens, self.hidden_size, dtype=torch.bfloat16, device=self.device
),
})
self._padded_x_sf_buf_l1 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l1']
self._padded_x_sf_buf_l2 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l2']
self._output_buf = CuTeDSLMoERunner._shared_padded_bufs[device_key]['output']
self._padded_x_sf_buf_l1 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l1']
self._padded_x_sf_buf_l2 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l2']
self._output_buf = Nvfp4MoE._shared_padded_bufs[device_key]['output']
# Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture)
self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
@@ -162,8 +166,8 @@ class CuTeDSLMoERunner:
# Padded hidden/activated: SHARED across all runners (not per-layer)
max_rows_per_expert = self._max_chunks_per_expert * 128
padded_max_slots = self.num_experts * max_rows_per_expert
if 'hidden' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]:
CuTeDSLMoERunner._shared_padded_bufs[device_key].update({
if 'hidden' not in Nvfp4MoE._shared_padded_bufs[device_key]:
Nvfp4MoE._shared_padded_bufs[device_key].update({
'hidden': torch.zeros(
padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device
),
@@ -177,7 +181,7 @@ class CuTeDSLMoERunner:
padded_max_slots, self.intermediate_size // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2),
})
self._shared_bufs = CuTeDSLMoERunner._shared_padded_bufs[device_key]
self._shared_bufs = Nvfp4MoE._shared_padded_bufs[device_key]
# Padded expert offsets buffer: [0, max_rows, 2*max_rows, ...] (fixed)
self._padded_expert_offsets_buf = torch.zeros(
@@ -237,7 +241,7 @@ class CuTeDSLMoERunner:
# assemble_scales_3d_side expects (K_sf, N) per expert and transposes
# to (N, K_sf) internally. But our scales are already (N, K_sf) from
# the checkpoint! Skip the transpose by calling the assembly directly.
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
from dsv4.ops.layouts import (
assemble_raw_scales_2d3d_3d_side,
)
self._l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_list)
@@ -285,7 +289,13 @@ class CuTeDSLMoERunner:
# This triggers cute.compile once per shape, caching the compiled
# kernel + workspace. Subsequent run() calls hit the cache.
# MUST happen before model forward pass to avoid OOM from lazy JIT.
from cutedsl.bridge import warmup_compilation, warmup_fused_swiglu_compilation, ceil_div as bridge_ceil_div
from dsv4.ops.layouts import (
ceil_div as bridge_ceil_div,
)
from dsv4.ops.gemm_runner import (
warmup_compilation,
warmup_fused_swiglu_compilation,
)
K_packed = self.hidden_size // 2
N_packed_l1 = (2 * self.intermediate_size) // 2 # gate+up combined
N_packed_l2 = self.hidden_size // 2 # down

2
dsv4/layers/norm.py Normal file
View File

@@ -0,0 +1,2 @@
"""RMSNorm placeholder."""
# TODO: Implement RMSNorm

2
dsv4/layers/router.py Normal file
View File

@@ -0,0 +1,2 @@
"""Router: sqrt(softplus) + topk + aux-free bias + hash routing."""
# TODO: Phase 2

View File

@@ -20,14 +20,18 @@ no dynamic shapes. Padding rows are zeros that contribute nothing to GEMM output
import torch
from cutedsl.bridge import (
from dsv4.ops.quantize import (
quantize_activation_nvfp4,
quantize_to_nvfp4,
)
from dsv4.ops.layouts import (
make_b_k_major,
assemble_scales_3d_side,
)
from dsv4.ops.gemm_runner import (
run_nvfp4_grouped_gemm,
)
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
from dsv4.kernels.gemm.grouped import (
ceil_div as cutedsl_ceil_div,
pad_and_swizzle_single,
)
@@ -40,7 +44,7 @@ class _SharedExpertApply(torch.autograd.Function):
return runner._run_impl(hidden_states)
class CuTeDSLSharedExpertRunner:
class Nvfp4SharedExpert:
"""NVFP4 shared expert runner using CuTeDSL GEMM (num_groups=1).
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.

0
dsv4/loader/__init__.py Normal file
View File

View File

@@ -0,0 +1,2 @@
"""HuggingFace checkpoint reader."""
# TODO

View File

@@ -0,0 +1,2 @@
"""Checkpoint layout conversion."""
# TODO

0
dsv4/model/__init__.py Normal file
View File

2
dsv4/model/config.py Normal file
View File

@@ -0,0 +1,2 @@
"""DSV4Config (Flash + Pro)."""
# TODO: Phase 1

2
dsv4/model/dsv4.py Normal file
View File

@@ -0,0 +1,2 @@
"""Full DSV4 model."""
# TODO: Phase 1

2
dsv4/model/layer.py Normal file
View File

@@ -0,0 +1,2 @@
"""Single transformer layer."""
# TODO: Phase 1

2
dsv4/model/mtp.py Normal file
View File

@@ -0,0 +1,2 @@
"""Multi-token prediction."""
# TODO

2
dsv4/model/sampler.py Normal file
View File

@@ -0,0 +1,2 @@
"""Token sampler."""
# TODO

0
dsv4/ops/__init__.py Normal file
View File

View File

@@ -1,13 +1,4 @@
"""
Bridge layer for the CuTeDSL NVFP4 MoE kernel.
Handles tensor layout conversion from our pipeline's format to what
the ScaledGroupedGemmKernel expects:
- BF16 NVFP4 quantization (float4_e2m1fn_x2)
- Scale factor assembly (padding + swizzle)
- B tensor K-major stride conversion
- Expert offset computation
"""
"""NVFP4 GEMM runner: warmup, compile, and execute grouped/fused GEMMs."""
import math
import torch
import cutlass
@@ -15,18 +6,24 @@ import cutlass.cute as cute
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
ScaledGroupedGemmKernel,
pad_and_swizzle_single,
assemble_raw_scales_2d3d_2d_side,
assemble_raw_scales_2d3d_3d_side,
cat_byte_reinterpretable_tensors,
stack_byte_reinterpretable_tensors,
from dsv4.kernels.gemm.grouped import ScaledGroupedGemmKernel
from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel
from dsv4.ops.quantize import (
quantize_activation_nvfp4,
quantize_weight_to_nvfp4,
quantize_to_nvfp4,
deinterleave_quantize_nvfp4_cuda,
)
from dsv4.ops.layouts import (
interleave_l1_weights,
deinterleave_l1_weights,
assemble_scales_2d_side,
assemble_scales_3d_side,
make_b_k_major,
compute_expert_offsets,
ceil_div,
round_up,
)
# ── Constants ──────────────────────────────────────────────────────────
E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
# Cache compiled kernels + pre-allocated workspace by cache_key
# Each entry: {'compiled': callable, 'workspace': Tensor, 'workspace_size': int}
@@ -42,326 +39,6 @@ E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
# Caching them would hold stale references to tensors that get freed.
_compiled_kernel_cache = {}
# Cached LUT for E2M1 quantization (created once per device, cudagraph-safe)
_NVFP4_STEP_LUT_CACHE = {}
def _get_step_to_idx_lut(device):
"""Get or create the E2M1 step-to-index LUT for the given device.
Cached per device to avoid CPU->CUDA copies during cudagraph capture.
Must be pre-populated during warmup (before torch.compile/cudagraph capture)
so the lock is never entered on the compiled path.
"""
# Fast path: already cached — no lock needed (torch.compile-safe)
if device in _NVFP4_STEP_LUT_CACHE:
return _NVFP4_STEP_LUT_CACHE[device]
# Slow path: first call, create the LUT
lut = torch.as_tensor(
[0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7],
dtype=torch.int8, device=device,
)
_NVFP4_STEP_LUT_CACHE[device] = lut
return lut
SF_VEC_SIZE = 16 # NVFP4 block size
def ceil_div(a, b):
return (a + b - 1) // b
def round_up(a, b):
return ceil_div(a, b) * b
# ── Quantization ──────────────────────────────────────────────────────
def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
"""Quantize BF16 tensor to NVFP4.
Args:
x_bf16: (..., D) BF16 tensor
Returns:
x_fp4: (..., D//2) float4_e2m1fn_x2 native PyTorch FP4
x_sf: (..., D//16) float8_e4m3fn block scales
global_scale: float32 scalar
"""
x_f32 = x_bf16.float()
amax = x_f32.abs().max().clamp(min=1e-8).float()
global_scale = amax / (6.0 * 448.0)
x_norm = x_f32 / global_scale
last_dim = x_norm.shape[-1]
n_blocks = ceil_div(last_dim, block_size)
if last_dim % block_size != 0:
pad_size = n_blocks * block_size - last_dim
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
block_amax = x_reshaped.abs().amax(dim=-1)
# Detect zero blocks and underflow blocks (amax > 0 but too small for FP8).
# Smallest positive FP8 e4m3fn is 2^-9 ≈ 1.95e-3. If amax/6 < this,
# the block scale underflows to 0, and dividing x by the clamped 1e-8
# inflates values into nonzero FP4 buckets — producing wrong results.
zero_block = block_amax < (6.0 * 2.0 ** -9) # < ~0.0117
# Zero out x for zero/underflow blocks before division.
# This ensures x_scaled = 0 → FP4 nibbles = 0.
x_reshaped = torch.where(zero_block.unsqueeze(-1),
torch.zeros_like(x_reshaped), x_reshaped)
block_amax = block_amax.clamp(min=1e-8)
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
# Force zero/underflow blocks: FP8 scale = 0 (exact zero).
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
# Nearest E2M1
block_sf_expanded = block_scale.float().unsqueeze(-1)
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
signs = torch.sign(x_scaled)
abs_scaled = x_scaled.abs().clamp(max=6.0)
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
indices = step_to_idx[half_steps.long()]
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
even = nibbles[..., ::2]
odd = nibbles[..., 1::2]
packed = (odd << 4) | even
packed_shape = list(x_bf16.shape)
packed_shape[-1] = last_dim // 2
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
block_scale = block_scale.reshape(sf_shape)
return x_fp4, block_scale, global_scale
def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE):
"""Quantize BF16 activation tensor to NVFP4 (cudagraph-safe).
Unlike quantize_to_nvfp4(), this takes a pre-computed global_scale
instead of computing it via .max() (which forces CPU-GPU sync).
All operations are pure GPU with no CPU-GPU syncs.
Args:
x_bf16: (..., D) BF16 tensor
global_scale: float32 scalar (pre-computed, NOT from .max())
block_size: NVFP4 block size
Returns:
x_fp4: (..., D//2) float4_e2m1fn_x2
x_sf: (..., D//16) float8_e4m3fn
"""
x_f32 = x_bf16.float()
x_norm = x_f32 / global_scale
last_dim = x_norm.shape[-1]
n_blocks = ceil_div(last_dim, block_size)
if last_dim % block_size != 0:
pad_size = n_blocks * block_size - last_dim
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
block_amax = x_reshaped.abs().amax(dim=-1)
# Detect zero blocks and underflow blocks (same threshold as quantize_to_nvfp4).
zero_block = block_amax < (6.0 * 2.0 ** -9)
x_reshaped = torch.where(zero_block.unsqueeze(-1),
torch.zeros_like(x_reshaped), x_reshaped)
block_amax = block_amax.clamp(min=1e-8)
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
block_sf_expanded = block_scale.float().unsqueeze(-1)
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
signs = torch.sign(x_scaled)
abs_scaled = x_scaled.abs().clamp(max=6.0)
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
indices = step_to_idx[half_steps.long()]
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
even = nibbles[..., ::2]
odd = nibbles[..., 1::2]
packed = (odd << 4) | even
packed_shape = list(x_bf16.shape)
packed_shape[-1] = last_dim // 2
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
block_scale = block_scale.reshape(sf_shape)
return x_fp4, block_scale
def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE):
"""Quantize BF16 weight matrix to NVFP4.
The weight is (K, N) where K is the input dim (packed dimension).
Block scales are computed along K (dim 0).
Args:
w_bf16: (K, N) BF16 weight matrix
Returns:
w_fp4: (K//2, N) float4_e2m1fn_x2 K is the packed dim
w_sf: (K//16, N) float8_e4m3fn block scales along K
global_scale: float32 scalar
"""
K, N = w_bf16.shape
w_f32 = w_bf16.float()
amax = w_f32.abs().max().clamp(min=1e-8).float()
global_scale = amax / (6.0 * 448.0)
w_norm = w_f32 / global_scale
k_blocks = ceil_div(K, block_size)
if K % block_size != 0:
w_norm = torch.nn.functional.pad(w_norm, (0, 0, 0, k_blocks * block_size - K))
w_reshaped = w_norm.reshape(k_blocks, block_size, N)
w_block_amax = w_reshaped.abs().amax(dim=1)
# Detect zero blocks and underflow blocks (same threshold).
zero_block = w_block_amax < (6.0 * 2.0 ** -9)
w_reshaped = torch.where(zero_block.unsqueeze(1),
torch.zeros_like(w_reshaped), w_reshaped)
w_block_amax = w_block_amax.clamp(min=1e-8)
w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn)
w_sf = torch.where(zero_block, torch.zeros_like(w_sf), w_sf)
w_block_sf = w_sf.float().unsqueeze(1)
w_scaled = w_reshaped / w_block_sf.clamp(min=1e-8)
signs = torch.sign(w_scaled)
abs_scaled = w_scaled.abs().clamp(max=6.0)
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
step_to_idx = _get_step_to_idx_lut(w_bf16.device)
indices = step_to_idx[half_steps.long()]
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
even = nibbles[:, ::2, :]
odd = nibbles[:, 1::2, :]
packed = (odd << 4) | even
w_fp4 = packed.reshape(K // 2, N).view(torch.float4_e2m1fn_x2)
return w_fp4, w_sf, global_scale
# ── Scale Factor Assembly ─────────────────────────────────────────────
def interleave_l1_weights(w_ekn, granularity_bf16=8):
"""Interleave gate/up weights at granularity 8 in BF16 (4 in FP4).
The fused SwiGLU epilogue requires gate/up pairs to be adjacent in the
MMA accumulator. With interleaved weights, the MMA tile produces
gate[i*8..i*8+7] and up[i*8..i*8+7] next to each other in registers,
enabling a single-register SwiGLU without SMEM round-trips.
Before: [gate_0..gate_N/2-1 | up_0..up_N/2-1]
After: [gate_0..gate_7, up_0..up_7, gate_8..gate_15, up_8..up_15, ...]
The interleave operates along the N dimension, where each column = 1 BF16
(FP4 packing is along K, not N). So g = granularity_bf16 directly.
Args:
w_ekn: (E, K_packed, N_packed) FP4 weight tensor in K-major layout
N_packed = 2*intermediate/2 = intermediate (gate+up fused)
granularity_bf16: interleave group size in BF16 elements (default 8)
Returns:
(E, K_packed, N_packed) FP4 weight tensor with interleaved gate/up
"""
E, K, N = w_ekn.shape
N_half = N // 2 # gate and up each have N/2 FP4 columns
g = granularity_bf16 # N-axis interleave: each N-col = 1 BF16 col (packing is along K)
gate = w_ekn[:, :, :N_half].reshape(E, K, N_half // g, g)
up = w_ekn[:, :, N_half:].reshape(E, K, N_half // g, g)
return torch.stack([gate, up], dim=3).reshape(E, K, N)
def deinterleave_l1_weights(w_ekn, granularity_bf16=8):
"""De-interleave gate/up weights (inverse of interleave_l1_weights).
Used for testing/verification only.
"""
g = granularity_bf16 # N-axis: each N-col = 1 BF16 col
E, K, N = w_ekn.shape
w_reshaped = w_ekn.reshape(E, K, N // (2 * g), 2, g)
gate = w_reshaped[:, :, :, 0, :].reshape(E, K, N // 2)
up = w_reshaped[:, :, :, 1, :].reshape(E, K, N // 2)
return torch.cat([gate, up], dim=2)
def assemble_scales_2d_side(raw_scales):
"""Assemble activation scale factors for the 2Dx3D scenario.
Args:
raw_scales: list of (M_e, K_sf) float8_e4m3fn tensors, one per expert
Returns:
Assembled and swizzled scale tensor
"""
return assemble_raw_scales_2d3d_2d_side(raw_scales)
def assemble_scales_3d_side(raw_scales):
"""Assemble weight scale factors for the 2Dx3D scenario.
Args:
raw_scales: list of (K_sf, N) float8_e4m3fn tensors, one per expert
NOTE: These will be transposed to (N, K_sf) before swizzling,
since the kernel expects N as the non-K dimension.
Returns:
Assembled and swizzled scale tensor
"""
# Kernel expects (N, K_sf) — transpose before swizzling
transposed = [sf.T.contiguous() for sf in raw_scales]
return assemble_raw_scales_2d3d_3d_side(transposed)
# ── Tensor Layout Conversion ──────────────────────────────────────────
def make_b_k_major(b_tensor):
"""Convert B tensor from N-major to K-major layout.
The kernel expects B with stride (E*K*N, 1, K) K is contiguous.
torch.stack produces stride (E*K*N, N, 1) N is contiguous.
Args:
b_tensor: (experts, K_packed, N_packed) float4_e2m1fn_x2, N-major
Returns:
Same shape, K-major strides
"""
return b_tensor.permute(0, 2, 1).contiguous().permute(0, 2, 1)
def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"):
"""Compute cumulative token offsets for the grouped GEMM.
Args:
tokens_per_expert: list of int, one per expert
Returns:
offs: (num_experts,) int32 cumulative sum
"""
offs = torch.tensor(
[sum(tokens_per_expert[:e+1]) for e in range(num_experts)],
dtype=torch.int32, device=device,
)
return offs
# ── Kernel Launch ─────────────────────────────────────────────────────
def warmup_compilation(num_experts, K_packed, N_packed, device,
mma_tiler_mn=(128, 128), cluster_shape_mn=(1, 1)):
"""Eagerly JIT-compile the GEMM kernel for a specific shape.
@@ -589,10 +266,7 @@ def run_nvfp4_grouped_gemm(
# ── Fused SwiGLU GEMM (Stage 1: SiLU in registers, BF16 output) ──────
# Cache for fused kernel (separate from standard GEMM cache)
_fused_kernel_cache = {}
def warmup_fused_swiglu_compilation(num_experts, K_packed, N_packed, device,
swiglu_limit=0.0,
mma_tiler_mn=(128, 128),
@@ -602,7 +276,7 @@ def warmup_fused_swiglu_compilation(num_experts, K_packed, N_packed, device,
Must be called during model initialization. See warmup_compilation()
for the standard GEMM equivalent.
"""
from cutedsl.kernel.moe.fused_swiglu_grouped_mm import FusedSwiGLUScaledGroupedGemmKernel
from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel
cache_key = ('fused', num_experts, str(device), mma_tiler_mn, cluster_shape_mn,
K_packed, N_packed, swiglu_limit)
@@ -697,7 +371,7 @@ def run_fused_swiglu_grouped_gemm(
Stage 1: SiLU is applied to the full accumulator in registers,
then written as BF16 to C. Gate/up pairing is not yet implemented.
"""
from cutedsl.kernel.moe.fused_swiglu_grouped_mm import FusedSwiGLUScaledGroupedGemmKernel
from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel
num_experts = mat_b.shape[0]
n_dim = mat_b.shape[2]
@@ -789,28 +463,3 @@ def run_fused_swiglu_grouped_gemm(
def deinterleave_quantize_nvfp4_cuda(fused_bf16, intermediate, global_scale, granularity=8):
"""De-interleave + quantize fused SwiGLU output using a custom CUDA kernel.
Single kernel launch, no Python loop. 4x faster than the Python path.
Args:
fused_bf16: (M, 2*intermediate) BF16 fused L1 output with interleaved gate/up
intermediate: intermediate dimension (e.g., 3072)
global_scale: pre-computed global scale for quantization
granularity: interleave granularity in BF16 columns (default 8)
Returns:
x_fp4: (M, intermediate//2) float4_e2m1fn_x2 quantized SwiGLU
x_sf: (M, intermediate//16) float8_e4m3fn block scales
"""
from torch.utils.cpp_extension import load
import os
kernel_dir = os.path.join(os.path.dirname(__file__), "kernels")
mod = load(
name="deinterleave_quantize_nvfp4",
sources=[os.path.join(kernel_dir, "deinterleave_quantize.cu")],
extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"],
verbose=False,
)
return mod.deinterleave_quantize_nvfp4(fused_bf16, intermediate, granularity, global_scale)

123
dsv4/ops/layouts.py Normal file
View File

@@ -0,0 +1,123 @@
"""Tensor layout helpers: scale swizzle, gate/up interleave, K-major, offsets."""
import torch
from dsv4.kernels.gemm.grouped import (
pad_and_swizzle_single,
assemble_raw_scales_2d3d_2d_side,
assemble_raw_scales_2d3d_3d_side,
)
def ceil_div(a, b):
return (a + b - 1) // b
def round_up(a, b):
return ceil_div(a, b) * b
def interleave_l1_weights(w_ekn, granularity_bf16=8):
"""Interleave gate/up weights at granularity 8 in BF16 (4 in FP4).
The fused SwiGLU epilogue requires gate/up pairs to be adjacent in the
MMA accumulator. With interleaved weights, the MMA tile produces
gate[i*8..i*8+7] and up[i*8..i*8+7] next to each other in registers,
enabling a single-register SwiGLU without SMEM round-trips.
Before: [gate_0..gate_N/2-1 | up_0..up_N/2-1]
After: [gate_0..gate_7, up_0..up_7, gate_8..gate_15, up_8..up_15, ...]
The interleave operates along the N dimension, where each column = 1 BF16
(FP4 packing is along K, not N). So g = granularity_bf16 directly.
Args:
w_ekn: (E, K_packed, N_packed) FP4 weight tensor in K-major layout
N_packed = 2*intermediate/2 = intermediate (gate+up fused)
granularity_bf16: interleave group size in BF16 elements (default 8)
Returns:
(E, K_packed, N_packed) FP4 weight tensor with interleaved gate/up
"""
E, K, N = w_ekn.shape
N_half = N // 2 # gate and up each have N/2 FP4 columns
g = granularity_bf16 # N-axis interleave: each N-col = 1 BF16 col (packing is along K)
gate = w_ekn[:, :, :N_half].reshape(E, K, N_half // g, g)
up = w_ekn[:, :, N_half:].reshape(E, K, N_half // g, g)
return torch.stack([gate, up], dim=3).reshape(E, K, N)
def deinterleave_l1_weights(w_ekn, granularity_bf16=8):
"""De-interleave gate/up weights (inverse of interleave_l1_weights).
Used for testing/verification only.
"""
g = granularity_bf16 # N-axis: each N-col = 1 BF16 col
E, K, N = w_ekn.shape
w_reshaped = w_ekn.reshape(E, K, N // (2 * g), 2, g)
gate = w_reshaped[:, :, :, 0, :].reshape(E, K, N // 2)
up = w_reshaped[:, :, :, 1, :].reshape(E, K, N // 2)
return torch.cat([gate, up], dim=2)
def assemble_scales_2d_side(raw_scales):
"""Assemble activation scale factors for the 2Dx3D scenario.
Args:
raw_scales: list of (M_e, K_sf) float8_e4m3fn tensors, one per expert
Returns:
Assembled and swizzled scale tensor
"""
return assemble_raw_scales_2d3d_2d_side(raw_scales)
def assemble_scales_3d_side(raw_scales):
"""Assemble weight scale factors for the 2Dx3D scenario.
Args:
raw_scales: list of (K_sf, N) float8_e4m3fn tensors, one per expert
NOTE: These will be transposed to (N, K_sf) before swizzling,
since the kernel expects N as the non-K dimension.
Returns:
Assembled and swizzled scale tensor
"""
# Kernel expects (N, K_sf) — transpose before swizzling
transposed = [sf.T.contiguous() for sf in raw_scales]
return assemble_raw_scales_2d3d_3d_side(transposed)
# ── Tensor Layout Conversion ──────────────────────────────────────────
def make_b_k_major(b_tensor):
"""Convert B tensor from N-major to K-major layout.
The kernel expects B with stride (E*K*N, 1, K) — K is contiguous.
torch.stack produces stride (E*K*N, N, 1) — N is contiguous.
Args:
b_tensor: (experts, K_packed, N_packed) float4_e2m1fn_x2, N-major
Returns:
Same shape, K-major strides
"""
return b_tensor.permute(0, 2, 1).contiguous().permute(0, 2, 1)
def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"):
"""Compute cumulative token offsets for the grouped GEMM.
Args:
tokens_per_expert: list of int, one per expert
Returns:
offs: (num_experts,) int32 — cumulative sum
"""
offs = torch.tensor(
[sum(tokens_per_expert[:e+1]) for e in range(num_experts)],
dtype=torch.int32, device=device,
)
return offs
# ── Kernel Launch ─────────────────────────────────────────────────────

253
dsv4/ops/quantize.py Normal file
View File

@@ -0,0 +1,253 @@
"""NVFP4 quantization: BF16 <-> NVFP4 conversion, scale factor computation."""
import math
import torch
import cutlass
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
from dsv4.kernels.gemm.grouped import (
cat_byte_reinterpretable_tensors,
stack_byte_reinterpretable_tensors,
)
E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
# Cache compiled kernels + pre-allocated workspace by cache_key
# Each entry: {'compiled': callable, 'workspace': Tensor, 'workspace_size': int}
#
# Key design decisions (Bug #1 fix):
# - cute.compile does NOT corrupt GPU memory (verified 2026-05-20 on B200).
# The original _needs_token_refill hack was a misdiagnosis. The real bug
# was elsewhere (likely OOB write or weight loading).
# - Workspace is pre-allocated per cache entry during warmup_compilation()
# and reused on subsequent calls. No torch.full() in the hot path.
# - CuTe tensor wrappers (from_dlpack + mark_layout_dynamic) are cheap
# metadata wrappers. We re-create them per call from real tensors.
# Caching them would hold stale references to tensors that get freed.
# Cached LUT for E2M1 quantization (created once per device, cudagraph-safe)
_NVFP4_STEP_LUT_CACHE = {}
def _get_step_to_idx_lut(device):
"""Get or create the E2M1 step-to-index LUT for the given device.
Cached per device to avoid CPU->CUDA copies during cudagraph capture.
Must be pre-populated during warmup (before torch.compile/cudagraph capture)
so the lock is never entered on the compiled path.
"""
# Fast path: already cached — no lock needed (torch.compile-safe)
if device in _NVFP4_STEP_LUT_CACHE:
return _NVFP4_STEP_LUT_CACHE[device]
# Slow path: first call, create the LUT
lut = torch.as_tensor(
[0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7],
dtype=torch.int8, device=device,
)
_NVFP4_STEP_LUT_CACHE[device] = lut
return lut
SF_VEC_SIZE = 16 # NVFP4 block size
def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
"""Quantize BF16 tensor to NVFP4.
Args:
x_bf16: (..., D) BF16 tensor
Returns:
x_fp4: (..., D//2) float4_e2m1fn_x2 — native PyTorch FP4
x_sf: (..., D//16) float8_e4m3fn — block scales
global_scale: float32 scalar
"""
x_f32 = x_bf16.float()
amax = x_f32.abs().max().clamp(min=1e-8).float()
global_scale = amax / (6.0 * 448.0)
x_norm = x_f32 / global_scale
last_dim = x_norm.shape[-1]
n_blocks = ceil_div(last_dim, block_size)
if last_dim % block_size != 0:
pad_size = n_blocks * block_size - last_dim
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
block_amax = x_reshaped.abs().amax(dim=-1)
# Detect zero blocks and underflow blocks (amax > 0 but too small for FP8).
# Smallest positive FP8 e4m3fn is 2^-9 ≈ 1.95e-3. If amax/6 < this,
# the block scale underflows to 0, and dividing x by the clamped 1e-8
# inflates values into nonzero FP4 buckets — producing wrong results.
zero_block = block_amax < (6.0 * 2.0 ** -9) # < ~0.0117
# Zero out x for zero/underflow blocks before division.
# This ensures x_scaled = 0 → FP4 nibbles = 0.
x_reshaped = torch.where(zero_block.unsqueeze(-1),
torch.zeros_like(x_reshaped), x_reshaped)
block_amax = block_amax.clamp(min=1e-8)
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
# Force zero/underflow blocks: FP8 scale = 0 (exact zero).
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
# Nearest E2M1
block_sf_expanded = block_scale.float().unsqueeze(-1)
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
signs = torch.sign(x_scaled)
abs_scaled = x_scaled.abs().clamp(max=6.0)
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
indices = step_to_idx[half_steps.long()]
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
even = nibbles[..., ::2]
odd = nibbles[..., 1::2]
packed = (odd << 4) | even
packed_shape = list(x_bf16.shape)
packed_shape[-1] = last_dim // 2
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
block_scale = block_scale.reshape(sf_shape)
return x_fp4, block_scale, global_scale
def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE):
"""Quantize BF16 activation tensor to NVFP4 (cudagraph-safe).
Unlike quantize_to_nvfp4(), this takes a pre-computed global_scale
instead of computing it via .max() (which forces CPU-GPU sync).
All operations are pure GPU with no CPU-GPU syncs.
Args:
x_bf16: (..., D) BF16 tensor
global_scale: float32 scalar (pre-computed, NOT from .max())
block_size: NVFP4 block size
Returns:
x_fp4: (..., D//2) float4_e2m1fn_x2
x_sf: (..., D//16) float8_e4m3fn
"""
x_f32 = x_bf16.float()
x_norm = x_f32 / global_scale
last_dim = x_norm.shape[-1]
n_blocks = ceil_div(last_dim, block_size)
if last_dim % block_size != 0:
pad_size = n_blocks * block_size - last_dim
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
block_amax = x_reshaped.abs().amax(dim=-1)
# Detect zero blocks and underflow blocks (same threshold as quantize_to_nvfp4).
zero_block = block_amax < (6.0 * 2.0 ** -9)
x_reshaped = torch.where(zero_block.unsqueeze(-1),
torch.zeros_like(x_reshaped), x_reshaped)
block_amax = block_amax.clamp(min=1e-8)
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
block_sf_expanded = block_scale.float().unsqueeze(-1)
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
signs = torch.sign(x_scaled)
abs_scaled = x_scaled.abs().clamp(max=6.0)
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
indices = step_to_idx[half_steps.long()]
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
even = nibbles[..., ::2]
odd = nibbles[..., 1::2]
packed = (odd << 4) | even
packed_shape = list(x_bf16.shape)
packed_shape[-1] = last_dim // 2
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
block_scale = block_scale.reshape(sf_shape)
return x_fp4, block_scale
def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE):
"""Quantize BF16 weight matrix to NVFP4.
The weight is (K, N) where K is the input dim (packed dimension).
Block scales are computed along K (dim 0).
Args:
w_bf16: (K, N) BF16 weight matrix
Returns:
w_fp4: (K//2, N) float4_e2m1fn_x2 — K is the packed dim
w_sf: (K//16, N) float8_e4m3fn — block scales along K
global_scale: float32 scalar
"""
K, N = w_bf16.shape
w_f32 = w_bf16.float()
amax = w_f32.abs().max().clamp(min=1e-8).float()
global_scale = amax / (6.0 * 448.0)
w_norm = w_f32 / global_scale
k_blocks = ceil_div(K, block_size)
if K % block_size != 0:
w_norm = torch.nn.functional.pad(w_norm, (0, 0, 0, k_blocks * block_size - K))
w_reshaped = w_norm.reshape(k_blocks, block_size, N)
w_block_amax = w_reshaped.abs().amax(dim=1)
# Detect zero blocks and underflow blocks (same threshold).
zero_block = w_block_amax < (6.0 * 2.0 ** -9)
w_reshaped = torch.where(zero_block.unsqueeze(1),
torch.zeros_like(w_reshaped), w_reshaped)
w_block_amax = w_block_amax.clamp(min=1e-8)
w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn)
w_sf = torch.where(zero_block, torch.zeros_like(w_sf), w_sf)
w_block_sf = w_sf.float().unsqueeze(1)
w_scaled = w_reshaped / w_block_sf.clamp(min=1e-8)
signs = torch.sign(w_scaled)
abs_scaled = w_scaled.abs().clamp(max=6.0)
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
step_to_idx = _get_step_to_idx_lut(w_bf16.device)
indices = step_to_idx[half_steps.long()]
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
even = nibbles[:, ::2, :]
odd = nibbles[:, 1::2, :]
packed = (odd << 4) | even
w_fp4 = packed.reshape(K // 2, N).view(torch.float4_e2m1fn_x2)
return w_fp4, w_sf, global_scale
# ── Scale Factor Assembly ─────────────────────────────────────────────
def deinterleave_quantize_nvfp4_cuda(fused_bf16, intermediate, global_scale, granularity=8):
"""De-interleave + quantize fused SwiGLU output using a custom CUDA kernel.
Single kernel launch, no Python loop. 4x faster than the Python path.
Args:
fused_bf16: (M, 2*intermediate) BF16 — fused L1 output with interleaved gate/up
intermediate: intermediate dimension (e.g., 3072)
global_scale: pre-computed global scale for quantization
granularity: interleave granularity in BF16 columns (default 8)
Returns:
x_fp4: (M, intermediate//2) float4_e2m1fn_x2 — quantized SwiGLU
x_sf: (M, intermediate//16) float8_e4m3fn — block scales
"""
from torch.utils.cpp_extension import load
import os
kernel_dir = os.path.join(os.path.dirname(__file__), "kernels")
mod = load(
name="deinterleave_quantize_nvfp4",
sources=[os.path.join(kernel_dir, "deinterleave_quantize.cu")],
extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"],
verbose=False,
)
return mod.deinterleave_quantize_nvfp4(fused_bf16, intermediate, granularity, global_scale)

View File

View File

@@ -14,15 +14,19 @@ block scales in float8_e4m3fn, global scales in float32.
"""
import torch
from cutedsl.bridge import (
from dsv4.ops.quantize import (
quantize_to_nvfp4,
quantize_weight_to_nvfp4,
)
from dsv4.ops.layouts import (
assemble_scales_2d_side,
assemble_scales_3d_side,
make_b_k_major,
compute_expert_offsets,
interleave_l1_weights,
deinterleave_l1_weights,
)
from dsv4.ops.gemm_runner import (
run_nvfp4_grouped_gemm,
run_fused_swiglu_grouped_gemm,
warmup_fused_swiglu_compilation,
@@ -198,7 +202,7 @@ def run_nvfp4_moe(
sf_ekn = sf.unsqueeze(0) # (1, K_sf, N)
sf_ekn = interleave_l1_weights(sf_ekn) # interleaved along N
l1_sf_il.append(sf_ekn[0].T.contiguous()) # (N, K_sf) for assembly
from cutedsl.kernel.moe.torch_scaled_grouped_mm import assemble_raw_scales_2d3d_3d_side as _assemble_3d
from dsv4.kernels.gemm.grouped import assemble_raw_scales_2d3d_3d_side as _assemble_3d
l1_scale_b = _assemble_3d(l1_sf_il)
# Global scales: alpha = igs * weight_gs for each expert
@@ -347,7 +351,7 @@ def run_nvfp4_moe_fused(
sf_ekn = sf.unsqueeze(0)
sf_ekn = interleave_l1_weights(sf_ekn)
l1_sf_il.append(sf_ekn[0].T.contiguous())
from cutedsl.kernel.moe.torch_scaled_grouped_mm import assemble_raw_scales_2d3d_3d_side as _assemble_3d
from dsv4.kernels.gemm.grouped import assemble_raw_scales_2d3d_3d_side as _assemble_3d
l1_scale_b = _assemble_3d(l1_sf_il)
l1_global_scale_a = torch.tensor([x_igs] * num_experts, dtype=torch.float32, device=device)
@@ -368,7 +372,10 @@ def run_nvfp4_moe_fused(
intermediate_size = l1_fused_out.shape[1] // 2
# Use pre-computed L2 activation gs, or compute from amax (fallback)
l2_gs = l2_activation_gs if l2_activation_gs is not None else l1_fused_out.abs().amax().float().item() / 2688.0
from cutedsl.bridge import deinterleave_quantize_nvfp4_cuda, quantize_activation_nvfp4
from dsv4.ops.quantize import (
deinterleave_quantize_nvfp4_cuda,
quantize_activation_nvfp4,
)
l2_x_fp4, l2_x_sf = deinterleave_quantize_nvfp4_cuda(l1_fused_out, intermediate_size, l2_gs)
# Skip the separate L2 quantize step below — we already have FP4+SF
# Set activated to None to signal we already quantized

View File

@@ -3,7 +3,7 @@ requires = ["setuptools>=68.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "nvfp4-megamoe-kernel"
name = "dsv4-inference"
version = "0.1.0"
description = "NVFP4 Mega MoE kernel for DeepSeek-V4-Pro on Blackwell (TileLang)"
requires-python = ">=3.10"
@@ -13,3 +13,4 @@ dependencies = [
[tool.setuptools.packages.find]
where = ["."]
include = ["dsv4*"]

View File

@@ -6,7 +6,7 @@ sys.path.insert(0, '/root/nvfp4-megamoe-kernel/cutedsl')
sys.path.insert(0, '/root/nvfp4-megamoe-kernel/vllm')
from cutedsl.reference.moe_pipeline import moe_pipeline
from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
from vllm.nvfp4_cutedsl import Nvfp4MoE
torch.cuda.set_device(0)
@@ -33,7 +33,7 @@ ref_out = moe_pipeline(
print(f"Reference output: amax={ref_out.amax().item():.4f} mean={ref_out.mean().item():.4f}")
# Run runner with warmup gs
runner = CuTeDSLMoERunner(
runner = Nvfp4MoE(
num_experts=3, hidden_size=256, intermediate_size=512,
max_num_tokens=4, top_k=2, device='cuda'
)

View File

@@ -51,14 +51,14 @@ def dequant_nvfp4(packed_uint8, scale_e4m3, global_scale):
def test_projection(name, weight, weight_sf, weight_gs, hidden_states, in_features, out_features):
"""Test a single NVFP4 projection."""
sys.path.insert(0, "/root/nvfp4-megamoe-kernel")
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
from dsv4.layers.linear import Nvfp4Linear
# Convert weight to CuTeDSL format: (out, in_packed) uint8 → (in_packed, out) float4
fp4 = [weight.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()]
sf = [weight_sf.permute(1, 0).contiguous()]
gs = [weight_gs]
runner = CuTeDSLNvfp4Linear(
runner = Nvfp4Linear(
in_features=in_features,
out_features=out_features,
max_num_tokens=8192,

View File

@@ -55,7 +55,7 @@ def rms(x, w, eps=1e-6):
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
from dsv4.layers.linear import Nvfp4Linear
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
s = s.permute(1,0).contiguous()
@@ -66,7 +66,7 @@ def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
else:
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
r = Nvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
r.finalize_weights(); r._ensure_initialized()
return r

Some files were not shown because too many files have changed in this diff Show More