Files
nvfp4-megamoe-kernel/vendored/blockscaled_layout.py
biondizzle 3fb3c925af Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

658 lines
20 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from dataclasses import dataclass, field
from typing import Optional
from cutlass.cutlass_dsl import dsl_user_op
import cutlass.cute as cute
from cutlass.cute.nvgpu import OperandMajorMode
from cutlass._mlir import ir
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
@dataclass(frozen=True)
class BlockScaledBasicChunk:
"""
The basic scale factor atom layout decided by tcgen05 BlockScaled MMA Ops.
This class represents the fixed layout pattern for scale factors used in
tcgen05 BlockScaled MMA Ops. The layout is determined by the
instruction specification and cannot be modified.
See `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x>`.
"""
sf_vec_size: int
major_mode: OperandMajorMode = OperandMajorMode.K
_layout: cute.Layout = field(init=False, repr=False)
def __post_init__(self) -> None:
if self.major_mode == OperandMajorMode.K:
# K-major layout: (AtomMN, AtomK)
atom_shape = ((32, 4), (self.sf_vec_size, 4))
atom_stride = ((16, 4), (0, 1))
else:
# MN-major layout: (AtomK, AtomMN)
atom_shape = ((self.sf_vec_size, 4), (32, 4))
atom_stride = ((0, 1), (16, 4))
object.__setattr__(
self, "_layout", cute.make_layout(atom_shape, stride=atom_stride)
)
@property
def layout(self) -> cute.Layout:
"""
Get the layout for this block scaled chunk.
:return: The layout representing the scale factor atom
:rtype: cute.Layout
"""
return self._layout
@dsl_user_op
def tile_atom_to_shape_SF(
Shape: cute.Shape,
sf_vec_size: int,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""
A helper function to get dynamic SFA/SFB layout by filling dynamic A/B shape to the scale factor atom layout.
:param Shape: The shape of the A/B tensor
:param sf_vec_size: Scale factor vector size
:return: The layout of the SFA/SFB tensor
:rtype: cute.Layout
"""
# ((Atom_MN, Rest_MN),(Atom_K, Rest_K),RestL)
sf_layout = cute.tile_to_shape(
BlockScaledBasicChunk(sf_vec_size).layout, Shape, (2, 1, 3)
)
return sf_layout
@dsl_user_op
def make_smem_layout_sf(
tile_shape: cute.Tile,
sf_vec_size: int,
num_stages: int,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""
A helper function to get dynamic SFA/SFB layout by filling dynamic A/B shape to the scale factor atom layout.
:param Shape: The shape of the A/B tensor
:param sf_vec_size: Scale factor vector size
:param num_stages: Number of stages
:return: The layout of the SFA/SFB tensor
:rtype: cute.Layout
"""
smem_layout = cute.tile_to_shape(
BlockScaledBasicChunk(sf_vec_size).layout,
tile_shape, # type: ignore[arg-type]
(2, 1),
loc=loc,
ip=ip,
)
smem_layout_staged = cute.append(
smem_layout,
cute.make_layout(
num_stages,
stride=cute.cosize(cute.filter_zeros(smem_layout)),
loc=loc,
ip=ip,
),
loc=loc,
ip=ip,
)
return smem_layout_staged
@dsl_user_op
def make_smem_layout_sfa(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: cute.Tile,
sf_vec_size: int,
num_stages: int,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""
Make smem layout for SFA based on:
1. BlockScaledBasicChunk
2. MMA tiler shape
3. Scale factor vector size
4. Number of stages
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The mma tiler shape
:type mma_tiler_mnk: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param num_stages: The number of stages
:type num_stages: int
:return: Smem layout for SFA
:rtype: cute.Layout
"""
# (CTA_Tile_Shape_M, MMA_Tile_Shape_K)
sfa_tile_shape = (
mma_tiler_mnk[0] // cute.size(tiled_mma.thr_id.shape), # type: ignore[index]
mma_tiler_mnk[2], # type: ignore[index]
)
# ((Atom_M, Rest_M),(Atom_K, Rest_K))
smem_layout = cute.tile_to_shape(
BlockScaledBasicChunk(sf_vec_size).layout,
sfa_tile_shape, # type: ignore[arg-type]
(2, 1),
)
# Number of MMA instructions to cover all k-tiles
mma_tile_inst_m = mma_tiler_mnk[0] // cute.size(tiled_mma.shape_mnk, mode=[0]) # type: ignore[index]
mma_tile_inst_k = mma_tiler_mnk[2] // cute.size(tiled_mma.shape_mnk, mode=[2]) # type: ignore[index]
# (CTA_Tile_Shape_M, MMA_Inst_Shape_K)
sfa_tile_shape = cute.shape_div(sfa_tile_shape, (mma_tile_inst_m, mma_tile_inst_k))
# ((Atom_Inst_M, Atom_Inst_K), MMA_M, MMA_K))
smem_layout = cute.tiled_divide(smem_layout, sfa_tile_shape)
atom_m = 128
tiler_inst = ((atom_m, sf_vec_size),)
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K)
smem_layout = cute.logical_divide(smem_layout, tiler_inst)
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE)
sfa_smem_layout_staged = cute.append(
smem_layout,
cute.make_layout(
num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout))
),
)
return sfa_smem_layout_staged
@dsl_user_op
def make_smem_layout_sfb(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: cute.Tile,
sf_vec_size: int,
num_stages: int,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""
Make smem layout for SFB based on:
1. BlockScaledBasicChunk
2. MMA tiler shape
3. Scale factor vector size
4. Number of stages
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The mma tiler shape
:type mma_tiler_mnk: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param num_stages: The number of stages
:type num_stages: int
:return: Smem layout for SFA
:rtype: cute.Layout
"""
# (Round_Up(CTA_Tile_Shape_N, 128), MMA_Tile_Shape_K)
sfb_tile_shape = (
cute.round_up(mma_tiler_mnk[1], 128), # type: ignore[index, arg-type]
mma_tiler_mnk[2], # type: ignore[index]
)
# ((Atom_N, Rest_N),(Atom_K, Rest_K))
smem_layout = cute.tile_to_shape(
BlockScaledBasicChunk(sf_vec_size).layout,
sfb_tile_shape, # type: ignore[arg-type]
(2, 1),
)
# Number of MMA instructions to cover all k-tiles
mma_tile_inst_n = mma_tiler_mnk[1] // cute.size(tiled_mma.shape_mnk, mode=[1]) # type: ignore[index]
mma_tile_inst_k = mma_tiler_mnk[2] // cute.size(tiled_mma.shape_mnk, mode=[2]) # type: ignore[index]
# (CTA_Tile_Shape_N, MMA_Inst_Shape_K)
sfb_tile_shape = cute.shape_div(sfb_tile_shape, (mma_tile_inst_n, mma_tile_inst_k))
# ((Atom_Inst_N, Atom_Inst_K), MMA_N, MMA_K)
smem_layout = cute.tiled_divide(smem_layout, sfb_tile_shape)
atom_n = 128
tiler_inst = ((atom_n, sf_vec_size),)
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K)
smem_layout = cute.logical_divide(smem_layout, tiler_inst)
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE)
sfb_smem_layout_staged = cute.append(
smem_layout,
cute.make_layout(
num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout))
),
)
return sfb_smem_layout_staged
@dsl_user_op
def sm120_make_smem_layout_sfa(
tiled_mma: cute.TiledMma,
tile_shape_mnk: cute.Tile,
sf_vec_size: int,
num_stages: int,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""
Make smem layout for SFA based on:
1. BlockScaledBasicChunk
2. MMA tiler shape
3. Scale factor vector size
4. Number of stages
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The mma tiler shape
:type mma_tiler_mnk: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param num_stages: The number of stages
:type num_stages: int
:return: Smem layout for SFA
:rtype: cute.Layout
"""
assert sf_vec_size == 16 or sf_vec_size == 32, "sf_vec_size must be 16 or 32"
blk_mn = 128
blk_sf = 4
blk_elems = blk_mn * blk_sf
mma_nsf = tiled_mma.shape_mnk[2] // sf_vec_size
mn_basic_block_shape = (32, 4)
mn_basic_block_stride = (16, 4)
k_basic_block_shape = (sf_vec_size, mma_nsf)
k_basic_block_stride = (0, 1)
assert tile_shape_mnk[0] % blk_mn == 0, ( # type: ignore[index, operator]
"tile_shape_mnk[0] must be divisible by blk_mn"
)
sSFA_shapeM = (mn_basic_block_shape, tile_shape_mnk[0] // blk_mn) # type: ignore[index, operator]
sSF_strideM = (mn_basic_block_stride, blk_elems)
assert tile_shape_mnk[2] % (blk_sf * mma_nsf) == 0, ( # type: ignore[index]
"tile_shape_mnk[2] must be divisible by blk_sf * mma_nsf"
)
sSFA_shapeK = (
k_basic_block_shape,
blk_sf // mma_nsf,
tile_shape_mnk[2] // sf_vec_size // blk_sf, # type: ignore[index, operator]
)
sSF_strideK = (
k_basic_block_stride,
mma_nsf,
tile_shape_mnk[0] // blk_mn * blk_elems, # type: ignore[index, operator]
)
sSFA_shape = (sSFA_shapeM, sSFA_shapeK)
sSFA_stride = (sSF_strideM, sSF_strideK)
smem_layout = cute.make_layout(sSFA_shape, stride=sSFA_stride)
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE)
sfa_smem_layout_staged = cute.append(
smem_layout,
cute.make_layout(
num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout))
),
)
return sfa_smem_layout_staged
@dsl_user_op
def sm120_make_smem_layout_sfb(
tiled_mma: cute.TiledMma,
tile_shape_mnk: cute.Tile,
sf_vec_size: int,
num_stages: int,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""
Make smem layout for SFB based on:
1. BlockScaledBasicChunk
2. MMA tiler shape
3. Scale factor vector size
4. Number of stages
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The mma tiler shape
:type mma_tiler_mnk: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param num_stages: The number of stages
:type num_stages: int
:return: Smem layout for SFA
:rtype: cute.Layout
"""
# A single indivisible block will hold 4 scale factors of 128 rows/columns (A/B matrix).
# 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row(col).
blk_mn = 128
blk_sf = 4
blk_elems = blk_mn * blk_sf
assert sf_vec_size == 16 or sf_vec_size == 32, "sf_vec_size must be 16 or 32"
assert tile_shape_mnk[1] % blk_mn == 0, ( # type: ignore[index, operator]
"tile_shape_mnk[1] must be divisible by blk_mn"
)
assert tile_shape_mnk[2] % sf_vec_size == 0, ( # type: ignore[index, operator]
"tile_shape_mnk[2] must be divisible by sf_vec_size"
)
mma_nsf = tiled_mma.shape_mnk[2] // sf_vec_size
mn_basic_block_shape = (32, 4)
mn_basic_block_stride = (16, 4)
k_basic_block_shape = (sf_vec_size, mma_nsf)
k_basic_block_stride = (0, 1)
assert tile_shape_mnk[1] % blk_mn == 0, ( # type: ignore[index, operator]
"tile_shape_mnk[1] must be divisible by blk_mn"
)
sSFA_shapeN = (mn_basic_block_shape, tile_shape_mnk[1] // blk_mn) # type: ignore[index, operator]
sSF_strideN = (mn_basic_block_stride, blk_elems)
assert tile_shape_mnk[2] % (blk_sf * mma_nsf) == 0, ( # type: ignore[index]
"tile_shape_mnk[2] must be divisible by blk_sf * mma_nsf"
)
sSFA_shapeK = (
k_basic_block_shape,
blk_sf // mma_nsf,
tile_shape_mnk[2] // sf_vec_size // blk_sf, # type: ignore[index, operator]
)
sSF_strideK = (
k_basic_block_stride,
mma_nsf,
tile_shape_mnk[1] // blk_mn * blk_elems, # type: ignore[index, operator]
)
sSFA_shape = (sSFA_shapeN, sSFA_shapeK)
sSFA_stride = (sSF_strideN, sSF_strideK)
smem_layout = cute.make_layout(sSFA_shape, stride=sSFA_stride)
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE)
sfb_smem_layout_staged = cute.append(
smem_layout,
cute.make_layout(
num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout))
),
)
return sfb_smem_layout_staged
@dsl_user_op
def make_tmem_layout_sfa(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: cute.Tile,
sf_vec_size: int,
smem_layout: cute.Layout,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""Make tmem layout for SFA based on:
1. SFA smem layout per stage
2. Cta tile shape m
3. tiled MMA atom thr size
4. Scale factor vector size
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The mma tiler shape
:type mma_tiler_mnk: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param smem_layout: The smem layout of SFA per stage
:type smem_layout: cute.Layout
:return: TMEM layout for SFA
:rtype: cute.Layout
"""
atom_thr_size = cute.size(tiled_mma.thr_id.shape, loc=loc, ip=ip)
cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size # type: ignore[index]
sfa_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfa(
smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size
)
return _cute_ir.static(sfa_layout_ty, loc=loc, ip=ip)
@dsl_user_op
def make_tmem_layout_sfb(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: cute.Tile,
sf_vec_size: int,
smem_layout: cute.Layout,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""Make tmem layout for SFB based on:
1. SFB smem layout per stage
2. Cta tile shape m
3. tiled MMA atom thr size
4. Scale factor vector size
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The mma tiler shape
:type mma_tiler_mnk: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param smem_layout: The smem layout of SFB per stage
:type smem_layout: cute.Layout
:return: TMEM layout for SFB
:rtype: cute.Layout
"""
atom_thr_size = cute.size(tiled_mma.thr_id.shape, loc=loc, ip=ip)
cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size # type: ignore[index]
sfb_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfb(
smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size
)
return _cute_ir.static(sfb_layout_ty, loc=loc, ip=ip)
@dataclass(frozen=True)
class Sm103BlockScaledBasicChunk:
"""
Basic scale-factor atom layout decided by tcgen05 BlockScaled MMA Ops on SM103.
Represents the fixed layout pattern for scale factors used by tcgen05
BlockScaled MMA Ops on SM103. The layout is determined by the instruction
specification and is not configurable.
"""
sf_vec_size: int
major_mode: OperandMajorMode = OperandMajorMode.K
_layout: cute.Layout = field(init=False, repr=False)
def __post_init__(self) -> None:
atom_shape: cute.Shape
atom_stride: cute.Stride
if self.major_mode == OperandMajorMode.K:
atom_shape = ((8, 4, 4), (self.sf_vec_size, 4))
atom_stride = ((16, 128, 4), (0, 1))
else:
atom_shape = ((self.sf_vec_size, 4), (8, 4, 4))
atom_stride = ((0, 1), (16, 128, 4))
object.__setattr__(
self, "_layout", cute.make_layout(shape=atom_shape, stride=atom_stride)
)
@property
def layout(self) -> cute.Layout:
return self._layout
@dsl_user_op
def sm103_make_smem_layout_sfa(
tiled_mma: cute.TiledMma,
mma_tiler: cute.Tile,
sf_vec_size: int,
num_stages: int,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""
Make SMEM layout for SFA based on:
1) Sm103BlockScaledBasicChunk, 2) MMA tiler, 3) sf_vec_size, 4) stages.
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler: The mma tiler shape
:type mma_tiler: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param num_stages: The number of stages
:type num_stages: int
:return: Smem layout for SFA
:rtype: cute.Layout
"""
mma_shape_mk = tiled_mma.partition_shape_A((mma_tiler[0], mma_tiler[2])) # type: ignore[index]
sf_atom = Sm103BlockScaledBasicChunk(sf_vec_size, tiled_mma.op.a_major_mode).layout # type: ignore[attr-defined]
k_divisor = 4 if sf_vec_size == 16 else 2
mma_sfa_tiler = (
mma_shape_mk[0][0] * mma_shape_mk[1],
mma_shape_mk[0][1] * mma_shape_mk[2] // k_divisor,
)
sfa_smem_atom_layout = cute.tiled_product(
sf_atom,
cute.make_layout(
cute.shape_div(mma_sfa_tiler, cute.product_each(sf_atom.shape))
),
)
sfa_smem_layout_staged = cute.make_layout(
shape=cute.append(sfa_smem_atom_layout.shape, num_stages),
stride=cute.append(
sfa_smem_atom_layout.stride,
cute.size(cute.filter_zeros(sfa_smem_atom_layout)),
),
)
return sfa_smem_layout_staged
@dsl_user_op
def sm103_make_smem_layout_sfb(
tiled_mma: cute.TiledMma,
mma_tiler: cute.Tile,
sf_vec_size: int,
num_stages: int,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""
Make SMEM layout for SFB based on the basic chunk, MMA tiler, sf_vec_size, stages.
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler: The mma tiler shape
:type mma_tiler: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param num_stages: The number of stages
:type num_stages: int
:return: Smem layout for SFB
:rtype: cute.Layout
"""
sf_atom = Sm103BlockScaledBasicChunk(sf_vec_size, tiled_mma.op.b_major_mode).layout # type: ignore[attr-defined]
k_divisor = 4 if sf_vec_size == 16 else 2
mma_sfb_tiler = (mma_tiler[1], mma_tiler[2] // k_divisor) # type: ignore[index, operator]
if mma_sfb_tiler[0] == 128:
sfb_smem_atom_layout = cute.tiled_product(
sf_atom,
cute.make_layout(
cute.shape_div(mma_sfb_tiler, cute.product_each(sf_atom.shape))
),
)
else:
sf_k_major_atom256 = cute.make_layout(
shape=(
(32, 4, 2),
(sf_vec_size, 4),
),
stride=(
(16, 4, mma_sfb_tiler[1] // sf_vec_size // 4 * 512),
(0, 1),
),
)
sfb_smem_atom_layout = cute.tiled_product(
sf_k_major_atom256,
cute.make_layout(
cute.shape_div(
mma_sfb_tiler, cute.product_each(sf_k_major_atom256.shape)
)
),
)
sfb_smem_layout_staged = cute.make_layout(
shape=cute.append(sfb_smem_atom_layout.shape, num_stages),
stride=cute.append(
sfb_smem_atom_layout.stride,
cute.size(cute.filter_zeros(sfb_smem_atom_layout)),
),
)
return sfb_smem_layout_staged