Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py - Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc. - Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda) - Moved PyTorch bridges to dsv4/ops/ - Moved nn.Module layers to dsv4layers/ - Moved reference implementations to dsv4/reference/ - Moved vendored CUTLASS code to vendored/ - Archived ~190 debug tests to tests/archive/ - Kept ~15 canonical tests in tests/unit/ - Updated all import paths - Added stubs for future components (model/, cache/, loader/) - Updated pyproject.toml: dsv4-inference package name
This commit is contained in:
@@ -1 +0,0 @@
|
||||
/root/dsv4-nvfp4-workspace/kernel/cutedsl
|
||||
2
dsv4/cache/block_table.py
vendored
Normal file
2
dsv4/cache/block_table.py
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Block table for paged KV cache."""
|
||||
# TODO: Phase 3
|
||||
2
dsv4/cache/paged_cache.py
vendored
Normal file
2
dsv4/cache/paged_cache.py
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Paged KV cache."""
|
||||
# TODO: Phase 3
|
||||
2
dsv4/cache/state_cache.py
vendored
Normal file
2
dsv4/cache/state_cache.py
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
"""State cache for KV."""
|
||||
# TODO: Phase 3
|
||||
2
dsv4/kernels/attention/fmha.py
Normal file
2
dsv4/kernels/attention/fmha.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Stage B+). Extracted from test_fmha_v3.py."""
|
||||
# TODO: Extract FmhaV3 kernel class here
|
||||
0
dsv4/kernels/compressor/__init__.py
Normal file
0
dsv4/kernels/compressor/__init__.py
Normal file
0
dsv4/kernels/cuda/__init__.py
Normal file
0
dsv4/kernels/cuda/__init__.py
Normal file
@@ -1,4 +1,4 @@
|
||||
"""
|
||||
|
||||
FP8 E4M3 -> BF16 conversion for CuTeDSL on Blackwell (SM100+).
|
||||
|
||||
STATUS: NOT USABLE INSIDE CUTE KERNELS.
|
||||
@@ -23,4 +23,4 @@ or when we can properly construct vector<4xf8E4M3FN> inside kernel code,
|
||||
we can fuse the dequant into the attention kernel. The PTX instruction
|
||||
exists (cvt.rn.bf16x2.e4m3x2), but CuTeDSL's AST preprocessor currently
|
||||
prevents us from injecting the necessary MLIR ops.
|
||||
"""
|
||||
|
||||
0
dsv4/kernels/decode/__init__.py
Normal file
0
dsv4/kernels/decode/__init__.py
Normal file
0
dsv4/kernels/gemm/__init__.py
Normal file
0
dsv4/kernels/gemm/__init__.py
Normal file
@@ -60,15 +60,15 @@ if __name__ == "__main__":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../../.."))
|
||||
|
||||
from cutedsl.kernel.moe.moe_utils import (
|
||||
from dsv4.kernels.gemm.utils import (
|
||||
MoEScaledGroupedGemmTensormapConstructor,
|
||||
)
|
||||
from cutedsl.kernel.moe.moe_persistent_scheduler import (
|
||||
from dsv4.kernels.gemm.scheduler import (
|
||||
MoEStaticSchedulerParams,
|
||||
MoEStaticPersistentTileScheduler,
|
||||
MoEWorkTileInfo,
|
||||
)
|
||||
from cutedsl.kernel.moe.moe_sched_extension import ScaledGroupedMmSchedExtension
|
||||
from dsv4.kernels.gemm.sched_extension import ScaledGroupedMmSchedExtension
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
||||
from cutlass.utils.gemm.sm100 import (
|
||||
@@ -3665,7 +3665,7 @@ class ScaledGroupedGemmTester:
|
||||
if _examples_root not in sys.path:
|
||||
sys.path.insert(0, _examples_root)
|
||||
|
||||
from cutedsl.kernel.blockscaled_gemm.dense_blockscaled_gemm_persistent import (
|
||||
from dsv4.kernels.gemm.dense import (
|
||||
Sm100BlockScaledPersistentDenseGemmKernel,
|
||||
)
|
||||
from cutlass.cute.nvgpu import OperandMajorMode
|
||||
@@ -60,15 +60,15 @@ if __name__ == "__main__":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../../.."))
|
||||
|
||||
from cutedsl.kernel.moe.moe_utils import (
|
||||
from dsv4.kernels.gemm.utils import (
|
||||
MoEScaledGroupedGemmTensormapConstructor,
|
||||
)
|
||||
from cutedsl.kernel.moe.moe_persistent_scheduler import (
|
||||
from dsv4.kernels.gemm.scheduler import (
|
||||
MoEStaticSchedulerParams,
|
||||
MoEStaticPersistentTileScheduler,
|
||||
MoEWorkTileInfo,
|
||||
)
|
||||
from cutedsl.kernel.moe.moe_sched_extension import ScaledGroupedMmSchedExtension
|
||||
from dsv4.kernels.gemm.sched_extension import ScaledGroupedMmSchedExtension
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
||||
from cutlass.utils.gemm.sm100 import (
|
||||
@@ -3608,7 +3608,7 @@ class ScaledGroupedGemmTester:
|
||||
if _examples_root not in sys.path:
|
||||
sys.path.insert(0, _examples_root)
|
||||
|
||||
from cutedsl.kernel.blockscaled_gemm.dense_blockscaled_gemm_persistent import (
|
||||
from dsv4.kernels.gemm.dense import (
|
||||
Sm100BlockScaledPersistentDenseGemmKernel,
|
||||
)
|
||||
from cutlass.cute.nvgpu import OperandMajorMode
|
||||
@@ -73,14 +73,14 @@ from cutlass.cutlass_dsl import Int32
|
||||
from dataclasses import dataclass
|
||||
|
||||
from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF
|
||||
from cutedsl.kernel.moe.moe_utils import (
|
||||
from dsv4.kernels.gemm.utils import (
|
||||
OnlineTensormapDescCreator,
|
||||
tensormap_ptr_for_copy,
|
||||
compute_expert_token_range,
|
||||
rewrite_tensor_shape,
|
||||
prefetch_tma_descriptor,
|
||||
)
|
||||
from cutedsl.kernel.moe.moe_persistent_scheduler import MoEWorkTileInfo
|
||||
from dsv4.kernels.gemm.scheduler import MoEWorkTileInfo
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
0
dsv4/layers/__init__.py
Normal file
0
dsv4/layers/__init__.py
Normal file
2
dsv4/layers/attention.py
Normal file
2
dsv4/layers/attention.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""DSV4 attention sub-block."""
|
||||
# TODO: Phase 3+4
|
||||
2
dsv4/layers/embedding.py
Normal file
2
dsv4/layers/embedding.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Token embedding + mHC init wrapper."""
|
||||
# TODO: Implement
|
||||
2
dsv4/layers/ffn.py
Normal file
2
dsv4/layers/ffn.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""FFN: router + MoE + shared expert."""
|
||||
# TODO: Phase 2
|
||||
@@ -14,22 +14,26 @@ CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
|
||||
|
||||
import torch
|
||||
|
||||
from cutedsl.bridge import (
|
||||
from dsv4.ops.quantize import (
|
||||
quantize_activation_nvfp4,
|
||||
quantize_weight_to_nvfp4,
|
||||
)
|
||||
from dsv4.ops.layouts import (
|
||||
make_b_k_major,
|
||||
assemble_scales_2d_side,
|
||||
assemble_scales_3d_side,
|
||||
)
|
||||
from dsv4.ops.gemm_runner import (
|
||||
run_nvfp4_grouped_gemm,
|
||||
)
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||
from dsv4.ops.layouts import (
|
||||
ceil_div as cutedsl_ceil_div,
|
||||
pad_and_swizzle_single,
|
||||
)
|
||||
from cutedsl.custom_ops import register_runner, nvfp4_linear_gemm
|
||||
from dsv4.ops.custom_ops import register_runner, nvfp4_linear_gemm
|
||||
|
||||
|
||||
class CuTeDSLNvfp4WoA:
|
||||
class Nvfp4GroupedLinear:
|
||||
"""Grouped NVFP4 linear for wo_a (o-projection first half).
|
||||
|
||||
Handles the "bhr,hdr->bhd" einsum pattern:
|
||||
@@ -181,7 +185,9 @@ class CuTeDSLNvfp4WoA:
|
||||
# Reshape to grouped format, then flatten to 2D for quantization
|
||||
o_grouped = o_sample.reshape(-1, self.n_local_groups, self.group_in_features)
|
||||
# We need a single gs for all groups — use the overall amax
|
||||
from cutedsl.bridge import quantize_to_nvfp4
|
||||
from dsv4.ops.quantize import (
|
||||
quantize_to_nvfp4,
|
||||
)
|
||||
o_flat = o_sample.reshape(-1, o_sample.shape[-1]) # (tokens, n_local_heads * head_dim) — not right
|
||||
# Actually, for grouped GEMM, each group's activation is (tokens, group_in_features)
|
||||
# The global scale should be computed per-group, but for simplicity use one scale
|
||||
@@ -256,7 +262,9 @@ class CuTeDSLNvfp4WoA:
|
||||
# Assemble A-side scales for all groups
|
||||
# The grouped GEMM expects scales for all groups assembled together
|
||||
# For 2Dx3D scenario, scale_a is assembled from per-group scale tensors
|
||||
from cutedsl.bridge import assemble_scales_2d_side
|
||||
from dsv4.ops.layouts import (
|
||||
assemble_scales_2d_side,
|
||||
)
|
||||
scale_a = assemble_scales_2d_side(all_x_sf)
|
||||
|
||||
# Expert offsets: cumulative [padded_T, 2*padded_T, ..., n_groups*padded_T]
|
||||
@@ -8,21 +8,25 @@ CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
|
||||
|
||||
import torch
|
||||
|
||||
from cutedsl.bridge import (
|
||||
from dsv4.ops.quantize import (
|
||||
quantize_activation_nvfp4,
|
||||
quantize_to_nvfp4,
|
||||
)
|
||||
from dsv4.ops.layouts import (
|
||||
make_b_k_major,
|
||||
assemble_scales_3d_side,
|
||||
)
|
||||
from dsv4.ops.gemm_runner import (
|
||||
run_nvfp4_grouped_gemm,
|
||||
)
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||
from dsv4.kernels.gemm.grouped import (
|
||||
ceil_div as cutedsl_ceil_div,
|
||||
pad_and_swizzle_single,
|
||||
)
|
||||
from cutedsl.custom_ops import register_runner, nvfp4_linear_gemm
|
||||
from dsv4.ops.custom_ops import register_runner, nvfp4_linear_gemm
|
||||
|
||||
|
||||
class CuTeDSLNvfp4Linear:
|
||||
class Nvfp4Linear:
|
||||
"""Single NVFP4 GEMM using CuTeDSL (num_groups=1).
|
||||
|
||||
Handles any (K, N) weight matrix in NVFP4 format.
|
||||
@@ -76,7 +80,6 @@ class CuTeDSLNvfp4Linear:
|
||||
|
||||
# Eagerly JIT-compile the GEMM kernel for this (K, N) shape.
|
||||
# Uses num_groups=1 since this is a single linear layer.
|
||||
# from cutedsl.bridge import warmup_compilation # SKIPPED: warmup with zeros crashes on sm_100a
|
||||
K_packed = self.in_features // 2
|
||||
N_packed = self.out_features // 2
|
||||
# warmup_compilation(1, K_packed, N_packed, self.device) # Lazy compile on first real forward
|
||||
@@ -15,26 +15,30 @@ processes max_slots = budget * top_k rows; padding rows are zeros.
|
||||
"""
|
||||
import torch
|
||||
|
||||
from cutedsl.bridge import (
|
||||
from dsv4.ops.quantize import (
|
||||
quantize_activation_nvfp4,
|
||||
quantize_weight_to_nvfp4,
|
||||
quantize_to_nvfp4,
|
||||
)
|
||||
from dsv4.ops.layouts import (
|
||||
make_b_k_major,
|
||||
assemble_scales_3d_side,
|
||||
interleave_l1_weights,
|
||||
deinterleave_l1_weights,
|
||||
)
|
||||
from dsv4.ops.gemm_runner import (
|
||||
run_nvfp4_grouped_gemm,
|
||||
run_fused_swiglu_grouped_gemm,
|
||||
warmup_fused_swiglu_compilation,
|
||||
)
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||
from dsv4.ops.layouts import (
|
||||
ceil_div as cutedsl_ceil_div,
|
||||
pad_and_swizzle_single,
|
||||
)
|
||||
from cutedsl.custom_ops import register_runner, nvfp4_moe_gemm
|
||||
from dsv4.ops.custom_ops import register_runner, nvfp4_moe_gemm
|
||||
|
||||
|
||||
class CuTeDSLMoERunner:
|
||||
class Nvfp4MoE:
|
||||
"""Manages NVFP4 MoE execution via the CuTeDSL kernel.
|
||||
|
||||
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs,
|
||||
@@ -127,15 +131,15 @@ class CuTeDSLMoERunner:
|
||||
|
||||
# Initialize shared buffers dict (if not already)
|
||||
device_key = str(self.device)
|
||||
if not hasattr(CuTeDSLMoERunner, '_shared_padded_bufs'):
|
||||
CuTeDSLMoERunner._shared_padded_bufs = {}
|
||||
if device_key not in CuTeDSLMoERunner._shared_padded_bufs:
|
||||
CuTeDSLMoERunner._shared_padded_bufs[device_key] = {}
|
||||
if not hasattr(Nvfp4MoE, '_shared_padded_bufs'):
|
||||
Nvfp4MoE._shared_padded_bufs = {}
|
||||
if device_key not in Nvfp4MoE._shared_padded_bufs:
|
||||
Nvfp4MoE._shared_padded_bufs[device_key] = {}
|
||||
|
||||
# Padded x_sf buffers: SHARED across all runners (not per-layer)
|
||||
max_sf_rows = self.num_experts * self._max_chunks_per_expert * 128
|
||||
if 'xsf_l1' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]:
|
||||
CuTeDSLMoERunner._shared_padded_bufs[device_key].update({
|
||||
if 'xsf_l1' not in Nvfp4MoE._shared_padded_bufs[device_key]:
|
||||
Nvfp4MoE._shared_padded_bufs[device_key].update({
|
||||
'xsf_l1': torch.zeros(
|
||||
max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn),
|
||||
@@ -146,9 +150,9 @@ class CuTeDSLMoERunner:
|
||||
self.max_num_tokens, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
||||
),
|
||||
})
|
||||
self._padded_x_sf_buf_l1 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l1']
|
||||
self._padded_x_sf_buf_l2 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l2']
|
||||
self._output_buf = CuTeDSLMoERunner._shared_padded_bufs[device_key]['output']
|
||||
self._padded_x_sf_buf_l1 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l1']
|
||||
self._padded_x_sf_buf_l2 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l2']
|
||||
self._output_buf = Nvfp4MoE._shared_padded_bufs[device_key]['output']
|
||||
|
||||
# Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture)
|
||||
self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
|
||||
@@ -162,8 +166,8 @@ class CuTeDSLMoERunner:
|
||||
# Padded hidden/activated: SHARED across all runners (not per-layer)
|
||||
max_rows_per_expert = self._max_chunks_per_expert * 128
|
||||
padded_max_slots = self.num_experts * max_rows_per_expert
|
||||
if 'hidden' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]:
|
||||
CuTeDSLMoERunner._shared_padded_bufs[device_key].update({
|
||||
if 'hidden' not in Nvfp4MoE._shared_padded_bufs[device_key]:
|
||||
Nvfp4MoE._shared_padded_bufs[device_key].update({
|
||||
'hidden': torch.zeros(
|
||||
padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
||||
),
|
||||
@@ -177,7 +181,7 @@ class CuTeDSLMoERunner:
|
||||
padded_max_slots, self.intermediate_size // 2, dtype=torch.uint8, device=self.device
|
||||
).view(torch.float4_e2m1fn_x2),
|
||||
})
|
||||
self._shared_bufs = CuTeDSLMoERunner._shared_padded_bufs[device_key]
|
||||
self._shared_bufs = Nvfp4MoE._shared_padded_bufs[device_key]
|
||||
|
||||
# Padded expert offsets buffer: [0, max_rows, 2*max_rows, ...] (fixed)
|
||||
self._padded_expert_offsets_buf = torch.zeros(
|
||||
@@ -237,7 +241,7 @@ class CuTeDSLMoERunner:
|
||||
# assemble_scales_3d_side expects (K_sf, N) per expert and transposes
|
||||
# to (N, K_sf) internally. But our scales are already (N, K_sf) from
|
||||
# the checkpoint! Skip the transpose by calling the assembly directly.
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||
from dsv4.ops.layouts import (
|
||||
assemble_raw_scales_2d3d_3d_side,
|
||||
)
|
||||
self._l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_list)
|
||||
@@ -285,7 +289,13 @@ class CuTeDSLMoERunner:
|
||||
# This triggers cute.compile once per shape, caching the compiled
|
||||
# kernel + workspace. Subsequent run() calls hit the cache.
|
||||
# MUST happen before model forward pass to avoid OOM from lazy JIT.
|
||||
from cutedsl.bridge import warmup_compilation, warmup_fused_swiglu_compilation, ceil_div as bridge_ceil_div
|
||||
from dsv4.ops.layouts import (
|
||||
ceil_div as bridge_ceil_div,
|
||||
)
|
||||
from dsv4.ops.gemm_runner import (
|
||||
warmup_compilation,
|
||||
warmup_fused_swiglu_compilation,
|
||||
)
|
||||
K_packed = self.hidden_size // 2
|
||||
N_packed_l1 = (2 * self.intermediate_size) // 2 # gate+up combined
|
||||
N_packed_l2 = self.hidden_size // 2 # down
|
||||
2
dsv4/layers/norm.py
Normal file
2
dsv4/layers/norm.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""RMSNorm placeholder."""
|
||||
# TODO: Implement RMSNorm
|
||||
2
dsv4/layers/router.py
Normal file
2
dsv4/layers/router.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Router: sqrt(softplus) + topk + aux-free bias + hash routing."""
|
||||
# TODO: Phase 2
|
||||
@@ -20,14 +20,18 @@ no dynamic shapes. Padding rows are zeros that contribute nothing to GEMM output
|
||||
|
||||
import torch
|
||||
|
||||
from cutedsl.bridge import (
|
||||
from dsv4.ops.quantize import (
|
||||
quantize_activation_nvfp4,
|
||||
quantize_to_nvfp4,
|
||||
)
|
||||
from dsv4.ops.layouts import (
|
||||
make_b_k_major,
|
||||
assemble_scales_3d_side,
|
||||
)
|
||||
from dsv4.ops.gemm_runner import (
|
||||
run_nvfp4_grouped_gemm,
|
||||
)
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||
from dsv4.kernels.gemm.grouped import (
|
||||
ceil_div as cutedsl_ceil_div,
|
||||
pad_and_swizzle_single,
|
||||
)
|
||||
@@ -40,7 +44,7 @@ class _SharedExpertApply(torch.autograd.Function):
|
||||
return runner._run_impl(hidden_states)
|
||||
|
||||
|
||||
class CuTeDSLSharedExpertRunner:
|
||||
class Nvfp4SharedExpert:
|
||||
"""NVFP4 shared expert runner using CuTeDSL GEMM (num_groups=1).
|
||||
|
||||
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
|
||||
0
dsv4/loader/__init__.py
Normal file
0
dsv4/loader/__init__.py
Normal file
2
dsv4/loader/hf_checkpoint.py
Normal file
2
dsv4/loader/hf_checkpoint.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""HuggingFace checkpoint reader."""
|
||||
# TODO
|
||||
2
dsv4/loader/layout_convert.py
Normal file
2
dsv4/loader/layout_convert.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Checkpoint layout conversion."""
|
||||
# TODO
|
||||
0
dsv4/model/__init__.py
Normal file
0
dsv4/model/__init__.py
Normal file
2
dsv4/model/config.py
Normal file
2
dsv4/model/config.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""DSV4Config (Flash + Pro)."""
|
||||
# TODO: Phase 1
|
||||
2
dsv4/model/dsv4.py
Normal file
2
dsv4/model/dsv4.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Full DSV4 model."""
|
||||
# TODO: Phase 1
|
||||
2
dsv4/model/layer.py
Normal file
2
dsv4/model/layer.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Single transformer layer."""
|
||||
# TODO: Phase 1
|
||||
2
dsv4/model/mtp.py
Normal file
2
dsv4/model/mtp.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Multi-token prediction."""
|
||||
# TODO
|
||||
2
dsv4/model/sampler.py
Normal file
2
dsv4/model/sampler.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Token sampler."""
|
||||
# TODO
|
||||
0
dsv4/ops/__init__.py
Normal file
0
dsv4/ops/__init__.py
Normal file
@@ -1,13 +1,4 @@
|
||||
"""
|
||||
Bridge layer for the CuTeDSL NVFP4 MoE kernel.
|
||||
|
||||
Handles tensor layout conversion from our pipeline's format to what
|
||||
the ScaledGroupedGemmKernel expects:
|
||||
- BF16 → NVFP4 quantization (float4_e2m1fn_x2)
|
||||
- Scale factor assembly (padding + swizzle)
|
||||
- B tensor K-major stride conversion
|
||||
- Expert offset computation
|
||||
"""
|
||||
"""NVFP4 GEMM runner: warmup, compile, and execute grouped/fused GEMMs."""
|
||||
import math
|
||||
import torch
|
||||
import cutlass
|
||||
@@ -15,18 +6,24 @@ import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils as utils
|
||||
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||
ScaledGroupedGemmKernel,
|
||||
pad_and_swizzle_single,
|
||||
assemble_raw_scales_2d3d_2d_side,
|
||||
assemble_raw_scales_2d3d_3d_side,
|
||||
cat_byte_reinterpretable_tensors,
|
||||
stack_byte_reinterpretable_tensors,
|
||||
from dsv4.kernels.gemm.grouped import ScaledGroupedGemmKernel
|
||||
from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel
|
||||
from dsv4.ops.quantize import (
|
||||
quantize_activation_nvfp4,
|
||||
quantize_weight_to_nvfp4,
|
||||
quantize_to_nvfp4,
|
||||
deinterleave_quantize_nvfp4_cuda,
|
||||
)
|
||||
from dsv4.ops.layouts import (
|
||||
interleave_l1_weights,
|
||||
deinterleave_l1_weights,
|
||||
assemble_scales_2d_side,
|
||||
assemble_scales_3d_side,
|
||||
make_b_k_major,
|
||||
compute_expert_offsets,
|
||||
ceil_div,
|
||||
round_up,
|
||||
)
|
||||
|
||||
# ── Constants ──────────────────────────────────────────────────────────
|
||||
|
||||
E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
||||
|
||||
# Cache compiled kernels + pre-allocated workspace by cache_key
|
||||
# Each entry: {'compiled': callable, 'workspace': Tensor, 'workspace_size': int}
|
||||
@@ -42,326 +39,6 @@ E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
||||
# Caching them would hold stale references to tensors that get freed.
|
||||
_compiled_kernel_cache = {}
|
||||
|
||||
# Cached LUT for E2M1 quantization (created once per device, cudagraph-safe)
|
||||
_NVFP4_STEP_LUT_CACHE = {}
|
||||
def _get_step_to_idx_lut(device):
|
||||
"""Get or create the E2M1 step-to-index LUT for the given device.
|
||||
|
||||
Cached per device to avoid CPU->CUDA copies during cudagraph capture.
|
||||
Must be pre-populated during warmup (before torch.compile/cudagraph capture)
|
||||
so the lock is never entered on the compiled path.
|
||||
"""
|
||||
# Fast path: already cached — no lock needed (torch.compile-safe)
|
||||
if device in _NVFP4_STEP_LUT_CACHE:
|
||||
return _NVFP4_STEP_LUT_CACHE[device]
|
||||
# Slow path: first call, create the LUT
|
||||
lut = torch.as_tensor(
|
||||
[0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7],
|
||||
dtype=torch.int8, device=device,
|
||||
)
|
||||
_NVFP4_STEP_LUT_CACHE[device] = lut
|
||||
return lut
|
||||
SF_VEC_SIZE = 16 # NVFP4 block size
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def round_up(a, b):
|
||||
return ceil_div(a, b) * b
|
||||
|
||||
|
||||
# ── Quantization ──────────────────────────────────────────────────────
|
||||
|
||||
def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
|
||||
"""Quantize BF16 tensor to NVFP4.
|
||||
|
||||
Args:
|
||||
x_bf16: (..., D) BF16 tensor
|
||||
|
||||
Returns:
|
||||
x_fp4: (..., D//2) float4_e2m1fn_x2 — native PyTorch FP4
|
||||
x_sf: (..., D//16) float8_e4m3fn — block scales
|
||||
global_scale: float32 scalar
|
||||
"""
|
||||
x_f32 = x_bf16.float()
|
||||
amax = x_f32.abs().max().clamp(min=1e-8).float()
|
||||
global_scale = amax / (6.0 * 448.0)
|
||||
x_norm = x_f32 / global_scale
|
||||
|
||||
last_dim = x_norm.shape[-1]
|
||||
n_blocks = ceil_div(last_dim, block_size)
|
||||
|
||||
if last_dim % block_size != 0:
|
||||
pad_size = n_blocks * block_size - last_dim
|
||||
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
|
||||
|
||||
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
|
||||
block_amax = x_reshaped.abs().amax(dim=-1)
|
||||
# Detect zero blocks and underflow blocks (amax > 0 but too small for FP8).
|
||||
# Smallest positive FP8 e4m3fn is 2^-9 ≈ 1.95e-3. If amax/6 < this,
|
||||
# the block scale underflows to 0, and dividing x by the clamped 1e-8
|
||||
# inflates values into nonzero FP4 buckets — producing wrong results.
|
||||
zero_block = block_amax < (6.0 * 2.0 ** -9) # < ~0.0117
|
||||
# Zero out x for zero/underflow blocks before division.
|
||||
# This ensures x_scaled = 0 → FP4 nibbles = 0.
|
||||
x_reshaped = torch.where(zero_block.unsqueeze(-1),
|
||||
torch.zeros_like(x_reshaped), x_reshaped)
|
||||
block_amax = block_amax.clamp(min=1e-8)
|
||||
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
|
||||
# Force zero/underflow blocks: FP8 scale = 0 (exact zero).
|
||||
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
|
||||
|
||||
# Nearest E2M1
|
||||
block_sf_expanded = block_scale.float().unsqueeze(-1)
|
||||
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
|
||||
|
||||
signs = torch.sign(x_scaled)
|
||||
abs_scaled = x_scaled.abs().clamp(max=6.0)
|
||||
|
||||
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
|
||||
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
|
||||
indices = step_to_idx[half_steps.long()]
|
||||
|
||||
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
||||
even = nibbles[..., ::2]
|
||||
odd = nibbles[..., 1::2]
|
||||
packed = (odd << 4) | even
|
||||
|
||||
packed_shape = list(x_bf16.shape)
|
||||
packed_shape[-1] = last_dim // 2
|
||||
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
|
||||
|
||||
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
|
||||
block_scale = block_scale.reshape(sf_shape)
|
||||
|
||||
return x_fp4, block_scale, global_scale
|
||||
|
||||
|
||||
def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE):
|
||||
"""Quantize BF16 activation tensor to NVFP4 (cudagraph-safe).
|
||||
|
||||
Unlike quantize_to_nvfp4(), this takes a pre-computed global_scale
|
||||
instead of computing it via .max() (which forces CPU-GPU sync).
|
||||
All operations are pure GPU with no CPU-GPU syncs.
|
||||
|
||||
Args:
|
||||
x_bf16: (..., D) BF16 tensor
|
||||
global_scale: float32 scalar (pre-computed, NOT from .max())
|
||||
block_size: NVFP4 block size
|
||||
|
||||
Returns:
|
||||
x_fp4: (..., D//2) float4_e2m1fn_x2
|
||||
x_sf: (..., D//16) float8_e4m3fn
|
||||
"""
|
||||
x_f32 = x_bf16.float()
|
||||
x_norm = x_f32 / global_scale
|
||||
|
||||
last_dim = x_norm.shape[-1]
|
||||
n_blocks = ceil_div(last_dim, block_size)
|
||||
|
||||
if last_dim % block_size != 0:
|
||||
pad_size = n_blocks * block_size - last_dim
|
||||
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
|
||||
|
||||
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
|
||||
block_amax = x_reshaped.abs().amax(dim=-1)
|
||||
# Detect zero blocks and underflow blocks (same threshold as quantize_to_nvfp4).
|
||||
zero_block = block_amax < (6.0 * 2.0 ** -9)
|
||||
x_reshaped = torch.where(zero_block.unsqueeze(-1),
|
||||
torch.zeros_like(x_reshaped), x_reshaped)
|
||||
block_amax = block_amax.clamp(min=1e-8)
|
||||
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
|
||||
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
|
||||
|
||||
block_sf_expanded = block_scale.float().unsqueeze(-1)
|
||||
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
|
||||
signs = torch.sign(x_scaled)
|
||||
abs_scaled = x_scaled.abs().clamp(max=6.0)
|
||||
|
||||
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
|
||||
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
|
||||
indices = step_to_idx[half_steps.long()]
|
||||
|
||||
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
||||
even = nibbles[..., ::2]
|
||||
odd = nibbles[..., 1::2]
|
||||
packed = (odd << 4) | even
|
||||
|
||||
packed_shape = list(x_bf16.shape)
|
||||
packed_shape[-1] = last_dim // 2
|
||||
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
|
||||
|
||||
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
|
||||
block_scale = block_scale.reshape(sf_shape)
|
||||
|
||||
return x_fp4, block_scale
|
||||
|
||||
|
||||
def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE):
|
||||
"""Quantize BF16 weight matrix to NVFP4.
|
||||
|
||||
The weight is (K, N) where K is the input dim (packed dimension).
|
||||
Block scales are computed along K (dim 0).
|
||||
|
||||
Args:
|
||||
w_bf16: (K, N) BF16 weight matrix
|
||||
|
||||
Returns:
|
||||
w_fp4: (K//2, N) float4_e2m1fn_x2 — K is the packed dim
|
||||
w_sf: (K//16, N) float8_e4m3fn — block scales along K
|
||||
global_scale: float32 scalar
|
||||
"""
|
||||
K, N = w_bf16.shape
|
||||
w_f32 = w_bf16.float()
|
||||
amax = w_f32.abs().max().clamp(min=1e-8).float()
|
||||
global_scale = amax / (6.0 * 448.0)
|
||||
w_norm = w_f32 / global_scale
|
||||
|
||||
k_blocks = ceil_div(K, block_size)
|
||||
if K % block_size != 0:
|
||||
w_norm = torch.nn.functional.pad(w_norm, (0, 0, 0, k_blocks * block_size - K))
|
||||
|
||||
w_reshaped = w_norm.reshape(k_blocks, block_size, N)
|
||||
w_block_amax = w_reshaped.abs().amax(dim=1)
|
||||
# Detect zero blocks and underflow blocks (same threshold).
|
||||
zero_block = w_block_amax < (6.0 * 2.0 ** -9)
|
||||
w_reshaped = torch.where(zero_block.unsqueeze(1),
|
||||
torch.zeros_like(w_reshaped), w_reshaped)
|
||||
w_block_amax = w_block_amax.clamp(min=1e-8)
|
||||
w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn)
|
||||
w_sf = torch.where(zero_block, torch.zeros_like(w_sf), w_sf)
|
||||
|
||||
w_block_sf = w_sf.float().unsqueeze(1)
|
||||
w_scaled = w_reshaped / w_block_sf.clamp(min=1e-8)
|
||||
|
||||
signs = torch.sign(w_scaled)
|
||||
abs_scaled = w_scaled.abs().clamp(max=6.0)
|
||||
|
||||
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
|
||||
step_to_idx = _get_step_to_idx_lut(w_bf16.device)
|
||||
indices = step_to_idx[half_steps.long()]
|
||||
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
||||
|
||||
even = nibbles[:, ::2, :]
|
||||
odd = nibbles[:, 1::2, :]
|
||||
packed = (odd << 4) | even
|
||||
|
||||
w_fp4 = packed.reshape(K // 2, N).view(torch.float4_e2m1fn_x2)
|
||||
return w_fp4, w_sf, global_scale
|
||||
|
||||
|
||||
# ── Scale Factor Assembly ─────────────────────────────────────────────
|
||||
|
||||
def interleave_l1_weights(w_ekn, granularity_bf16=8):
|
||||
"""Interleave gate/up weights at granularity 8 in BF16 (4 in FP4).
|
||||
|
||||
The fused SwiGLU epilogue requires gate/up pairs to be adjacent in the
|
||||
MMA accumulator. With interleaved weights, the MMA tile produces
|
||||
gate[i*8..i*8+7] and up[i*8..i*8+7] next to each other in registers,
|
||||
enabling a single-register SwiGLU without SMEM round-trips.
|
||||
|
||||
Before: [gate_0..gate_N/2-1 | up_0..up_N/2-1]
|
||||
After: [gate_0..gate_7, up_0..up_7, gate_8..gate_15, up_8..up_15, ...]
|
||||
|
||||
The interleave operates along the N dimension, where each column = 1 BF16
|
||||
(FP4 packing is along K, not N). So g = granularity_bf16 directly.
|
||||
|
||||
Args:
|
||||
w_ekn: (E, K_packed, N_packed) FP4 weight tensor in K-major layout
|
||||
N_packed = 2*intermediate/2 = intermediate (gate+up fused)
|
||||
granularity_bf16: interleave group size in BF16 elements (default 8)
|
||||
|
||||
Returns:
|
||||
(E, K_packed, N_packed) FP4 weight tensor with interleaved gate/up
|
||||
"""
|
||||
E, K, N = w_ekn.shape
|
||||
N_half = N // 2 # gate and up each have N/2 FP4 columns
|
||||
g = granularity_bf16 # N-axis interleave: each N-col = 1 BF16 col (packing is along K)
|
||||
|
||||
gate = w_ekn[:, :, :N_half].reshape(E, K, N_half // g, g)
|
||||
up = w_ekn[:, :, N_half:].reshape(E, K, N_half // g, g)
|
||||
return torch.stack([gate, up], dim=3).reshape(E, K, N)
|
||||
|
||||
|
||||
def deinterleave_l1_weights(w_ekn, granularity_bf16=8):
|
||||
"""De-interleave gate/up weights (inverse of interleave_l1_weights).
|
||||
|
||||
Used for testing/verification only.
|
||||
"""
|
||||
g = granularity_bf16 # N-axis: each N-col = 1 BF16 col
|
||||
E, K, N = w_ekn.shape
|
||||
w_reshaped = w_ekn.reshape(E, K, N // (2 * g), 2, g)
|
||||
gate = w_reshaped[:, :, :, 0, :].reshape(E, K, N // 2)
|
||||
up = w_reshaped[:, :, :, 1, :].reshape(E, K, N // 2)
|
||||
return torch.cat([gate, up], dim=2)
|
||||
|
||||
|
||||
def assemble_scales_2d_side(raw_scales):
|
||||
"""Assemble activation scale factors for the 2Dx3D scenario.
|
||||
|
||||
Args:
|
||||
raw_scales: list of (M_e, K_sf) float8_e4m3fn tensors, one per expert
|
||||
|
||||
Returns:
|
||||
Assembled and swizzled scale tensor
|
||||
"""
|
||||
return assemble_raw_scales_2d3d_2d_side(raw_scales)
|
||||
|
||||
|
||||
def assemble_scales_3d_side(raw_scales):
|
||||
"""Assemble weight scale factors for the 2Dx3D scenario.
|
||||
|
||||
Args:
|
||||
raw_scales: list of (K_sf, N) float8_e4m3fn tensors, one per expert
|
||||
NOTE: These will be transposed to (N, K_sf) before swizzling,
|
||||
since the kernel expects N as the non-K dimension.
|
||||
|
||||
Returns:
|
||||
Assembled and swizzled scale tensor
|
||||
"""
|
||||
# Kernel expects (N, K_sf) — transpose before swizzling
|
||||
transposed = [sf.T.contiguous() for sf in raw_scales]
|
||||
return assemble_raw_scales_2d3d_3d_side(transposed)
|
||||
|
||||
|
||||
# ── Tensor Layout Conversion ──────────────────────────────────────────
|
||||
|
||||
def make_b_k_major(b_tensor):
|
||||
"""Convert B tensor from N-major to K-major layout.
|
||||
|
||||
The kernel expects B with stride (E*K*N, 1, K) — K is contiguous.
|
||||
torch.stack produces stride (E*K*N, N, 1) — N is contiguous.
|
||||
|
||||
Args:
|
||||
b_tensor: (experts, K_packed, N_packed) float4_e2m1fn_x2, N-major
|
||||
|
||||
Returns:
|
||||
Same shape, K-major strides
|
||||
"""
|
||||
return b_tensor.permute(0, 2, 1).contiguous().permute(0, 2, 1)
|
||||
|
||||
|
||||
def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"):
|
||||
"""Compute cumulative token offsets for the grouped GEMM.
|
||||
|
||||
Args:
|
||||
tokens_per_expert: list of int, one per expert
|
||||
|
||||
Returns:
|
||||
offs: (num_experts,) int32 — cumulative sum
|
||||
"""
|
||||
offs = torch.tensor(
|
||||
[sum(tokens_per_expert[:e+1]) for e in range(num_experts)],
|
||||
dtype=torch.int32, device=device,
|
||||
)
|
||||
return offs
|
||||
|
||||
|
||||
# ── Kernel Launch ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def warmup_compilation(num_experts, K_packed, N_packed, device,
|
||||
mma_tiler_mn=(128, 128), cluster_shape_mn=(1, 1)):
|
||||
"""Eagerly JIT-compile the GEMM kernel for a specific shape.
|
||||
@@ -589,10 +266,7 @@ def run_nvfp4_grouped_gemm(
|
||||
|
||||
# ── Fused SwiGLU GEMM (Stage 1: SiLU in registers, BF16 output) ──────
|
||||
|
||||
# Cache for fused kernel (separate from standard GEMM cache)
|
||||
_fused_kernel_cache = {}
|
||||
|
||||
|
||||
def warmup_fused_swiglu_compilation(num_experts, K_packed, N_packed, device,
|
||||
swiglu_limit=0.0,
|
||||
mma_tiler_mn=(128, 128),
|
||||
@@ -602,7 +276,7 @@ def warmup_fused_swiglu_compilation(num_experts, K_packed, N_packed, device,
|
||||
Must be called during model initialization. See warmup_compilation()
|
||||
for the standard GEMM equivalent.
|
||||
"""
|
||||
from cutedsl.kernel.moe.fused_swiglu_grouped_mm import FusedSwiGLUScaledGroupedGemmKernel
|
||||
from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel
|
||||
|
||||
cache_key = ('fused', num_experts, str(device), mma_tiler_mn, cluster_shape_mn,
|
||||
K_packed, N_packed, swiglu_limit)
|
||||
@@ -697,7 +371,7 @@ def run_fused_swiglu_grouped_gemm(
|
||||
Stage 1: SiLU is applied to the full accumulator in registers,
|
||||
then written as BF16 to C. Gate/up pairing is not yet implemented.
|
||||
"""
|
||||
from cutedsl.kernel.moe.fused_swiglu_grouped_mm import FusedSwiGLUScaledGroupedGemmKernel
|
||||
from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel
|
||||
|
||||
num_experts = mat_b.shape[0]
|
||||
n_dim = mat_b.shape[2]
|
||||
@@ -789,28 +463,3 @@ def run_fused_swiglu_grouped_gemm(
|
||||
|
||||
|
||||
|
||||
def deinterleave_quantize_nvfp4_cuda(fused_bf16, intermediate, global_scale, granularity=8):
|
||||
"""De-interleave + quantize fused SwiGLU output using a custom CUDA kernel.
|
||||
|
||||
Single kernel launch, no Python loop. 4x faster than the Python path.
|
||||
|
||||
Args:
|
||||
fused_bf16: (M, 2*intermediate) BF16 — fused L1 output with interleaved gate/up
|
||||
intermediate: intermediate dimension (e.g., 3072)
|
||||
global_scale: pre-computed global scale for quantization
|
||||
granularity: interleave granularity in BF16 columns (default 8)
|
||||
|
||||
Returns:
|
||||
x_fp4: (M, intermediate//2) float4_e2m1fn_x2 — quantized SwiGLU
|
||||
x_sf: (M, intermediate//16) float8_e4m3fn — block scales
|
||||
"""
|
||||
from torch.utils.cpp_extension import load
|
||||
import os
|
||||
kernel_dir = os.path.join(os.path.dirname(__file__), "kernels")
|
||||
mod = load(
|
||||
name="deinterleave_quantize_nvfp4",
|
||||
sources=[os.path.join(kernel_dir, "deinterleave_quantize.cu")],
|
||||
extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"],
|
||||
verbose=False,
|
||||
)
|
||||
return mod.deinterleave_quantize_nvfp4(fused_bf16, intermediate, granularity, global_scale)
|
||||
123
dsv4/ops/layouts.py
Normal file
123
dsv4/ops/layouts.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Tensor layout helpers: scale swizzle, gate/up interleave, K-major, offsets."""
|
||||
import torch
|
||||
|
||||
from dsv4.kernels.gemm.grouped import (
|
||||
pad_and_swizzle_single,
|
||||
assemble_raw_scales_2d3d_2d_side,
|
||||
assemble_raw_scales_2d3d_3d_side,
|
||||
)
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def round_up(a, b):
|
||||
return ceil_div(a, b) * b
|
||||
|
||||
def interleave_l1_weights(w_ekn, granularity_bf16=8):
|
||||
"""Interleave gate/up weights at granularity 8 in BF16 (4 in FP4).
|
||||
|
||||
The fused SwiGLU epilogue requires gate/up pairs to be adjacent in the
|
||||
MMA accumulator. With interleaved weights, the MMA tile produces
|
||||
gate[i*8..i*8+7] and up[i*8..i*8+7] next to each other in registers,
|
||||
enabling a single-register SwiGLU without SMEM round-trips.
|
||||
|
||||
Before: [gate_0..gate_N/2-1 | up_0..up_N/2-1]
|
||||
After: [gate_0..gate_7, up_0..up_7, gate_8..gate_15, up_8..up_15, ...]
|
||||
|
||||
The interleave operates along the N dimension, where each column = 1 BF16
|
||||
(FP4 packing is along K, not N). So g = granularity_bf16 directly.
|
||||
|
||||
Args:
|
||||
w_ekn: (E, K_packed, N_packed) FP4 weight tensor in K-major layout
|
||||
N_packed = 2*intermediate/2 = intermediate (gate+up fused)
|
||||
granularity_bf16: interleave group size in BF16 elements (default 8)
|
||||
|
||||
Returns:
|
||||
(E, K_packed, N_packed) FP4 weight tensor with interleaved gate/up
|
||||
"""
|
||||
E, K, N = w_ekn.shape
|
||||
N_half = N // 2 # gate and up each have N/2 FP4 columns
|
||||
g = granularity_bf16 # N-axis interleave: each N-col = 1 BF16 col (packing is along K)
|
||||
|
||||
gate = w_ekn[:, :, :N_half].reshape(E, K, N_half // g, g)
|
||||
up = w_ekn[:, :, N_half:].reshape(E, K, N_half // g, g)
|
||||
return torch.stack([gate, up], dim=3).reshape(E, K, N)
|
||||
|
||||
|
||||
def deinterleave_l1_weights(w_ekn, granularity_bf16=8):
|
||||
"""De-interleave gate/up weights (inverse of interleave_l1_weights).
|
||||
|
||||
Used for testing/verification only.
|
||||
"""
|
||||
g = granularity_bf16 # N-axis: each N-col = 1 BF16 col
|
||||
E, K, N = w_ekn.shape
|
||||
w_reshaped = w_ekn.reshape(E, K, N // (2 * g), 2, g)
|
||||
gate = w_reshaped[:, :, :, 0, :].reshape(E, K, N // 2)
|
||||
up = w_reshaped[:, :, :, 1, :].reshape(E, K, N // 2)
|
||||
return torch.cat([gate, up], dim=2)
|
||||
|
||||
|
||||
def assemble_scales_2d_side(raw_scales):
|
||||
"""Assemble activation scale factors for the 2Dx3D scenario.
|
||||
|
||||
Args:
|
||||
raw_scales: list of (M_e, K_sf) float8_e4m3fn tensors, one per expert
|
||||
|
||||
Returns:
|
||||
Assembled and swizzled scale tensor
|
||||
"""
|
||||
return assemble_raw_scales_2d3d_2d_side(raw_scales)
|
||||
|
||||
|
||||
def assemble_scales_3d_side(raw_scales):
|
||||
"""Assemble weight scale factors for the 2Dx3D scenario.
|
||||
|
||||
Args:
|
||||
raw_scales: list of (K_sf, N) float8_e4m3fn tensors, one per expert
|
||||
NOTE: These will be transposed to (N, K_sf) before swizzling,
|
||||
since the kernel expects N as the non-K dimension.
|
||||
|
||||
Returns:
|
||||
Assembled and swizzled scale tensor
|
||||
"""
|
||||
# Kernel expects (N, K_sf) — transpose before swizzling
|
||||
transposed = [sf.T.contiguous() for sf in raw_scales]
|
||||
return assemble_raw_scales_2d3d_3d_side(transposed)
|
||||
|
||||
|
||||
# ── Tensor Layout Conversion ──────────────────────────────────────────
|
||||
|
||||
def make_b_k_major(b_tensor):
|
||||
"""Convert B tensor from N-major to K-major layout.
|
||||
|
||||
The kernel expects B with stride (E*K*N, 1, K) — K is contiguous.
|
||||
torch.stack produces stride (E*K*N, N, 1) — N is contiguous.
|
||||
|
||||
Args:
|
||||
b_tensor: (experts, K_packed, N_packed) float4_e2m1fn_x2, N-major
|
||||
|
||||
Returns:
|
||||
Same shape, K-major strides
|
||||
"""
|
||||
return b_tensor.permute(0, 2, 1).contiguous().permute(0, 2, 1)
|
||||
|
||||
|
||||
def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"):
|
||||
"""Compute cumulative token offsets for the grouped GEMM.
|
||||
|
||||
Args:
|
||||
tokens_per_expert: list of int, one per expert
|
||||
|
||||
Returns:
|
||||
offs: (num_experts,) int32 — cumulative sum
|
||||
"""
|
||||
offs = torch.tensor(
|
||||
[sum(tokens_per_expert[:e+1]) for e in range(num_experts)],
|
||||
dtype=torch.int32, device=device,
|
||||
)
|
||||
return offs
|
||||
|
||||
|
||||
# ── Kernel Launch ─────────────────────────────────────────────────────
|
||||
|
||||
253
dsv4/ops/quantize.py
Normal file
253
dsv4/ops/quantize.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""NVFP4 quantization: BF16 <-> NVFP4 conversion, scale factor computation."""
|
||||
import math
|
||||
import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils as utils
|
||||
|
||||
from dsv4.kernels.gemm.grouped import (
|
||||
cat_byte_reinterpretable_tensors,
|
||||
stack_byte_reinterpretable_tensors,
|
||||
)
|
||||
|
||||
E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
||||
|
||||
# Cache compiled kernels + pre-allocated workspace by cache_key
|
||||
# Each entry: {'compiled': callable, 'workspace': Tensor, 'workspace_size': int}
|
||||
#
|
||||
# Key design decisions (Bug #1 fix):
|
||||
# - cute.compile does NOT corrupt GPU memory (verified 2026-05-20 on B200).
|
||||
# The original _needs_token_refill hack was a misdiagnosis. The real bug
|
||||
# was elsewhere (likely OOB write or weight loading).
|
||||
# - Workspace is pre-allocated per cache entry during warmup_compilation()
|
||||
# and reused on subsequent calls. No torch.full() in the hot path.
|
||||
# - CuTe tensor wrappers (from_dlpack + mark_layout_dynamic) are cheap
|
||||
# metadata wrappers. We re-create them per call from real tensors.
|
||||
# Caching them would hold stale references to tensors that get freed.
|
||||
|
||||
# Cached LUT for E2M1 quantization (created once per device, cudagraph-safe)
|
||||
_NVFP4_STEP_LUT_CACHE = {}
|
||||
def _get_step_to_idx_lut(device):
|
||||
"""Get or create the E2M1 step-to-index LUT for the given device.
|
||||
|
||||
Cached per device to avoid CPU->CUDA copies during cudagraph capture.
|
||||
Must be pre-populated during warmup (before torch.compile/cudagraph capture)
|
||||
so the lock is never entered on the compiled path.
|
||||
"""
|
||||
# Fast path: already cached — no lock needed (torch.compile-safe)
|
||||
if device in _NVFP4_STEP_LUT_CACHE:
|
||||
return _NVFP4_STEP_LUT_CACHE[device]
|
||||
# Slow path: first call, create the LUT
|
||||
lut = torch.as_tensor(
|
||||
[0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7],
|
||||
dtype=torch.int8, device=device,
|
||||
)
|
||||
_NVFP4_STEP_LUT_CACHE[device] = lut
|
||||
return lut
|
||||
SF_VEC_SIZE = 16 # NVFP4 block size
|
||||
|
||||
def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
|
||||
"""Quantize BF16 tensor to NVFP4.
|
||||
|
||||
Args:
|
||||
x_bf16: (..., D) BF16 tensor
|
||||
|
||||
Returns:
|
||||
x_fp4: (..., D//2) float4_e2m1fn_x2 — native PyTorch FP4
|
||||
x_sf: (..., D//16) float8_e4m3fn — block scales
|
||||
global_scale: float32 scalar
|
||||
"""
|
||||
x_f32 = x_bf16.float()
|
||||
amax = x_f32.abs().max().clamp(min=1e-8).float()
|
||||
global_scale = amax / (6.0 * 448.0)
|
||||
x_norm = x_f32 / global_scale
|
||||
|
||||
last_dim = x_norm.shape[-1]
|
||||
n_blocks = ceil_div(last_dim, block_size)
|
||||
|
||||
if last_dim % block_size != 0:
|
||||
pad_size = n_blocks * block_size - last_dim
|
||||
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
|
||||
|
||||
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
|
||||
block_amax = x_reshaped.abs().amax(dim=-1)
|
||||
# Detect zero blocks and underflow blocks (amax > 0 but too small for FP8).
|
||||
# Smallest positive FP8 e4m3fn is 2^-9 ≈ 1.95e-3. If amax/6 < this,
|
||||
# the block scale underflows to 0, and dividing x by the clamped 1e-8
|
||||
# inflates values into nonzero FP4 buckets — producing wrong results.
|
||||
zero_block = block_amax < (6.0 * 2.0 ** -9) # < ~0.0117
|
||||
# Zero out x for zero/underflow blocks before division.
|
||||
# This ensures x_scaled = 0 → FP4 nibbles = 0.
|
||||
x_reshaped = torch.where(zero_block.unsqueeze(-1),
|
||||
torch.zeros_like(x_reshaped), x_reshaped)
|
||||
block_amax = block_amax.clamp(min=1e-8)
|
||||
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
|
||||
# Force zero/underflow blocks: FP8 scale = 0 (exact zero).
|
||||
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
|
||||
|
||||
# Nearest E2M1
|
||||
block_sf_expanded = block_scale.float().unsqueeze(-1)
|
||||
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
|
||||
|
||||
signs = torch.sign(x_scaled)
|
||||
abs_scaled = x_scaled.abs().clamp(max=6.0)
|
||||
|
||||
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
|
||||
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
|
||||
indices = step_to_idx[half_steps.long()]
|
||||
|
||||
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
||||
even = nibbles[..., ::2]
|
||||
odd = nibbles[..., 1::2]
|
||||
packed = (odd << 4) | even
|
||||
|
||||
packed_shape = list(x_bf16.shape)
|
||||
packed_shape[-1] = last_dim // 2
|
||||
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
|
||||
|
||||
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
|
||||
block_scale = block_scale.reshape(sf_shape)
|
||||
|
||||
return x_fp4, block_scale, global_scale
|
||||
|
||||
|
||||
def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE):
|
||||
"""Quantize BF16 activation tensor to NVFP4 (cudagraph-safe).
|
||||
|
||||
Unlike quantize_to_nvfp4(), this takes a pre-computed global_scale
|
||||
instead of computing it via .max() (which forces CPU-GPU sync).
|
||||
All operations are pure GPU with no CPU-GPU syncs.
|
||||
|
||||
Args:
|
||||
x_bf16: (..., D) BF16 tensor
|
||||
global_scale: float32 scalar (pre-computed, NOT from .max())
|
||||
block_size: NVFP4 block size
|
||||
|
||||
Returns:
|
||||
x_fp4: (..., D//2) float4_e2m1fn_x2
|
||||
x_sf: (..., D//16) float8_e4m3fn
|
||||
"""
|
||||
x_f32 = x_bf16.float()
|
||||
x_norm = x_f32 / global_scale
|
||||
|
||||
last_dim = x_norm.shape[-1]
|
||||
n_blocks = ceil_div(last_dim, block_size)
|
||||
|
||||
if last_dim % block_size != 0:
|
||||
pad_size = n_blocks * block_size - last_dim
|
||||
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
|
||||
|
||||
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
|
||||
block_amax = x_reshaped.abs().amax(dim=-1)
|
||||
# Detect zero blocks and underflow blocks (same threshold as quantize_to_nvfp4).
|
||||
zero_block = block_amax < (6.0 * 2.0 ** -9)
|
||||
x_reshaped = torch.where(zero_block.unsqueeze(-1),
|
||||
torch.zeros_like(x_reshaped), x_reshaped)
|
||||
block_amax = block_amax.clamp(min=1e-8)
|
||||
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
|
||||
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
|
||||
|
||||
block_sf_expanded = block_scale.float().unsqueeze(-1)
|
||||
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
|
||||
signs = torch.sign(x_scaled)
|
||||
abs_scaled = x_scaled.abs().clamp(max=6.0)
|
||||
|
||||
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
|
||||
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
|
||||
indices = step_to_idx[half_steps.long()]
|
||||
|
||||
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
||||
even = nibbles[..., ::2]
|
||||
odd = nibbles[..., 1::2]
|
||||
packed = (odd << 4) | even
|
||||
|
||||
packed_shape = list(x_bf16.shape)
|
||||
packed_shape[-1] = last_dim // 2
|
||||
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
|
||||
|
||||
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
|
||||
block_scale = block_scale.reshape(sf_shape)
|
||||
|
||||
return x_fp4, block_scale
|
||||
|
||||
|
||||
def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE):
|
||||
"""Quantize BF16 weight matrix to NVFP4.
|
||||
|
||||
The weight is (K, N) where K is the input dim (packed dimension).
|
||||
Block scales are computed along K (dim 0).
|
||||
|
||||
Args:
|
||||
w_bf16: (K, N) BF16 weight matrix
|
||||
|
||||
Returns:
|
||||
w_fp4: (K//2, N) float4_e2m1fn_x2 — K is the packed dim
|
||||
w_sf: (K//16, N) float8_e4m3fn — block scales along K
|
||||
global_scale: float32 scalar
|
||||
"""
|
||||
K, N = w_bf16.shape
|
||||
w_f32 = w_bf16.float()
|
||||
amax = w_f32.abs().max().clamp(min=1e-8).float()
|
||||
global_scale = amax / (6.0 * 448.0)
|
||||
w_norm = w_f32 / global_scale
|
||||
|
||||
k_blocks = ceil_div(K, block_size)
|
||||
if K % block_size != 0:
|
||||
w_norm = torch.nn.functional.pad(w_norm, (0, 0, 0, k_blocks * block_size - K))
|
||||
|
||||
w_reshaped = w_norm.reshape(k_blocks, block_size, N)
|
||||
w_block_amax = w_reshaped.abs().amax(dim=1)
|
||||
# Detect zero blocks and underflow blocks (same threshold).
|
||||
zero_block = w_block_amax < (6.0 * 2.0 ** -9)
|
||||
w_reshaped = torch.where(zero_block.unsqueeze(1),
|
||||
torch.zeros_like(w_reshaped), w_reshaped)
|
||||
w_block_amax = w_block_amax.clamp(min=1e-8)
|
||||
w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn)
|
||||
w_sf = torch.where(zero_block, torch.zeros_like(w_sf), w_sf)
|
||||
|
||||
w_block_sf = w_sf.float().unsqueeze(1)
|
||||
w_scaled = w_reshaped / w_block_sf.clamp(min=1e-8)
|
||||
|
||||
signs = torch.sign(w_scaled)
|
||||
abs_scaled = w_scaled.abs().clamp(max=6.0)
|
||||
|
||||
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
|
||||
step_to_idx = _get_step_to_idx_lut(w_bf16.device)
|
||||
indices = step_to_idx[half_steps.long()]
|
||||
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
||||
|
||||
even = nibbles[:, ::2, :]
|
||||
odd = nibbles[:, 1::2, :]
|
||||
packed = (odd << 4) | even
|
||||
|
||||
w_fp4 = packed.reshape(K // 2, N).view(torch.float4_e2m1fn_x2)
|
||||
return w_fp4, w_sf, global_scale
|
||||
|
||||
|
||||
# ── Scale Factor Assembly ─────────────────────────────────────────────
|
||||
def deinterleave_quantize_nvfp4_cuda(fused_bf16, intermediate, global_scale, granularity=8):
|
||||
"""De-interleave + quantize fused SwiGLU output using a custom CUDA kernel.
|
||||
|
||||
Single kernel launch, no Python loop. 4x faster than the Python path.
|
||||
|
||||
Args:
|
||||
fused_bf16: (M, 2*intermediate) BF16 — fused L1 output with interleaved gate/up
|
||||
intermediate: intermediate dimension (e.g., 3072)
|
||||
global_scale: pre-computed global scale for quantization
|
||||
granularity: interleave granularity in BF16 columns (default 8)
|
||||
|
||||
Returns:
|
||||
x_fp4: (M, intermediate//2) float4_e2m1fn_x2 — quantized SwiGLU
|
||||
x_sf: (M, intermediate//16) float8_e4m3fn — block scales
|
||||
"""
|
||||
from torch.utils.cpp_extension import load
|
||||
import os
|
||||
kernel_dir = os.path.join(os.path.dirname(__file__), "kernels")
|
||||
mod = load(
|
||||
name="deinterleave_quantize_nvfp4",
|
||||
sources=[os.path.join(kernel_dir, "deinterleave_quantize.cu")],
|
||||
extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"],
|
||||
verbose=False,
|
||||
)
|
||||
return mod.deinterleave_quantize_nvfp4(fused_bf16, intermediate, granularity, global_scale)
|
||||
0
dsv4/reference/__init__.py
Normal file
0
dsv4/reference/__init__.py
Normal file
@@ -14,15 +14,19 @@ block scales in float8_e4m3fn, global scales in float32.
|
||||
"""
|
||||
import torch
|
||||
|
||||
from cutedsl.bridge import (
|
||||
from dsv4.ops.quantize import (
|
||||
quantize_to_nvfp4,
|
||||
quantize_weight_to_nvfp4,
|
||||
)
|
||||
from dsv4.ops.layouts import (
|
||||
assemble_scales_2d_side,
|
||||
assemble_scales_3d_side,
|
||||
make_b_k_major,
|
||||
compute_expert_offsets,
|
||||
interleave_l1_weights,
|
||||
deinterleave_l1_weights,
|
||||
)
|
||||
from dsv4.ops.gemm_runner import (
|
||||
run_nvfp4_grouped_gemm,
|
||||
run_fused_swiglu_grouped_gemm,
|
||||
warmup_fused_swiglu_compilation,
|
||||
@@ -198,7 +202,7 @@ def run_nvfp4_moe(
|
||||
sf_ekn = sf.unsqueeze(0) # (1, K_sf, N)
|
||||
sf_ekn = interleave_l1_weights(sf_ekn) # interleaved along N
|
||||
l1_sf_il.append(sf_ekn[0].T.contiguous()) # (N, K_sf) for assembly
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import assemble_raw_scales_2d3d_3d_side as _assemble_3d
|
||||
from dsv4.kernels.gemm.grouped import assemble_raw_scales_2d3d_3d_side as _assemble_3d
|
||||
l1_scale_b = _assemble_3d(l1_sf_il)
|
||||
|
||||
# Global scales: alpha = igs * weight_gs for each expert
|
||||
@@ -347,7 +351,7 @@ def run_nvfp4_moe_fused(
|
||||
sf_ekn = sf.unsqueeze(0)
|
||||
sf_ekn = interleave_l1_weights(sf_ekn)
|
||||
l1_sf_il.append(sf_ekn[0].T.contiguous())
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import assemble_raw_scales_2d3d_3d_side as _assemble_3d
|
||||
from dsv4.kernels.gemm.grouped import assemble_raw_scales_2d3d_3d_side as _assemble_3d
|
||||
l1_scale_b = _assemble_3d(l1_sf_il)
|
||||
|
||||
l1_global_scale_a = torch.tensor([x_igs] * num_experts, dtype=torch.float32, device=device)
|
||||
@@ -368,7 +372,10 @@ def run_nvfp4_moe_fused(
|
||||
intermediate_size = l1_fused_out.shape[1] // 2
|
||||
# Use pre-computed L2 activation gs, or compute from amax (fallback)
|
||||
l2_gs = l2_activation_gs if l2_activation_gs is not None else l1_fused_out.abs().amax().float().item() / 2688.0
|
||||
from cutedsl.bridge import deinterleave_quantize_nvfp4_cuda, quantize_activation_nvfp4
|
||||
from dsv4.ops.quantize import (
|
||||
deinterleave_quantize_nvfp4_cuda,
|
||||
quantize_activation_nvfp4,
|
||||
)
|
||||
l2_x_fp4, l2_x_sf = deinterleave_quantize_nvfp4_cuda(l1_fused_out, intermediate_size, l2_gs)
|
||||
# Skip the separate L2 quantize step below — we already have FP4+SF
|
||||
# Set activated to None to signal we already quantized
|
||||
@@ -3,7 +3,7 @@ requires = ["setuptools>=68.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "nvfp4-megamoe-kernel"
|
||||
name = "dsv4-inference"
|
||||
version = "0.1.0"
|
||||
description = "NVFP4 Mega MoE kernel for DeepSeek-V4-Pro on Blackwell (TileLang)"
|
||||
requires-python = ">=3.10"
|
||||
@@ -13,3 +13,4 @@ dependencies = [
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["dsv4*"]
|
||||
|
||||
@@ -6,7 +6,7 @@ sys.path.insert(0, '/root/nvfp4-megamoe-kernel/cutedsl')
|
||||
sys.path.insert(0, '/root/nvfp4-megamoe-kernel/vllm')
|
||||
|
||||
from cutedsl.reference.moe_pipeline import moe_pipeline
|
||||
from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
|
||||
from vllm.nvfp4_cutedsl import Nvfp4MoE
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
@@ -33,7 +33,7 @@ ref_out = moe_pipeline(
|
||||
print(f"Reference output: amax={ref_out.amax().item():.4f} mean={ref_out.mean().item():.4f}")
|
||||
|
||||
# Run runner with warmup gs
|
||||
runner = CuTeDSLMoERunner(
|
||||
runner = Nvfp4MoE(
|
||||
num_experts=3, hidden_size=256, intermediate_size=512,
|
||||
max_num_tokens=4, top_k=2, device='cuda'
|
||||
)
|
||||
@@ -51,14 +51,14 @@ def dequant_nvfp4(packed_uint8, scale_e4m3, global_scale):
|
||||
def test_projection(name, weight, weight_sf, weight_gs, hidden_states, in_features, out_features):
|
||||
"""Test a single NVFP4 projection."""
|
||||
sys.path.insert(0, "/root/nvfp4-megamoe-kernel")
|
||||
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
|
||||
# Convert weight to CuTeDSL format: (out, in_packed) uint8 → (in_packed, out) float4
|
||||
fp4 = [weight.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()]
|
||||
sf = [weight_sf.permute(1, 0).contiguous()]
|
||||
gs = [weight_gs]
|
||||
|
||||
runner = CuTeDSLNvfp4Linear(
|
||||
runner = Nvfp4Linear(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
max_num_tokens=8192,
|
||||
@@ -55,7 +55,7 @@ def rms(x, w, eps=1e-6):
|
||||
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
|
||||
|
||||
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
|
||||
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
|
||||
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
|
||||
s = s.permute(1,0).contiguous()
|
||||
@@ -66,7 +66,7 @@ def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
|
||||
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
|
||||
else:
|
||||
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
|
||||
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
|
||||
r = Nvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
|
||||
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
|
||||
r.finalize_weights(); r._ensure_initialized()
|
||||
return r
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user