Files
nvfp4-megamoe-kernel/dsv4/layers/grouped_linear.py

420 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""CuTeDSL NVFP4 Grouped Linear for wo_a (o_proj first half).
wo_a in DeepSeek V4 is a grouped matmul (bmm) with n_local_groups=8 groups.
Each group: (tokens, heads_per_group * head_dim) × (heads_per_group * head_dim, o_lora_rank) → (tokens, o_lora_rank)
The vLLM forward does this via DeepGEMM fp8_einsum with equation "bhr,hdr->bhd".
We replace it with our CuTeDSL ScaledGroupedGemm using n_local_groups as num_experts,
where every token goes to every "expert" (group).
wo_a is loaded as BF16 from our NVFP4 checkpoint, then quantized to NVFP4 here.
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
"""
import torch
from dsv4.ops.quantize import (
quantize_activation_nvfp4,
quantize_weight_to_nvfp4,
quantize_nvfp4_gpu_fused,
)
from dsv4.ops.layouts import (
make_b_k_major,
assemble_scales_2d_side,
assemble_scales_3d_side,
)
from dsv4.ops.gemm_runner import (
run_nvfp4_grouped_gemm,
)
from dsv4.ops.layouts import (
ceil_div as cutedsl_ceil_div,
pad_and_swizzle_single,
)
from dsv4.ops.custom_ops import register_runner, nvfp4_linear_gemm
class Nvfp4GroupedLinear:
"""Grouped NVFP4 linear for wo_a (o-projection first half).
Handles the "bhr,hdr->bhd" einsum pattern:
- o: (tokens, n_local_heads, head_dim) → reshape to (tokens, n_local_groups, heads_per_group * head_dim)
- wo_a: (n_local_groups, heads_per_group * head_dim, o_lora_rank) → NVFP4 per group
- z: (tokens, n_local_groups, o_lora_rank)
Uses ScaledGroupedGemm with num_groups=n_local_groups.
Every token goes to every group (no routing).
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
"""
def __init__(
self,
n_local_groups: int,
heads_per_group: int,
head_dim: int,
o_lora_rank: int,
max_num_tokens: int = 8192,
device: str = "cuda",
):
self.n_local_groups = n_local_groups
self.heads_per_group = heads_per_group
self.head_dim = head_dim
self.o_lora_rank = o_lora_rank
self.max_num_tokens = max_num_tokens
self.device = device
# Per-group dimensions
self.group_in_features = heads_per_group * head_dim # 8192
self.group_out_features = o_lora_rank # 1536
# NVFP4 weight storage: lists of per-group tensors
self._weight_fp4 = None # list of (K//2, N) float4_e2m1fn_x2
self._weight_sf = None # list of (K//16, N) float8_e4m3fn
self._weight_gs = None # list of float32
# Processed weights (set by finalize_weights)
self._mat_b = None
self._scale_b = None
self._gsb = None
# Activation global scale
self._activation_global_scale = 1.0 / (6.0 * 448.0)
# Pre-allocated buffers
self._padded_x_fp4_buf = None
self._gsa_buf = None
self._expert_offsets_buf = None
self._buffers_allocated = False
def set_bf16_weight(self, wo_a_bf16: torch.Tensor):
"""Set wo_a weight from BF16 and quantize to NVFP4.
Args:
wo_a_bf16: (n_local_groups * o_lora_rank, heads_per_group * head_dim) BF16
OR (n_local_groups, heads_per_group * head_dim, o_lora_rank) if from bmm
"""
# Quantize each group separately
fp4_list = []
sf_list = []
gs_list = []
if wo_a_bf16.ndim == 3:
# bmm format: (n_local_groups, heads_per_group * head_dim, o_lora_rank)
for g in range(self.n_local_groups):
w_g = wo_a_bf16[g] # (in_features, out_features)
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_g)
# quantize_weight_to_nvfp4 returns (K//2, N) with K=in_features
# Our kernel expects (K_packed, N_packed) where K is the contraction dim
# For weight (in_features, out_features): K=in_features (contraction)
# quantize_weight_to_nvfp4 treats dim 0 as K, so result is (K//2, N) ✓
fp4_list.append(w_fp4)
sf_list.append(w_sf)
gs_list.append(w_gs)
else:
# Dense format: (n_local_groups * o_lora_rank, heads_per_group * head_dim)
# Split into per-group blocks
for g in range(self.n_local_groups):
start = g * self.o_lora_rank
end = start + self.o_lora_rank
w_g = wo_a_bf16[start:end, :] # (o_lora_rank, in_features)
# NOTE: This is transposed — weight is (out, in) but quantize_weight_to_nvfp4
# expects (K, N) where K is the packed/contraction dim.
# For matmul X @ W^T, the contraction dim of W is dim 1 (in_features).
# So we need to transpose before quantizing.
w_g_t = w_g.T # (in_features, o_lora_rank) = (K, N)
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_g_t)
fp4_list.append(w_fp4)
sf_list.append(w_sf)
gs_list.append(w_gs)
self._weight_fp4 = fp4_list
self._weight_sf = sf_list
self._weight_gs = gs_list
def load_nvfp4_weight(self, weight, weight_scale, weight_scale_2=None, input_scale=None):
"""Load NVFP4 weights directly from checkpoint — no dequant/re-quant.
The checkpoint stores weights in (out_features, in_features) layout:
weight: (n_groups * o_rank, group_in_features // 2) uint8
weight_scale: (n_groups * o_rank, group_in_features // 16) float8_e4m3fn
weight_scale_2: scalar or (n_groups * o_rank,) float
input_scale: scalar or (n_groups * o_rank,) float (unused for weight dequant)
Each group's chunk is (o_rank, K_packed) = (N, K_packed) in row-major.
Our GEMM expects (K_packed, N) per group, so we transpose each group.
Block scales follow the same transpose.
Args:
weight: (n_groups * o_rank, group_in_features // 2) uint8
weight_scale: (n_groups * o_rank, group_in_features // 16) float8_e4m3fn
weight_scale_2: scalar or per-row scale tensor (optional)
input_scale: scalar or per-row (unused — for activation quantization)
"""
fp4_list = []
sf_list = []
gs_list = []
K_packed = self.group_in_features // 2
N = self.o_lora_rank
K_sf = self.group_in_features // 16 # block scale dim along K
for g in range(self.n_local_groups):
# Extract this group's weight: (o_rank, K_packed) = (N, K_packed)
start = g * N
end = start + N
w_g = weight[start:end] # (N, K_packed) uint8
ws_g = weight_scale[start:end] # (N, K_sf) float8_e4m3fn
# Transpose to (K_packed, N) — the layout quantize_weight_to_nvfp4 produces
w_g_t = w_g.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
ws_g_t = ws_g.permute(1, 0).contiguous()
fp4_list.append(w_g_t)
sf_list.append(ws_g_t)
# Global scale: weight_scale_2
if weight_scale_2 is not None:
if weight_scale_2.numel() == 1:
gs_list.append(weight_scale_2.float().item())
else:
# Per-row: take mean of this group's rows
gs_list.append(weight_scale_2[start:end].float().mean().item())
else:
gs_list.append(1.0)
self._weight_fp4 = fp4_list
self._weight_sf = sf_list
self._weight_gs = gs_list
def finalize_weights(self):
"""Process NVFP4 weights for CuTeDSL GEMM."""
if self._weight_fp4 is None:
raise RuntimeError("Call set_bf16_weight() before finalize_weights()")
self._mat_b = make_b_k_major(torch.stack(self._weight_fp4)) # (groups, K_packed, N_packed)
self._scale_b = assemble_scales_3d_side(self._weight_sf)
self._gsb = torch.tensor(self._weight_gs, dtype=torch.float32, device=self.device)
# Free raw weights
self._weight_fp4 = None
self._weight_sf = None
self._weight_gs = None
def _allocate_buffers(self):
"""Pre-allocate buffers at max size for cudagraph compatibility."""
max_rows_per_group = cutedsl_ceil_div(self.max_num_tokens, 128) * 128
total_max_rows = max_rows_per_group * self.n_local_groups
self._padded_x_fp4_buf = torch.zeros(
total_max_rows, self.group_in_features // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2)
self._gsa_buf = torch.zeros(self.n_local_groups, dtype=torch.float32, device=self.device)
self._expert_offsets_buf = torch.zeros(self.n_local_groups, dtype=torch.int32, device=self.device)
# Pre-computed range [1, 2, 3, ..., n_groups] for expert offsets
# Avoids torch.arange() per call (allocation) and Python loop (CPU→GPU sync)
self._expert_offsets_range_buf = torch.arange(
1, self.n_local_groups + 1, dtype=torch.int32, device=self.device
)
self._group_offset_buf = torch.zeros(self.n_local_groups, dtype=torch.int32, device=self.device)
# Pre-allocate output buffer for graph capture
self._output_buf = torch.zeros(
self.max_num_tokens, self.n_local_groups, self.o_lora_rank,
dtype=torch.bfloat16, device=self.device
)
# Pre-allocate FLAT output buffer for grouped GEMM (graph capture)
# The GEMM produces (tokens_sum, n_dim) where n_dim = o_lora_rank
# tokens_sum = n_groups * padded_rows_per_group (max = n_groups * max_num_tokens)
self._output_buf_padded = torch.zeros(
self.max_num_tokens * self.n_local_groups, self.o_lora_rank,
dtype=torch.bfloat16, device=self.device
)
# Pre-allocate scale_a swizzle buffer for graph capture
K_sf = cutedsl_ceil_div(self.group_in_features, 16)
max_padded_rows = cutedsl_ceil_div(self.max_num_tokens, 128) * 128
max_padded_cols = cutedsl_ceil_div(K_sf, 4) * 4
self._scale_a_buf = torch.zeros(
max_padded_rows, max_padded_cols, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
self._buffers_allocated = True
def _ensure_initialized(self):
if self._mat_b is None:
self.finalize_weights()
if not self._buffers_allocated:
self._allocate_buffers()
def _assemble_scales_single_group(self, x_sf):
"""Assemble 2D-side activation scales for num_groups=1.
CUDA-graph-safe: uses pre-allocated _scale_a_buf.
"""
num_rows, num_cols = x_sf.shape
padded_rows = cutedsl_ceil_div(num_rows, 128) * 128
padded_cols = cutedsl_ceil_div(num_cols, 4) * 4
# Use pre-allocated buffer — zero + scatter pattern (no new allocation)
buf = self._scale_a_buf
assert buf.shape[0] >= padded_rows and buf.shape[1] >= padded_cols, \
f"scale_a_buf too small: {buf.shape} < ({padded_rows}, {padded_cols})"
buf.view(torch.uint8).zero_()
buf[:num_rows, :num_cols] = x_sf
view = buf[:padded_rows, :padded_cols]
swizzled_flat = pad_and_swizzle_single(view)
return swizzled_flat.reshape(padded_rows, padded_cols)
def compute_activation_global_scale(self, o_sample: torch.Tensor):
"""Compute activation global scale from a warmup forward.
Args:
o_sample: (tokens, n_local_heads, head_dim) BF16 attention output sample
"""
self._ensure_initialized()
# Reshape to grouped format, then flatten to 2D for quantization
o_grouped = o_sample.reshape(-1, self.n_local_groups, self.group_in_features)
# We need a single gs for all groups — use the overall amax
from dsv4.ops.quantize import (
quantize_to_nvfp4,
)
o_flat = o_sample.reshape(-1, o_sample.shape[-1]) # (tokens, n_local_heads * head_dim) — not right
# Actually, for grouped GEMM, each group's activation is (tokens, group_in_features)
# The global scale should be computed per-group, but for simplicity use one scale
# based on the overall amax.
with torch.no_grad():
_, _, gs = quantize_to_nvfp4(o_grouped.reshape(-1, self.group_in_features))
self._activation_global_scale = gs
def run(self, o: torch.Tensor) -> torch.Tensor:
"""Forward: BF16 attention output → NVFP4 grouped GEMM → BF16 z.
Args:
o: (num_tokens, n_local_heads, head_dim) BF16 — attention output
AFTER inverse RoPE has been applied
Returns:
z: (num_tokens, n_local_groups, o_lora_rank) BF16
"""
if not hasattr(self, '_runner_id'):
self._runner_id = register_runner(self)
return nvfp4_linear_gemm(
o, self._runner_id, self.n_local_groups * self.o_lora_rank,
)
def _run_impl(self, o: torch.Tensor) -> torch.Tensor:
"""Actual implementation.
Input o is (tokens, n_local_heads, head_dim).
We reshape to (tokens, n_local_groups, heads_per_group * head_dim),
then treat each group's (tokens, group_in_features) as one "expert"
in our grouped GEMM. All tokens go to all groups.
The grouped GEMM layout requires each group's tokens to be
contiguous at their correct offset:
- Group 0: rows [0, padded_T)
- Group 1: rows [padded_T, 2*padded_T)
- ...
- Group G: rows [(G-1)*padded_T, G*padded_T)
"""
self._ensure_initialized()
num_tokens = o.shape[0]
padded_rows_per_group = cutedsl_ceil_div(num_tokens, 128) * 128
# Reshape: (tokens, n_local_heads, head_dim) → (tokens, n_local_groups, group_in_features)
o_grouped = o.reshape(num_tokens, self.n_local_groups, self.group_in_features)
# Permute to groups-first: (G, T, D)
o_grouped = o_grouped.permute(1, 0, 2)
# Flatten all groups into (G*T, D) for batched fused quantize — single kernel launch
o_flat = o_grouped.reshape(self.n_local_groups * num_tokens, self.group_in_features)
# Fused amax + quantize: zero CPU-GPU syncs.
# Computes gsa on GPU, quantizes to NVFP4, returns GPU tensor.
# Replaces the old path: .item() sync + Python quantize per group.
if getattr(self, '_use_runtime_gsa', False):
x_fp4_flat, x_sf_flat, gsa_gpu = quantize_nvfp4_gpu_fused(o_flat)
# gsa_gpu is (G*T,) — all rows share same amax (from max over full tensor)
# For the GEMM's global_scale_a, fill all group slots with the same gsa value
# Use GPU-only copy: no .item(), no CPU sync
self._gsa_buf[0] = gsa_gpu[0] # scalar GPU→GPU, no sync, graph-capturable
# Broadcast to all groups (all get same gsa)
# Use scalar broadcast assignment instead of copy_ from expanded view
# (expanded views can cause cudaErrorInvalidValue in copy_)
if self.n_local_groups > 1:
self._gsa_buf[1:] = self._gsa_buf[0] # scalar broadcast, graph-capturable
else:
self._gsa_buf.fill_(self._activation_global_scale)
x_fp4_flat, x_sf_flat = quantize_activation_nvfp4(
o_flat, self._activation_global_scale
)
# Reshape FP4 back to (G, T, D//2) and scatter into padded buffer
padded_x_fp4 = self._padded_x_fp4_buf
padded_x_fp4.view(torch.uint8).zero_()
x_fp4_grouped = x_fp4_flat.reshape(self.n_local_groups, num_tokens, self.group_in_features // 2)
# Vectorized scatter — no Python loop, no CPU→GPU sync
# Unconditionally update group offsets — GPU-only, no conditional host read.
# padded_rows_per_group is a Python int multiplied with a GPU tensor = GPU op.
group_offsets = self._group_offset_buf[:self.n_local_groups]
expert_offsets = self._expert_offsets_buf
expert_offsets[:self.n_local_groups] = self._expert_offsets_range_buf * padded_rows_per_group
# Scatter each group's x_fp4 into padded buffer
for g in range(self.n_local_groups):
offset = g * padded_rows_per_group
padded_x_fp4.view(torch.uint8)[offset:offset + num_tokens] = x_fp4_grouped[g].view(torch.uint8)
# Reshape scales back to (G, T, D//16) and assemble
x_sf_grouped = x_sf_flat.reshape(self.n_local_groups, num_tokens, self.group_in_features // 16)
all_x_sf = [x_sf_grouped[g] for g in range(self.n_local_groups)]
# Assemble A-side scales for all groups
from dsv4.ops.layouts import (
assemble_scales_2d_side,
)
scale_a = assemble_scales_2d_side(all_x_sf)
# Expert offsets: cumulative [padded_T, 2*padded_T, ..., n_groups*padded_T]
# GPU-only computation — no Python loop, no CPU→GPU sync
expert_offsets = self._expert_offsets_buf
# element-wise multiply: range * padded_rows → GPU tensor (no host sync)
expert_offsets[:self.n_local_groups] = self._expert_offsets_range_buf * padded_rows_per_group
# Global scales — GPU-computed gsa already in _gsa_buf (no CPU sync)
gsa = self._gsa_buf
# Run grouped GEMM — pass pre-allocated output buffer for CUDA graph capture
z_gem = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4,
mat_b=self._mat_b,
scale_a=scale_a,
scale_b=self._scale_b,
expert_offsets=expert_offsets,
global_scale_a=gsa,
global_scale_b=self._gsb,
out=self._output_buf_padded if hasattr(self, '_output_buf_padded') else None,
)
# Extract real outputs and reshape
# GEMM output layout: (tokens_sum, o_lora_rank) where tokens_sum = n_groups * padded_rows
# Groups are stacked vertically: group 0 at rows [0, padded_rows), group 1 at [padded_rows, 2*padded_rows), etc.
z_gem = z_gem if z_gem is not None else self._output_buf_padded
z = self._output_buf[:num_tokens]
if num_tokens == 1:
# Vectorized: gather_indices = [0, padded_T, 2*padded_T, ...] — GPU-only
gather_indices = self._expert_offsets_range_buf[:self.n_local_groups] * padded_rows_per_group - padded_rows_per_group
z_flat = z_gem[gather_indices] # (n_groups, o_lora_rank) — GPU gather
z[:, :, :] = z_flat.unsqueeze(0) # (1, n_groups, o_lora_rank)
else:
for g in range(self.n_local_groups):
offset = g * padded_rows_per_group
z[:, g, :] = z_gem[offset:offset + num_tokens, :]
return z
def __call__(self, o: torch.Tensor) -> torch.Tensor:
return self.run(o)