Blackwell swizzle CUDA kernel for CUDA graph capture

Python view operations (reshape, transpose, permute) are not
graph-capturable — they cause cudaErrorStreamCaptureUnsupported.

Added:
- dsv4/kernels/cuda/blackwell_swizzle.cu: custom CUDA kernel for 32_4_4 swizzle
- to_blocked(): detects graph capture, uses CUDA kernel instead of Python views
- MoE _assemble_scales_cudagraph_safe: same treatment
- Shared expert _assemble_scales_single_group: same treatment
- Linear _assemble_scales_single_group: same treatment
- Pre-allocated swizzled output buffers for all layers (avoids torch.empty_like)

The CUDA kernel writes to a pre-allocated buffer — no per-step allocations.
Eager path unchanged (still uses fast Python view operations).
This commit is contained in:
2026-06-04 03:03:02 +00:00
parent e7766254b7
commit a434545d12
5 changed files with 181 additions and 7 deletions

View File

@@ -0,0 +1,101 @@
/**
* Blackwell 32_4_4 scale swizzle kernel.
*
* Rearranges FP8 scale factors from row-major layout to Blackwell tensor-core
* compatible layout. This is the GPU equivalent of the Python:
* blocks = x.view(R, 128, C, 4).permute(0, 2, 1, 3)
* out = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16).flatten()
*
* The kernel writes to a pre-allocated output buffer — no per-step allocations.
* CUDA-graph-capturable: no host-device syncs, no dynamic shapes.
*/
#include <cuda_runtime.h>
#include <cstdint>
// Blackwell 32_4_4 swizzle: each thread handles one output element
// Input: (rows, cols) float8_e4m3fn — rows is multiple of 128, cols is multiple of 4
// Output: (rows, cols) float8_e4m3fn — swizzled layout
//
// The swizzle reorders so that:
// For each group of 128 rows × 4 cols (a "block"):
// - The 128 rows are divided into 32 "sub-rows" of 4 rows each
// - The 4 cols are kept as-is
// - The output order is: [sub-row 0 col 0..3, sub-row 1 col 0..3, ..., sub-row 31 col 0..3]
// - Within each sub-row, the 4 rows × 4 cols = 16 elements are laid out as 32×16
__global__ void blackwell_swizzle_32_4_4_kernel(
const uint8_t* __restrict__ input, // (rows, cols) in FP8
uint8_t* __restrict__ output, // (rows, cols) swizzled FP8
const int32_t rows,
const int32_t cols // must be multiple of 4
) {
const int32_t R = rows / 128; // number of 128-row blocks
const int32_t C = cols / 4; // number of 4-col groups
// Total output elements
const int32_t total = rows * cols;
// Each thread handles one output element
const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= total) return;
// Output flat index → (block_r, col_group, sub_row, col_4, row_in_sub)
// Output layout: flatten of (R, C, 32, 4, 4, 4) → but simplified:
// The output is organized as:
// For each (R, C) block: 32 sub-rows × 16 elements = 512 elements per block
// Total per block: 128 * 4 = 512 elements
// Decompose tid into block coordinates
const int32_t elements_per_block = 128 * 4; // 512
const int32_t block_idx = tid / elements_per_block;
const int32_t within_block = tid % elements_per_block;
const int32_t r = block_idx / C; // row block index
const int32_t c = block_idx % C; // col group index
// Within-block layout: (32 sub-rows) × (4 col_within_group) × (4 row_within_subrow)
// But actually the swizzle is: reshape(32, 4, 4, 4) → transpose(1,2) → flatten
// Which gives: for each (sub_row, col_4, row_in_sub):
// output[sub_row * 16 + col_4 * 4 + row_in_sub] = input[sub_row * 4 + row_in_sub][col_4 * 4 + c_offset]
// Within block: 512 elements in swizzled order
// The Python swizzle does:
// blocks[128 rows, 4 cols] → view(32, 4, 4, 4) → permute → (32, 4, 4, 4)
// → reshape(-1, 32, 16) → flatten
// The output index maps to:
// sub_row = within_block / 16
// within_sub = within_block % 16 → (col_4, row_in_sub) = (within_sub / 4, within_sub % 4)
const int32_t sub_row = within_block / 16;
const int32_t within_sub = within_block % 16;
const int32_t col_4 = within_sub / 4;
const int32_t row_in_sub = within_sub % 4;
// Map back to input coordinates
const int32_t input_row = r * 128 + sub_row * 4 + row_in_sub;
const int32_t input_col = c * 4 + col_4;
// Read input, write to output
output[tid] = input[input_row * cols + input_col];
}
extern "C" {
void launch_blackwell_swizzle(
const uint8_t* input,
uint8_t* output,
int32_t rows,
int32_t cols,
cudaStream_t stream
) {
const int32_t total = rows * cols;
const int32_t block_size = 256;
const int32_t grid_size = (total + block_size - 1) / block_size;
blackwell_swizzle_32_4_4_kernel<<<grid_size, block_size, 0, stream>>>(
input, output, rows, cols
);
}
} // extern "C"

View File

@@ -2374,8 +2374,15 @@ def compute_scale_shape(
return (padded_N, total_cols)
def to_blocked(scale_2d: torch.Tensor) -> torch.Tensor:
"""Pad and apply the Blackwell 32_4_4 scale swizzle to one raw scale tensor."""
def to_blocked(scale_2d: torch.Tensor, out_buf: torch.Tensor = None) -> torch.Tensor:
"""Pad and apply the Blackwell 32_4_4 scale swizzle to one raw scale tensor.
During CUDA graph capture, uses a custom CUDA kernel because Python
view operations (reshape, transpose, permute) are not graph-capturable.
The out_buf must be provided during graph capture (pre-allocated output).
During eager mode, uses the faster Python view path.
"""
if scale_2d.dim() != 2:
raise ValueError(f"Expected 2D scale tensor, got {scale_2d.dim()}D.")
rows, cols = scale_2d.shape
@@ -2394,6 +2401,19 @@ def to_blocked(scale_2d: torch.Tensor) -> torch.Tensor:
)
padded[:rows, :cols] = scale_2d
# Use CUDA kernel during graph capture — Python view ops are not capturable
if torch.cuda.is_current_stream_capturing():
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("blackwell_swizzle", ["blackwell_swizzle.cu"])
if out_buf is None:
out_buf = torch.empty_like(padded)
mod.blackwell_swizzle_32_4_4(
padded.view(torch.uint8), out_buf.view(torch.uint8),
padded_rows, padded_cols
)
return out_buf.view(torch.float8_e4m3fn).flatten()
# Eager path: Python view operations (fast, no kernel launch overhead)
blocks = padded.view(row_blocks, 128, col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
return rearranged.flatten()

View File

@@ -140,6 +140,9 @@ class Nvfp4Linear:
self._gemm_out_buf = torch.zeros(
max_padded_rows, self.out_features, dtype=torch.bfloat16, device=self.device
)
# Pre-allocated swizzled scale output buffer (for CUDA graph capture)
self._padded_x_sf_swizzled_buf = torch.zeros_like(self._scale_a_buf)
def _ensure_initialized(self):
if self._mat_b is None:
@@ -165,6 +168,18 @@ class Nvfp4Linear:
# Pass correctly-sized VIEW to swizzle — the swizzle operates on
# (padded_rows, padded_cols) not the full max-size buffer.
view = buf[:padded_rows, :padded_cols]
# During graph capture, use CUDA swizzle kernel (Python view ops not capturable)
if torch.cuda.is_current_stream_capturing():
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("blackwell_swizzle", ["blackwell_swizzle.cu"])
swizzled_buf = self._padded_x_sf_swizzled_buf
mod.blackwell_swizzle_32_4_4(
view.view(torch.uint8), swizzled_buf[:padded_rows, :padded_cols].view(torch.uint8),
padded_rows, padded_cols
)
return swizzled_buf[:padded_rows, :padded_cols].reshape(padded_rows, padded_cols)
swizzled_flat = pad_and_swizzle_single(view)
return swizzled_flat.reshape(padded_rows, padded_cols)

View File

@@ -161,6 +161,16 @@ class Nvfp4MoE:
self._padded_x_sf_buf_l2 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l2']
self._output_buf = Nvfp4MoE._shared_padded_bufs[device_key]['output']
# Pre-allocated swizzled scale output buffers (same size as padded_x_sf)
# Required for CUDA graph capture — Python view ops (reshape, transpose) not capturable
if 'xsf_swizzled_l1' not in Nvfp4MoE._shared_padded_bufs[device_key]:
Nvfp4MoE._shared_padded_bufs[device_key].update({
'xsf_swizzled_l1': torch.zeros_like(Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l1']),
'xsf_swizzled_l2': torch.zeros_like(Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l2']),
})
self._padded_x_sf_swizzled_buf_l1 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_swizzled_l1']
self._padded_x_sf_swizzled_buf_l2 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_swizzled_l2']
# Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture)
self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
@@ -444,11 +454,18 @@ class Nvfp4MoE:
padded_x_sf[dst_rows, :K_sf] = x_sf
# Phase 2: Full-buffer swizzle (no CPU sync, no Python loops)
# padded_x_sf is 128-row aligned per expert and 4-col aligned.
# to_blocked: (rows, cols) → view(R, 128, C, 4) → permute(0,2,1,3)
# → reshape(-1, 4, 32, 4) → transpose(1,2) → reshape(-1, 32, 16) → flatten
rows = padded_x_sf.shape[0]
cols = padded_x_sf.shape[1]
# During graph capture, Python view ops (reshape, transpose) are not allowed.
# Use CUDA swizzle kernel instead.
if torch.cuda.is_current_stream_capturing():
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("blackwell_swizzle", ["blackwell_swizzle.cu"])
out_buf = self._padded_x_sf_swizzled_buf_l1 if padded_x_sf is self._padded_x_sf_buf_l1 else self._padded_x_sf_swizzled_buf_l2
mod.blackwell_swizzle_32_4_4(
padded_x_sf.view(torch.uint8), out_buf.view(torch.uint8),
rows, cols
)
return out_buf.view(torch.float8_e4m3fn).reshape(rows, cols)
# Eager path: Python view operations
R = rows // 128
C = cols // 4
blocks = padded_x_sf.view(R, 128, C, 4).permute(0, 2, 1, 3)

View File

@@ -178,11 +178,19 @@ class Nvfp4SharedExpert:
self._padded_x_sf_buf_l2 = torch.zeros(
max_rows, padded_cols_l2, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
# Swizzled scale output buffers (for CUDA graph capture)
self._padded_x_sf_swizzled_buf_l1 = torch.zeros_like(self._padded_x_sf_buf_l1)
self._padded_x_sf_swizzled_buf_l2 = torch.zeros_like(self._padded_x_sf_buf_l2)
# Global scale buffers
self._l1_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
self._l2_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
# Pre-allocated swizzled scale output buffers (for CUDA graph capture)
self._padded_x_sf_swizzled_buf_l1 = None
self._padded_x_sf_swizzled_buf_l2 = None
# Pre-allocated L1 output buffer for graph capture
# L1 produces gate+up combined: 2 * intermediate_size BF16 columns
self._l1_out_buf = torch.zeros(
@@ -234,6 +242,19 @@ class Nvfp4SharedExpert:
buf[:num_rows, :num_cols] = x_sf
# Pass correctly-sized VIEW to swizzle — avoids processing the full max-size buffer
view = buf[:padded_rows, :padded_cols]
# During graph capture, use CUDA swizzle kernel (Python view ops not capturable)
if torch.cuda.is_current_stream_capturing():
from dsv4.kernels.cuda.loader import get_cuda_module
swizzled_buf = self._padded_x_sf_swizzled_buf_l1 if padded_x_sf_buf is self._padded_x_sf_buf_l1 else self._padded_x_sf_swizzled_buf_l2
mod = get_cuda_module("blackwell_swizzle", ["blackwell_swizzle.cu"])
mod.blackwell_swizzle_32_4_4(
view.view(torch.uint8), swizzled_buf[:padded_rows, :padded_cols].view(torch.uint8),
padded_rows, padded_cols
)
return swizzled_buf[:padded_rows, :padded_cols].reshape(padded_rows, padded_cols)
# Eager path: Python view operations
swizzled_flat = pad_and_swizzle_single(view)
return swizzled_flat.reshape(padded_rows, padded_cols)