diff --git a/dsv4/kernels/cuda/blackwell_swizzle.cu b/dsv4/kernels/cuda/blackwell_swizzle.cu new file mode 100644 index 00000000..e3ad09e0 --- /dev/null +++ b/dsv4/kernels/cuda/blackwell_swizzle.cu @@ -0,0 +1,101 @@ +/** + * Blackwell 32_4_4 scale swizzle kernel. + * + * Rearranges FP8 scale factors from row-major layout to Blackwell tensor-core + * compatible layout. This is the GPU equivalent of the Python: + * blocks = x.view(R, 128, C, 4).permute(0, 2, 1, 3) + * out = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16).flatten() + * + * The kernel writes to a pre-allocated output buffer — no per-step allocations. + * CUDA-graph-capturable: no host-device syncs, no dynamic shapes. + */ + +#include +#include + +// Blackwell 32_4_4 swizzle: each thread handles one output element +// Input: (rows, cols) float8_e4m3fn — rows is multiple of 128, cols is multiple of 4 +// Output: (rows, cols) float8_e4m3fn — swizzled layout +// +// The swizzle reorders so that: +// For each group of 128 rows × 4 cols (a "block"): +// - The 128 rows are divided into 32 "sub-rows" of 4 rows each +// - The 4 cols are kept as-is +// - The output order is: [sub-row 0 col 0..3, sub-row 1 col 0..3, ..., sub-row 31 col 0..3] +// - Within each sub-row, the 4 rows × 4 cols = 16 elements are laid out as 32×16 + +__global__ void blackwell_swizzle_32_4_4_kernel( + const uint8_t* __restrict__ input, // (rows, cols) in FP8 + uint8_t* __restrict__ output, // (rows, cols) swizzled FP8 + const int32_t rows, + const int32_t cols // must be multiple of 4 +) { + const int32_t R = rows / 128; // number of 128-row blocks + const int32_t C = cols / 4; // number of 4-col groups + + // Total output elements + const int32_t total = rows * cols; + + // Each thread handles one output element + const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= total) return; + + // Output flat index → (block_r, col_group, sub_row, col_4, row_in_sub) + // Output layout: flatten of (R, C, 32, 4, 4, 4) → but simplified: + // The output is organized as: + // For each (R, C) block: 32 sub-rows × 16 elements = 512 elements per block + // Total per block: 128 * 4 = 512 elements + + // Decompose tid into block coordinates + const int32_t elements_per_block = 128 * 4; // 512 + const int32_t block_idx = tid / elements_per_block; + const int32_t within_block = tid % elements_per_block; + + const int32_t r = block_idx / C; // row block index + const int32_t c = block_idx % C; // col group index + + // Within-block layout: (32 sub-rows) × (4 col_within_group) × (4 row_within_subrow) + // But actually the swizzle is: reshape(32, 4, 4, 4) → transpose(1,2) → flatten + // Which gives: for each (sub_row, col_4, row_in_sub): + // output[sub_row * 16 + col_4 * 4 + row_in_sub] = input[sub_row * 4 + row_in_sub][col_4 * 4 + c_offset] + + // Within block: 512 elements in swizzled order + // The Python swizzle does: + // blocks[128 rows, 4 cols] → view(32, 4, 4, 4) → permute → (32, 4, 4, 4) + // → reshape(-1, 32, 16) → flatten + // The output index maps to: + // sub_row = within_block / 16 + // within_sub = within_block % 16 → (col_4, row_in_sub) = (within_sub / 4, within_sub % 4) + + const int32_t sub_row = within_block / 16; + const int32_t within_sub = within_block % 16; + const int32_t col_4 = within_sub / 4; + const int32_t row_in_sub = within_sub % 4; + + // Map back to input coordinates + const int32_t input_row = r * 128 + sub_row * 4 + row_in_sub; + const int32_t input_col = c * 4 + col_4; + + // Read input, write to output + output[tid] = input[input_row * cols + input_col]; +} + +extern "C" { + +void launch_blackwell_swizzle( + const uint8_t* input, + uint8_t* output, + int32_t rows, + int32_t cols, + cudaStream_t stream +) { + const int32_t total = rows * cols; + const int32_t block_size = 256; + const int32_t grid_size = (total + block_size - 1) / block_size; + + blackwell_swizzle_32_4_4_kernel<<>>( + input, output, rows, cols + ); +} + +} // extern "C" diff --git a/dsv4/kernels/gemm/grouped.py b/dsv4/kernels/gemm/grouped.py index c1f7c1d8..25dedf8c 100644 --- a/dsv4/kernels/gemm/grouped.py +++ b/dsv4/kernels/gemm/grouped.py @@ -2374,8 +2374,15 @@ def compute_scale_shape( return (padded_N, total_cols) -def to_blocked(scale_2d: torch.Tensor) -> torch.Tensor: - """Pad and apply the Blackwell 32_4_4 scale swizzle to one raw scale tensor.""" +def to_blocked(scale_2d: torch.Tensor, out_buf: torch.Tensor = None) -> torch.Tensor: + """Pad and apply the Blackwell 32_4_4 scale swizzle to one raw scale tensor. + + During CUDA graph capture, uses a custom CUDA kernel because Python + view operations (reshape, transpose, permute) are not graph-capturable. + The out_buf must be provided during graph capture (pre-allocated output). + + During eager mode, uses the faster Python view path. + """ if scale_2d.dim() != 2: raise ValueError(f"Expected 2D scale tensor, got {scale_2d.dim()}D.") rows, cols = scale_2d.shape @@ -2394,6 +2401,19 @@ def to_blocked(scale_2d: torch.Tensor) -> torch.Tensor: ) padded[:rows, :cols] = scale_2d + # Use CUDA kernel during graph capture — Python view ops are not capturable + if torch.cuda.is_current_stream_capturing(): + from dsv4.kernels.cuda.loader import get_cuda_module + mod = get_cuda_module("blackwell_swizzle", ["blackwell_swizzle.cu"]) + if out_buf is None: + out_buf = torch.empty_like(padded) + mod.blackwell_swizzle_32_4_4( + padded.view(torch.uint8), out_buf.view(torch.uint8), + padded_rows, padded_cols + ) + return out_buf.view(torch.float8_e4m3fn).flatten() + + # Eager path: Python view operations (fast, no kernel launch overhead) blocks = padded.view(row_blocks, 128, col_blocks, 4).permute(0, 2, 1, 3) rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) return rearranged.flatten() diff --git a/dsv4/layers/linear.py b/dsv4/layers/linear.py index ba3edc9f..4bc24ec9 100644 --- a/dsv4/layers/linear.py +++ b/dsv4/layers/linear.py @@ -140,6 +140,9 @@ class Nvfp4Linear: self._gemm_out_buf = torch.zeros( max_padded_rows, self.out_features, dtype=torch.bfloat16, device=self.device ) + + # Pre-allocated swizzled scale output buffer (for CUDA graph capture) + self._padded_x_sf_swizzled_buf = torch.zeros_like(self._scale_a_buf) def _ensure_initialized(self): if self._mat_b is None: @@ -165,6 +168,18 @@ class Nvfp4Linear: # Pass correctly-sized VIEW to swizzle — the swizzle operates on # (padded_rows, padded_cols) not the full max-size buffer. view = buf[:padded_rows, :padded_cols] + + # During graph capture, use CUDA swizzle kernel (Python view ops not capturable) + if torch.cuda.is_current_stream_capturing(): + from dsv4.kernels.cuda.loader import get_cuda_module + mod = get_cuda_module("blackwell_swizzle", ["blackwell_swizzle.cu"]) + swizzled_buf = self._padded_x_sf_swizzled_buf + mod.blackwell_swizzle_32_4_4( + view.view(torch.uint8), swizzled_buf[:padded_rows, :padded_cols].view(torch.uint8), + padded_rows, padded_cols + ) + return swizzled_buf[:padded_rows, :padded_cols].reshape(padded_rows, padded_cols) + swizzled_flat = pad_and_swizzle_single(view) return swizzled_flat.reshape(padded_rows, padded_cols) diff --git a/dsv4/layers/moe.py b/dsv4/layers/moe.py index 59c48369..0c682805 100644 --- a/dsv4/layers/moe.py +++ b/dsv4/layers/moe.py @@ -161,6 +161,16 @@ class Nvfp4MoE: self._padded_x_sf_buf_l2 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l2'] self._output_buf = Nvfp4MoE._shared_padded_bufs[device_key]['output'] + # Pre-allocated swizzled scale output buffers (same size as padded_x_sf) + # Required for CUDA graph capture — Python view ops (reshape, transpose) not capturable + if 'xsf_swizzled_l1' not in Nvfp4MoE._shared_padded_bufs[device_key]: + Nvfp4MoE._shared_padded_bufs[device_key].update({ + 'xsf_swizzled_l1': torch.zeros_like(Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l1']), + 'xsf_swizzled_l2': torch.zeros_like(Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l2']), + }) + self._padded_x_sf_swizzled_buf_l1 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_swizzled_l1'] + self._padded_x_sf_swizzled_buf_l2 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_swizzled_l2'] + # Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture) self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device) self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device) @@ -444,11 +454,18 @@ class Nvfp4MoE: padded_x_sf[dst_rows, :K_sf] = x_sf # Phase 2: Full-buffer swizzle (no CPU sync, no Python loops) - # padded_x_sf is 128-row aligned per expert and 4-col aligned. - # to_blocked: (rows, cols) → view(R, 128, C, 4) → permute(0,2,1,3) - # → reshape(-1, 4, 32, 4) → transpose(1,2) → reshape(-1, 32, 16) → flatten - rows = padded_x_sf.shape[0] - cols = padded_x_sf.shape[1] + # During graph capture, Python view ops (reshape, transpose) are not allowed. + # Use CUDA swizzle kernel instead. + if torch.cuda.is_current_stream_capturing(): + from dsv4.kernels.cuda.loader import get_cuda_module + mod = get_cuda_module("blackwell_swizzle", ["blackwell_swizzle.cu"]) + out_buf = self._padded_x_sf_swizzled_buf_l1 if padded_x_sf is self._padded_x_sf_buf_l1 else self._padded_x_sf_swizzled_buf_l2 + mod.blackwell_swizzle_32_4_4( + padded_x_sf.view(torch.uint8), out_buf.view(torch.uint8), + rows, cols + ) + return out_buf.view(torch.float8_e4m3fn).reshape(rows, cols) + # Eager path: Python view operations R = rows // 128 C = cols // 4 blocks = padded_x_sf.view(R, 128, C, 4).permute(0, 2, 1, 3) diff --git a/dsv4/layers/shared_expert.py b/dsv4/layers/shared_expert.py index c3d36ed3..f00aabb4 100644 --- a/dsv4/layers/shared_expert.py +++ b/dsv4/layers/shared_expert.py @@ -178,11 +178,19 @@ class Nvfp4SharedExpert: self._padded_x_sf_buf_l2 = torch.zeros( max_rows, padded_cols_l2, dtype=torch.float16, device=self.device ).to(torch.float8_e4m3fn) + + # Swizzled scale output buffers (for CUDA graph capture) + self._padded_x_sf_swizzled_buf_l1 = torch.zeros_like(self._padded_x_sf_buf_l1) + self._padded_x_sf_swizzled_buf_l2 = torch.zeros_like(self._padded_x_sf_buf_l2) # Global scale buffers self._l1_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device) self._l2_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device) + # Pre-allocated swizzled scale output buffers (for CUDA graph capture) + self._padded_x_sf_swizzled_buf_l1 = None + self._padded_x_sf_swizzled_buf_l2 = None + # Pre-allocated L1 output buffer for graph capture # L1 produces gate+up combined: 2 * intermediate_size BF16 columns self._l1_out_buf = torch.zeros( @@ -234,6 +242,19 @@ class Nvfp4SharedExpert: buf[:num_rows, :num_cols] = x_sf # Pass correctly-sized VIEW to swizzle — avoids processing the full max-size buffer view = buf[:padded_rows, :padded_cols] + + # During graph capture, use CUDA swizzle kernel (Python view ops not capturable) + if torch.cuda.is_current_stream_capturing(): + from dsv4.kernels.cuda.loader import get_cuda_module + swizzled_buf = self._padded_x_sf_swizzled_buf_l1 if padded_x_sf_buf is self._padded_x_sf_buf_l1 else self._padded_x_sf_swizzled_buf_l2 + mod = get_cuda_module("blackwell_swizzle", ["blackwell_swizzle.cu"]) + mod.blackwell_swizzle_32_4_4( + view.view(torch.uint8), swizzled_buf[:padded_rows, :padded_cols].view(torch.uint8), + padded_rows, padded_cols + ) + return swizzled_buf[:padded_rows, :padded_cols].reshape(padded_rows, padded_cols) + + # Eager path: Python view operations swizzled_flat = pad_and_swizzle_single(view) return swizzled_flat.reshape(padded_rows, padded_cols)