Files
nvfp4-megamoe-kernel/dsv4/layers/grouped_linear.py
biondizzle 9cbdc92744 Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

301 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""CuTeDSL NVFP4 Grouped Linear for wo_a (o_proj first half).
wo_a in DeepSeek V4 is a grouped matmul (bmm) with n_local_groups=8 groups.
Each group: (tokens, heads_per_group * head_dim) × (heads_per_group * head_dim, o_lora_rank) → (tokens, o_lora_rank)
The vLLM forward does this via DeepGEMM fp8_einsum with equation "bhr,hdr->bhd".
We replace it with our CuTeDSL ScaledGroupedGemm using n_local_groups as num_experts,
where every token goes to every "expert" (group).
wo_a is loaded as BF16 from our NVFP4 checkpoint, then quantized to NVFP4 here.
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
"""
import torch
from dsv4.ops.quantize import (
quantize_activation_nvfp4,
quantize_weight_to_nvfp4,
)
from dsv4.ops.layouts import (
make_b_k_major,
assemble_scales_2d_side,
assemble_scales_3d_side,
)
from dsv4.ops.gemm_runner import (
run_nvfp4_grouped_gemm,
)
from dsv4.ops.layouts import (
ceil_div as cutedsl_ceil_div,
pad_and_swizzle_single,
)
from dsv4.ops.custom_ops import register_runner, nvfp4_linear_gemm
class Nvfp4GroupedLinear:
"""Grouped NVFP4 linear for wo_a (o-projection first half).
Handles the "bhr,hdr->bhd" einsum pattern:
- o: (tokens, n_local_heads, head_dim) → reshape to (tokens, n_local_groups, heads_per_group * head_dim)
- wo_a: (n_local_groups, heads_per_group * head_dim, o_lora_rank) → NVFP4 per group
- z: (tokens, n_local_groups, o_lora_rank)
Uses ScaledGroupedGemm with num_groups=n_local_groups.
Every token goes to every group (no routing).
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
"""
def __init__(
self,
n_local_groups: int,
heads_per_group: int,
head_dim: int,
o_lora_rank: int,
max_num_tokens: int = 8192,
device: str = "cuda",
):
self.n_local_groups = n_local_groups
self.heads_per_group = heads_per_group
self.head_dim = head_dim
self.o_lora_rank = o_lora_rank
self.max_num_tokens = max_num_tokens
self.device = device
# Per-group dimensions
self.group_in_features = heads_per_group * head_dim # 8192
self.group_out_features = o_lora_rank # 1536
# NVFP4 weight storage: lists of per-group tensors
self._weight_fp4 = None # list of (K//2, N) float4_e2m1fn_x2
self._weight_sf = None # list of (K//16, N) float8_e4m3fn
self._weight_gs = None # list of float32
# Processed weights (set by finalize_weights)
self._mat_b = None
self._scale_b = None
self._gsb = None
# Activation global scale
self._activation_global_scale = 1.0 / (6.0 * 448.0)
# Pre-allocated buffers
self._padded_x_fp4_buf = None
self._gsa_buf = None
self._expert_offsets_buf = None
self._buffers_allocated = False
def set_bf16_weight(self, wo_a_bf16: torch.Tensor):
"""Set wo_a weight from BF16 and quantize to NVFP4.
Args:
wo_a_bf16: (n_local_groups * o_lora_rank, heads_per_group * head_dim) BF16
OR (n_local_groups, heads_per_group * head_dim, o_lora_rank) if from bmm
"""
# Quantize each group separately
fp4_list = []
sf_list = []
gs_list = []
if wo_a_bf16.ndim == 3:
# bmm format: (n_local_groups, heads_per_group * head_dim, o_lora_rank)
for g in range(self.n_local_groups):
w_g = wo_a_bf16[g] # (in_features, out_features)
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_g)
# quantize_weight_to_nvfp4 returns (K//2, N) with K=in_features
# Our kernel expects (K_packed, N_packed) where K is the contraction dim
# For weight (in_features, out_features): K=in_features (contraction)
# quantize_weight_to_nvfp4 treats dim 0 as K, so result is (K//2, N) ✓
fp4_list.append(w_fp4)
sf_list.append(w_sf)
gs_list.append(w_gs)
else:
# Dense format: (n_local_groups * o_lora_rank, heads_per_group * head_dim)
# Split into per-group blocks
for g in range(self.n_local_groups):
start = g * self.o_lora_rank
end = start + self.o_lora_rank
w_g = wo_a_bf16[start:end, :] # (o_lora_rank, in_features)
# NOTE: This is transposed — weight is (out, in) but quantize_weight_to_nvfp4
# expects (K, N) where K is the packed/contraction dim.
# For matmul X @ W^T, the contraction dim of W is dim 1 (in_features).
# So we need to transpose before quantizing.
w_g_t = w_g.T # (in_features, o_lora_rank) = (K, N)
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_g_t)
fp4_list.append(w_fp4)
sf_list.append(w_sf)
gs_list.append(w_gs)
self._weight_fp4 = fp4_list
self._weight_sf = sf_list
self._weight_gs = gs_list
def finalize_weights(self):
"""Process NVFP4 weights for CuTeDSL GEMM."""
if self._weight_fp4 is None:
raise RuntimeError("Call set_bf16_weight() before finalize_weights()")
self._mat_b = make_b_k_major(torch.stack(self._weight_fp4)) # (groups, K_packed, N_packed)
self._scale_b = assemble_scales_3d_side(self._weight_sf)
self._gsb = torch.tensor(self._weight_gs, dtype=torch.float32, device=self.device)
# Free raw weights
self._weight_fp4 = None
self._weight_sf = None
self._weight_gs = None
def _allocate_buffers(self):
"""Pre-allocate buffers at max size for cudagraph compatibility."""
max_rows_per_group = cutedsl_ceil_div(self.max_num_tokens, 128) * 128
total_max_rows = max_rows_per_group * self.n_local_groups
self._padded_x_fp4_buf = torch.zeros(
total_max_rows, self.group_in_features // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2)
self._gsa_buf = torch.zeros(self.n_local_groups, dtype=torch.float32, device=self.device)
self._expert_offsets_buf = torch.zeros(self.n_local_groups, dtype=torch.int32, device=self.device)
self._buffers_allocated = True
def _ensure_initialized(self):
if self._mat_b is None:
self.finalize_weights()
if not self._buffers_allocated:
self._allocate_buffers()
def _assemble_scales_single_group(self, x_sf):
"""Assemble 2D-side activation scales for num_groups=1."""
num_rows, num_cols = x_sf.shape
padded_rows = cutedsl_ceil_div(num_rows, 128) * 128
padded_cols = cutedsl_ceil_div(num_cols, 4) * 4
buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn)
buf[:num_rows, :num_cols] = x_sf
swizzled_flat = pad_and_swizzle_single(buf)
return swizzled_flat.reshape(padded_rows, padded_cols)
def compute_activation_global_scale(self, o_sample: torch.Tensor):
"""Compute activation global scale from a warmup forward.
Args:
o_sample: (tokens, n_local_heads, head_dim) BF16 attention output sample
"""
self._ensure_initialized()
# Reshape to grouped format, then flatten to 2D for quantization
o_grouped = o_sample.reshape(-1, self.n_local_groups, self.group_in_features)
# We need a single gs for all groups — use the overall amax
from dsv4.ops.quantize import (
quantize_to_nvfp4,
)
o_flat = o_sample.reshape(-1, o_sample.shape[-1]) # (tokens, n_local_heads * head_dim) — not right
# Actually, for grouped GEMM, each group's activation is (tokens, group_in_features)
# The global scale should be computed per-group, but for simplicity use one scale
# based on the overall amax.
with torch.no_grad():
_, _, gs = quantize_to_nvfp4(o_grouped.reshape(-1, self.group_in_features))
self._activation_global_scale = gs
def run(self, o: torch.Tensor) -> torch.Tensor:
"""Forward: BF16 attention output → NVFP4 grouped GEMM → BF16 z.
Args:
o: (num_tokens, n_local_heads, head_dim) BF16 — attention output
AFTER inverse RoPE has been applied
Returns:
z: (num_tokens, n_local_groups, o_lora_rank) BF16
"""
if not hasattr(self, '_runner_id'):
self._runner_id = register_runner(self)
return nvfp4_linear_gemm(
o, self._runner_id, self.n_local_groups * self.o_lora_rank,
)
def _run_impl(self, o: torch.Tensor) -> torch.Tensor:
"""Actual implementation.
Input o is (tokens, n_local_heads, head_dim).
We reshape to (tokens, n_local_groups, heads_per_group * head_dim),
then treat each group's (tokens, group_in_features) as one "expert"
in our grouped GEMM. All tokens go to all groups.
The grouped GEMM layout requires each group's tokens to be
contiguous at their correct offset:
- Group 0: rows [0, padded_T)
- Group 1: rows [padded_T, 2*padded_T)
- ...
- Group G: rows [(G-1)*padded_T, G*padded_T)
"""
self._ensure_initialized()
num_tokens = o.shape[0]
padded_rows_per_group = cutedsl_ceil_div(num_tokens, 128) * 128
# Reshape: (tokens, n_local_heads, head_dim) → (tokens, n_local_groups, group_in_features)
o_grouped = o.reshape(num_tokens, self.n_local_groups, self.group_in_features)
# Permute to groups-first: (G, T, D)
o_grouped = o_grouped.permute(1, 0, 2)
# Quantize each group's activation and scatter into padded buffer
padded_x_fp4 = self._padded_x_fp4_buf
padded_x_fp4.view(torch.uint8).zero_()
# We need to collect scales for ALL groups for the GEMM
all_x_sf = []
for g in range(self.n_local_groups):
group_act = o_grouped[g] # (T, group_in_features)
# Quantize this group's activation
x_fp4_g, x_sf_g = quantize_activation_nvfp4(
group_act, self._activation_global_scale
)
# Scatter into the padded buffer at the correct offset
offset = g * padded_rows_per_group
padded_x_fp4.view(torch.uint8)[offset:offset + num_tokens] = x_fp4_g.view(torch.uint8)
all_x_sf.append(x_sf_g)
# Assemble A-side scales for all groups
# The grouped GEMM expects scales for all groups assembled together
# For 2Dx3D scenario, scale_a is assembled from per-group scale tensors
from dsv4.ops.layouts import (
assemble_scales_2d_side,
)
scale_a = assemble_scales_2d_side(all_x_sf)
# Expert offsets: cumulative [padded_T, 2*padded_T, ..., n_groups*padded_T]
expert_offsets = self._expert_offsets_buf
for g in range(self.n_local_groups):
expert_offsets[g] = (g + 1) * padded_rows_per_group
# Global scales (same for all groups)
gsa = self._gsa_buf.fill_(self._activation_global_scale)
# Run grouped GEMM
out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4,
mat_b=self._mat_b,
scale_a=scale_a,
scale_b=self._scale_b,
expert_offsets=expert_offsets,
global_scale_a=gsa,
global_scale_b=self._gsb,
)
# Extract real outputs and reshape
# GEMM output has the same layout as mat_a: groups-first with padding
z = torch.empty(num_tokens, self.n_local_groups, self.o_lora_rank,
dtype=torch.bfloat16, device=o.device)
for g in range(self.n_local_groups):
offset = g * padded_rows_per_group
z[:, g, :] = out[offset:offset + num_tokens, :]
return z
def __call__(self, o: torch.Tensor) -> torch.Tensor:
return self.run(o)