Blackwell swizzle CUDA kernel for CUDA graph capture
Python view operations (reshape, transpose, permute) are not graph-capturable — they cause cudaErrorStreamCaptureUnsupported. Added: - dsv4/kernels/cuda/blackwell_swizzle.cu: custom CUDA kernel for 32_4_4 swizzle - to_blocked(): detects graph capture, uses CUDA kernel instead of Python views - MoE _assemble_scales_cudagraph_safe: same treatment - Shared expert _assemble_scales_single_group: same treatment - Linear _assemble_scales_single_group: same treatment - Pre-allocated swizzled output buffers for all layers (avoids torch.empty_like) The CUDA kernel writes to a pre-allocated buffer — no per-step allocations. Eager path unchanged (still uses fast Python view operations).
This commit is contained in:
101
dsv4/kernels/cuda/blackwell_swizzle.cu
Normal file
101
dsv4/kernels/cuda/blackwell_swizzle.cu
Normal file
@@ -0,0 +1,101 @@
|
||||
/**
|
||||
* Blackwell 32_4_4 scale swizzle kernel.
|
||||
*
|
||||
* Rearranges FP8 scale factors from row-major layout to Blackwell tensor-core
|
||||
* compatible layout. This is the GPU equivalent of the Python:
|
||||
* blocks = x.view(R, 128, C, 4).permute(0, 2, 1, 3)
|
||||
* out = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16).flatten()
|
||||
*
|
||||
* The kernel writes to a pre-allocated output buffer — no per-step allocations.
|
||||
* CUDA-graph-capturable: no host-device syncs, no dynamic shapes.
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdint>
|
||||
|
||||
// Blackwell 32_4_4 swizzle: each thread handles one output element
|
||||
// Input: (rows, cols) float8_e4m3fn — rows is multiple of 128, cols is multiple of 4
|
||||
// Output: (rows, cols) float8_e4m3fn — swizzled layout
|
||||
//
|
||||
// The swizzle reorders so that:
|
||||
// For each group of 128 rows × 4 cols (a "block"):
|
||||
// - The 128 rows are divided into 32 "sub-rows" of 4 rows each
|
||||
// - The 4 cols are kept as-is
|
||||
// - The output order is: [sub-row 0 col 0..3, sub-row 1 col 0..3, ..., sub-row 31 col 0..3]
|
||||
// - Within each sub-row, the 4 rows × 4 cols = 16 elements are laid out as 32×16
|
||||
|
||||
__global__ void blackwell_swizzle_32_4_4_kernel(
|
||||
const uint8_t* __restrict__ input, // (rows, cols) in FP8
|
||||
uint8_t* __restrict__ output, // (rows, cols) swizzled FP8
|
||||
const int32_t rows,
|
||||
const int32_t cols // must be multiple of 4
|
||||
) {
|
||||
const int32_t R = rows / 128; // number of 128-row blocks
|
||||
const int32_t C = cols / 4; // number of 4-col groups
|
||||
|
||||
// Total output elements
|
||||
const int32_t total = rows * cols;
|
||||
|
||||
// Each thread handles one output element
|
||||
const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid >= total) return;
|
||||
|
||||
// Output flat index → (block_r, col_group, sub_row, col_4, row_in_sub)
|
||||
// Output layout: flatten of (R, C, 32, 4, 4, 4) → but simplified:
|
||||
// The output is organized as:
|
||||
// For each (R, C) block: 32 sub-rows × 16 elements = 512 elements per block
|
||||
// Total per block: 128 * 4 = 512 elements
|
||||
|
||||
// Decompose tid into block coordinates
|
||||
const int32_t elements_per_block = 128 * 4; // 512
|
||||
const int32_t block_idx = tid / elements_per_block;
|
||||
const int32_t within_block = tid % elements_per_block;
|
||||
|
||||
const int32_t r = block_idx / C; // row block index
|
||||
const int32_t c = block_idx % C; // col group index
|
||||
|
||||
// Within-block layout: (32 sub-rows) × (4 col_within_group) × (4 row_within_subrow)
|
||||
// But actually the swizzle is: reshape(32, 4, 4, 4) → transpose(1,2) → flatten
|
||||
// Which gives: for each (sub_row, col_4, row_in_sub):
|
||||
// output[sub_row * 16 + col_4 * 4 + row_in_sub] = input[sub_row * 4 + row_in_sub][col_4 * 4 + c_offset]
|
||||
|
||||
// Within block: 512 elements in swizzled order
|
||||
// The Python swizzle does:
|
||||
// blocks[128 rows, 4 cols] → view(32, 4, 4, 4) → permute → (32, 4, 4, 4)
|
||||
// → reshape(-1, 32, 16) → flatten
|
||||
// The output index maps to:
|
||||
// sub_row = within_block / 16
|
||||
// within_sub = within_block % 16 → (col_4, row_in_sub) = (within_sub / 4, within_sub % 4)
|
||||
|
||||
const int32_t sub_row = within_block / 16;
|
||||
const int32_t within_sub = within_block % 16;
|
||||
const int32_t col_4 = within_sub / 4;
|
||||
const int32_t row_in_sub = within_sub % 4;
|
||||
|
||||
// Map back to input coordinates
|
||||
const int32_t input_row = r * 128 + sub_row * 4 + row_in_sub;
|
||||
const int32_t input_col = c * 4 + col_4;
|
||||
|
||||
// Read input, write to output
|
||||
output[tid] = input[input_row * cols + input_col];
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_blackwell_swizzle(
|
||||
const uint8_t* input,
|
||||
uint8_t* output,
|
||||
int32_t rows,
|
||||
int32_t cols,
|
||||
cudaStream_t stream
|
||||
) {
|
||||
const int32_t total = rows * cols;
|
||||
const int32_t block_size = 256;
|
||||
const int32_t grid_size = (total + block_size - 1) / block_size;
|
||||
|
||||
blackwell_swizzle_32_4_4_kernel<<<grid_size, block_size, 0, stream>>>(
|
||||
input, output, rows, cols
|
||||
);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
@@ -2374,8 +2374,15 @@ def compute_scale_shape(
|
||||
return (padded_N, total_cols)
|
||||
|
||||
|
||||
def to_blocked(scale_2d: torch.Tensor) -> torch.Tensor:
|
||||
"""Pad and apply the Blackwell 32_4_4 scale swizzle to one raw scale tensor."""
|
||||
def to_blocked(scale_2d: torch.Tensor, out_buf: torch.Tensor = None) -> torch.Tensor:
|
||||
"""Pad and apply the Blackwell 32_4_4 scale swizzle to one raw scale tensor.
|
||||
|
||||
During CUDA graph capture, uses a custom CUDA kernel because Python
|
||||
view operations (reshape, transpose, permute) are not graph-capturable.
|
||||
The out_buf must be provided during graph capture (pre-allocated output).
|
||||
|
||||
During eager mode, uses the faster Python view path.
|
||||
"""
|
||||
if scale_2d.dim() != 2:
|
||||
raise ValueError(f"Expected 2D scale tensor, got {scale_2d.dim()}D.")
|
||||
rows, cols = scale_2d.shape
|
||||
@@ -2394,6 +2401,19 @@ def to_blocked(scale_2d: torch.Tensor) -> torch.Tensor:
|
||||
)
|
||||
padded[:rows, :cols] = scale_2d
|
||||
|
||||
# Use CUDA kernel during graph capture — Python view ops are not capturable
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("blackwell_swizzle", ["blackwell_swizzle.cu"])
|
||||
if out_buf is None:
|
||||
out_buf = torch.empty_like(padded)
|
||||
mod.blackwell_swizzle_32_4_4(
|
||||
padded.view(torch.uint8), out_buf.view(torch.uint8),
|
||||
padded_rows, padded_cols
|
||||
)
|
||||
return out_buf.view(torch.float8_e4m3fn).flatten()
|
||||
|
||||
# Eager path: Python view operations (fast, no kernel launch overhead)
|
||||
blocks = padded.view(row_blocks, 128, col_blocks, 4).permute(0, 2, 1, 3)
|
||||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
||||
return rearranged.flatten()
|
||||
|
||||
@@ -140,6 +140,9 @@ class Nvfp4Linear:
|
||||
self._gemm_out_buf = torch.zeros(
|
||||
max_padded_rows, self.out_features, dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
|
||||
# Pre-allocated swizzled scale output buffer (for CUDA graph capture)
|
||||
self._padded_x_sf_swizzled_buf = torch.zeros_like(self._scale_a_buf)
|
||||
|
||||
def _ensure_initialized(self):
|
||||
if self._mat_b is None:
|
||||
@@ -165,6 +168,18 @@ class Nvfp4Linear:
|
||||
# Pass correctly-sized VIEW to swizzle — the swizzle operates on
|
||||
# (padded_rows, padded_cols) not the full max-size buffer.
|
||||
view = buf[:padded_rows, :padded_cols]
|
||||
|
||||
# During graph capture, use CUDA swizzle kernel (Python view ops not capturable)
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("blackwell_swizzle", ["blackwell_swizzle.cu"])
|
||||
swizzled_buf = self._padded_x_sf_swizzled_buf
|
||||
mod.blackwell_swizzle_32_4_4(
|
||||
view.view(torch.uint8), swizzled_buf[:padded_rows, :padded_cols].view(torch.uint8),
|
||||
padded_rows, padded_cols
|
||||
)
|
||||
return swizzled_buf[:padded_rows, :padded_cols].reshape(padded_rows, padded_cols)
|
||||
|
||||
swizzled_flat = pad_and_swizzle_single(view)
|
||||
return swizzled_flat.reshape(padded_rows, padded_cols)
|
||||
|
||||
|
||||
@@ -161,6 +161,16 @@ class Nvfp4MoE:
|
||||
self._padded_x_sf_buf_l2 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l2']
|
||||
self._output_buf = Nvfp4MoE._shared_padded_bufs[device_key]['output']
|
||||
|
||||
# Pre-allocated swizzled scale output buffers (same size as padded_x_sf)
|
||||
# Required for CUDA graph capture — Python view ops (reshape, transpose) not capturable
|
||||
if 'xsf_swizzled_l1' not in Nvfp4MoE._shared_padded_bufs[device_key]:
|
||||
Nvfp4MoE._shared_padded_bufs[device_key].update({
|
||||
'xsf_swizzled_l1': torch.zeros_like(Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l1']),
|
||||
'xsf_swizzled_l2': torch.zeros_like(Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l2']),
|
||||
})
|
||||
self._padded_x_sf_swizzled_buf_l1 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_swizzled_l1']
|
||||
self._padded_x_sf_swizzled_buf_l2 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_swizzled_l2']
|
||||
|
||||
# Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture)
|
||||
self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
|
||||
self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
|
||||
@@ -444,11 +454,18 @@ class Nvfp4MoE:
|
||||
padded_x_sf[dst_rows, :K_sf] = x_sf
|
||||
|
||||
# Phase 2: Full-buffer swizzle (no CPU sync, no Python loops)
|
||||
# padded_x_sf is 128-row aligned per expert and 4-col aligned.
|
||||
# to_blocked: (rows, cols) → view(R, 128, C, 4) → permute(0,2,1,3)
|
||||
# → reshape(-1, 4, 32, 4) → transpose(1,2) → reshape(-1, 32, 16) → flatten
|
||||
rows = padded_x_sf.shape[0]
|
||||
cols = padded_x_sf.shape[1]
|
||||
# During graph capture, Python view ops (reshape, transpose) are not allowed.
|
||||
# Use CUDA swizzle kernel instead.
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("blackwell_swizzle", ["blackwell_swizzle.cu"])
|
||||
out_buf = self._padded_x_sf_swizzled_buf_l1 if padded_x_sf is self._padded_x_sf_buf_l1 else self._padded_x_sf_swizzled_buf_l2
|
||||
mod.blackwell_swizzle_32_4_4(
|
||||
padded_x_sf.view(torch.uint8), out_buf.view(torch.uint8),
|
||||
rows, cols
|
||||
)
|
||||
return out_buf.view(torch.float8_e4m3fn).reshape(rows, cols)
|
||||
# Eager path: Python view operations
|
||||
R = rows // 128
|
||||
C = cols // 4
|
||||
blocks = padded_x_sf.view(R, 128, C, 4).permute(0, 2, 1, 3)
|
||||
|
||||
@@ -178,11 +178,19 @@ class Nvfp4SharedExpert:
|
||||
self._padded_x_sf_buf_l2 = torch.zeros(
|
||||
max_rows, padded_cols_l2, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn)
|
||||
|
||||
# Swizzled scale output buffers (for CUDA graph capture)
|
||||
self._padded_x_sf_swizzled_buf_l1 = torch.zeros_like(self._padded_x_sf_buf_l1)
|
||||
self._padded_x_sf_swizzled_buf_l2 = torch.zeros_like(self._padded_x_sf_buf_l2)
|
||||
|
||||
# Global scale buffers
|
||||
self._l1_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
|
||||
self._l2_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
|
||||
|
||||
# Pre-allocated swizzled scale output buffers (for CUDA graph capture)
|
||||
self._padded_x_sf_swizzled_buf_l1 = None
|
||||
self._padded_x_sf_swizzled_buf_l2 = None
|
||||
|
||||
# Pre-allocated L1 output buffer for graph capture
|
||||
# L1 produces gate+up combined: 2 * intermediate_size BF16 columns
|
||||
self._l1_out_buf = torch.zeros(
|
||||
@@ -234,6 +242,19 @@ class Nvfp4SharedExpert:
|
||||
buf[:num_rows, :num_cols] = x_sf
|
||||
# Pass correctly-sized VIEW to swizzle — avoids processing the full max-size buffer
|
||||
view = buf[:padded_rows, :padded_cols]
|
||||
|
||||
# During graph capture, use CUDA swizzle kernel (Python view ops not capturable)
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
swizzled_buf = self._padded_x_sf_swizzled_buf_l1 if padded_x_sf_buf is self._padded_x_sf_buf_l1 else self._padded_x_sf_swizzled_buf_l2
|
||||
mod = get_cuda_module("blackwell_swizzle", ["blackwell_swizzle.cu"])
|
||||
mod.blackwell_swizzle_32_4_4(
|
||||
view.view(torch.uint8), swizzled_buf[:padded_rows, :padded_cols].view(torch.uint8),
|
||||
padded_rows, padded_cols
|
||||
)
|
||||
return swizzled_buf[:padded_rows, :padded_cols].reshape(padded_rows, padded_cols)
|
||||
|
||||
# Eager path: Python view operations
|
||||
swizzled_flat = pad_and_swizzle_single(view)
|
||||
return swizzled_flat.reshape(padded_rows, padded_cols)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user