The checkpoint's input_scale was designed for training-time FP8 quantization, not NVFP4 activation quantization. Using it as gsa causes x/gsa to exceed the E4M3 block scale maximum (448), leading to systematic magnitude loss in every projection. This accumulates over 61 layers, compressing the logit range and producing garbage tokens. Fix: compute gsa at runtime from actual activation magnitude: gsa = max(|x|) / (6.0 * 448.0) This ensures x/gsa ≤ 2688 (the maximum representable in E4M3 block scales). Applied to: Nvfp4Linear, Nvfp4GroupedLinear, Nvfp4MoE, Nvfp4SharedExpert, Router gate
306 lines
12 KiB
Python
306 lines
12 KiB
Python
"""CuTeDSL NVFP4 Grouped Linear for wo_a (o_proj first half).
|
||
|
||
wo_a in DeepSeek V4 is a grouped matmul (bmm) with n_local_groups=8 groups.
|
||
Each group: (tokens, heads_per_group * head_dim) × (heads_per_group * head_dim, o_lora_rank) → (tokens, o_lora_rank)
|
||
|
||
The vLLM forward does this via DeepGEMM fp8_einsum with equation "bhr,hdr->bhd".
|
||
We replace it with our CuTeDSL ScaledGroupedGemm using n_local_groups as num_experts,
|
||
where every token goes to every "expert" (group).
|
||
|
||
wo_a is loaded as BF16 from our NVFP4 checkpoint, then quantized to NVFP4 here.
|
||
|
||
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
|
||
"""
|
||
|
||
import torch
|
||
|
||
from dsv4.ops.quantize import (
|
||
quantize_activation_nvfp4,
|
||
quantize_weight_to_nvfp4,
|
||
)
|
||
from dsv4.ops.layouts import (
|
||
make_b_k_major,
|
||
assemble_scales_2d_side,
|
||
assemble_scales_3d_side,
|
||
)
|
||
from dsv4.ops.gemm_runner import (
|
||
run_nvfp4_grouped_gemm,
|
||
)
|
||
from dsv4.ops.layouts import (
|
||
ceil_div as cutedsl_ceil_div,
|
||
pad_and_swizzle_single,
|
||
)
|
||
from dsv4.ops.custom_ops import register_runner, nvfp4_linear_gemm
|
||
|
||
|
||
class Nvfp4GroupedLinear:
|
||
"""Grouped NVFP4 linear for wo_a (o-projection first half).
|
||
|
||
Handles the "bhr,hdr->bhd" einsum pattern:
|
||
- o: (tokens, n_local_heads, head_dim) → reshape to (tokens, n_local_groups, heads_per_group * head_dim)
|
||
- wo_a: (n_local_groups, heads_per_group * head_dim, o_lora_rank) → NVFP4 per group
|
||
- z: (tokens, n_local_groups, o_lora_rank)
|
||
|
||
Uses ScaledGroupedGemm with num_groups=n_local_groups.
|
||
Every token goes to every group (no routing).
|
||
|
||
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
n_local_groups: int,
|
||
heads_per_group: int,
|
||
head_dim: int,
|
||
o_lora_rank: int,
|
||
max_num_tokens: int = 8192,
|
||
device: str = "cuda",
|
||
):
|
||
self.n_local_groups = n_local_groups
|
||
self.heads_per_group = heads_per_group
|
||
self.head_dim = head_dim
|
||
self.o_lora_rank = o_lora_rank
|
||
self.max_num_tokens = max_num_tokens
|
||
self.device = device
|
||
|
||
# Per-group dimensions
|
||
self.group_in_features = heads_per_group * head_dim # 8192
|
||
self.group_out_features = o_lora_rank # 1536
|
||
|
||
# NVFP4 weight storage: lists of per-group tensors
|
||
self._weight_fp4 = None # list of (K//2, N) float4_e2m1fn_x2
|
||
self._weight_sf = None # list of (K//16, N) float8_e4m3fn
|
||
self._weight_gs = None # list of float32
|
||
|
||
# Processed weights (set by finalize_weights)
|
||
self._mat_b = None
|
||
self._scale_b = None
|
||
self._gsb = None
|
||
|
||
# Activation global scale
|
||
self._activation_global_scale = 1.0 / (6.0 * 448.0)
|
||
|
||
# Pre-allocated buffers
|
||
self._padded_x_fp4_buf = None
|
||
self._gsa_buf = None
|
||
self._expert_offsets_buf = None
|
||
self._buffers_allocated = False
|
||
|
||
def set_bf16_weight(self, wo_a_bf16: torch.Tensor):
|
||
"""Set wo_a weight from BF16 and quantize to NVFP4.
|
||
|
||
Args:
|
||
wo_a_bf16: (n_local_groups * o_lora_rank, heads_per_group * head_dim) BF16
|
||
OR (n_local_groups, heads_per_group * head_dim, o_lora_rank) if from bmm
|
||
"""
|
||
# Quantize each group separately
|
||
fp4_list = []
|
||
sf_list = []
|
||
gs_list = []
|
||
|
||
if wo_a_bf16.ndim == 3:
|
||
# bmm format: (n_local_groups, heads_per_group * head_dim, o_lora_rank)
|
||
for g in range(self.n_local_groups):
|
||
w_g = wo_a_bf16[g] # (in_features, out_features)
|
||
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_g)
|
||
# quantize_weight_to_nvfp4 returns (K//2, N) with K=in_features
|
||
# Our kernel expects (K_packed, N_packed) where K is the contraction dim
|
||
# For weight (in_features, out_features): K=in_features (contraction)
|
||
# quantize_weight_to_nvfp4 treats dim 0 as K, so result is (K//2, N) ✓
|
||
fp4_list.append(w_fp4)
|
||
sf_list.append(w_sf)
|
||
gs_list.append(w_gs)
|
||
else:
|
||
# Dense format: (n_local_groups * o_lora_rank, heads_per_group * head_dim)
|
||
# Split into per-group blocks
|
||
for g in range(self.n_local_groups):
|
||
start = g * self.o_lora_rank
|
||
end = start + self.o_lora_rank
|
||
w_g = wo_a_bf16[start:end, :] # (o_lora_rank, in_features)
|
||
# NOTE: This is transposed — weight is (out, in) but quantize_weight_to_nvfp4
|
||
# expects (K, N) where K is the packed/contraction dim.
|
||
# For matmul X @ W^T, the contraction dim of W is dim 1 (in_features).
|
||
# So we need to transpose before quantizing.
|
||
w_g_t = w_g.T # (in_features, o_lora_rank) = (K, N)
|
||
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_g_t)
|
||
fp4_list.append(w_fp4)
|
||
sf_list.append(w_sf)
|
||
gs_list.append(w_gs)
|
||
|
||
self._weight_fp4 = fp4_list
|
||
self._weight_sf = sf_list
|
||
self._weight_gs = gs_list
|
||
|
||
def finalize_weights(self):
|
||
"""Process NVFP4 weights for CuTeDSL GEMM."""
|
||
if self._weight_fp4 is None:
|
||
raise RuntimeError("Call set_bf16_weight() before finalize_weights()")
|
||
|
||
self._mat_b = make_b_k_major(torch.stack(self._weight_fp4)) # (groups, K_packed, N_packed)
|
||
self._scale_b = assemble_scales_3d_side(self._weight_sf)
|
||
self._gsb = torch.tensor(self._weight_gs, dtype=torch.float32, device=self.device)
|
||
|
||
# Free raw weights
|
||
self._weight_fp4 = None
|
||
self._weight_sf = None
|
||
self._weight_gs = None
|
||
|
||
def _allocate_buffers(self):
|
||
"""Pre-allocate buffers at max size for cudagraph compatibility."""
|
||
max_rows_per_group = cutedsl_ceil_div(self.max_num_tokens, 128) * 128
|
||
total_max_rows = max_rows_per_group * self.n_local_groups
|
||
|
||
self._padded_x_fp4_buf = torch.zeros(
|
||
total_max_rows, self.group_in_features // 2, dtype=torch.uint8, device=self.device
|
||
).view(torch.float4_e2m1fn_x2)
|
||
|
||
self._gsa_buf = torch.zeros(self.n_local_groups, dtype=torch.float32, device=self.device)
|
||
self._expert_offsets_buf = torch.zeros(self.n_local_groups, dtype=torch.int32, device=self.device)
|
||
self._buffers_allocated = True
|
||
|
||
def _ensure_initialized(self):
|
||
if self._mat_b is None:
|
||
self.finalize_weights()
|
||
if not self._buffers_allocated:
|
||
self._allocate_buffers()
|
||
|
||
def _assemble_scales_single_group(self, x_sf):
|
||
"""Assemble 2D-side activation scales for num_groups=1."""
|
||
num_rows, num_cols = x_sf.shape
|
||
padded_rows = cutedsl_ceil_div(num_rows, 128) * 128
|
||
padded_cols = cutedsl_ceil_div(num_cols, 4) * 4
|
||
|
||
buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn)
|
||
buf[:num_rows, :num_cols] = x_sf
|
||
swizzled_flat = pad_and_swizzle_single(buf)
|
||
return swizzled_flat.reshape(padded_rows, padded_cols)
|
||
|
||
def compute_activation_global_scale(self, o_sample: torch.Tensor):
|
||
"""Compute activation global scale from a warmup forward.
|
||
|
||
Args:
|
||
o_sample: (tokens, n_local_heads, head_dim) BF16 attention output sample
|
||
"""
|
||
self._ensure_initialized()
|
||
# Reshape to grouped format, then flatten to 2D for quantization
|
||
o_grouped = o_sample.reshape(-1, self.n_local_groups, self.group_in_features)
|
||
# We need a single gs for all groups — use the overall amax
|
||
from dsv4.ops.quantize import (
|
||
quantize_to_nvfp4,
|
||
)
|
||
o_flat = o_sample.reshape(-1, o_sample.shape[-1]) # (tokens, n_local_heads * head_dim) — not right
|
||
# Actually, for grouped GEMM, each group's activation is (tokens, group_in_features)
|
||
# The global scale should be computed per-group, but for simplicity use one scale
|
||
# based on the overall amax.
|
||
with torch.no_grad():
|
||
_, _, gs = quantize_to_nvfp4(o_grouped.reshape(-1, self.group_in_features))
|
||
self._activation_global_scale = gs
|
||
|
||
def run(self, o: torch.Tensor) -> torch.Tensor:
|
||
"""Forward: BF16 attention output → NVFP4 grouped GEMM → BF16 z.
|
||
|
||
Args:
|
||
o: (num_tokens, n_local_heads, head_dim) BF16 — attention output
|
||
AFTER inverse RoPE has been applied
|
||
|
||
Returns:
|
||
z: (num_tokens, n_local_groups, o_lora_rank) BF16
|
||
"""
|
||
if not hasattr(self, '_runner_id'):
|
||
self._runner_id = register_runner(self)
|
||
return nvfp4_linear_gemm(
|
||
o, self._runner_id, self.n_local_groups * self.o_lora_rank,
|
||
)
|
||
|
||
def _run_impl(self, o: torch.Tensor) -> torch.Tensor:
|
||
"""Actual implementation.
|
||
|
||
Input o is (tokens, n_local_heads, head_dim).
|
||
We reshape to (tokens, n_local_groups, heads_per_group * head_dim),
|
||
then treat each group's (tokens, group_in_features) as one "expert"
|
||
in our grouped GEMM. All tokens go to all groups.
|
||
|
||
The grouped GEMM layout requires each group's tokens to be
|
||
contiguous at their correct offset:
|
||
- Group 0: rows [0, padded_T)
|
||
- Group 1: rows [padded_T, 2*padded_T)
|
||
- ...
|
||
- Group G: rows [(G-1)*padded_T, G*padded_T)
|
||
"""
|
||
self._ensure_initialized()
|
||
|
||
num_tokens = o.shape[0]
|
||
padded_rows_per_group = cutedsl_ceil_div(num_tokens, 128) * 128
|
||
|
||
# Reshape: (tokens, n_local_heads, head_dim) → (tokens, n_local_groups, group_in_features)
|
||
o_grouped = o.reshape(num_tokens, self.n_local_groups, self.group_in_features)
|
||
|
||
# Permute to groups-first: (G, T, D)
|
||
o_grouped = o_grouped.permute(1, 0, 2)
|
||
|
||
# Compute activation global scale at runtime if requested.
|
||
if getattr(self, '_use_runtime_gsa', False):
|
||
amax = o.float().abs().max().clamp(min=1e-8).item()
|
||
self._activation_global_scale = amax / (6.0 * 448.0)
|
||
|
||
# Quantize each group's activation and scatter into padded buffer
|
||
padded_x_fp4 = self._padded_x_fp4_buf
|
||
padded_x_fp4.view(torch.uint8).zero_()
|
||
|
||
# We need to collect scales for ALL groups for the GEMM
|
||
all_x_sf = []
|
||
|
||
for g in range(self.n_local_groups):
|
||
group_act = o_grouped[g] # (T, group_in_features)
|
||
|
||
# Quantize this group's activation
|
||
x_fp4_g, x_sf_g = quantize_activation_nvfp4(
|
||
group_act, self._activation_global_scale
|
||
)
|
||
|
||
# Scatter into the padded buffer at the correct offset
|
||
offset = g * padded_rows_per_group
|
||
padded_x_fp4.view(torch.uint8)[offset:offset + num_tokens] = x_fp4_g.view(torch.uint8)
|
||
|
||
all_x_sf.append(x_sf_g)
|
||
|
||
# Assemble A-side scales for all groups
|
||
# The grouped GEMM expects scales for all groups assembled together
|
||
# For 2Dx3D scenario, scale_a is assembled from per-group scale tensors
|
||
from dsv4.ops.layouts import (
|
||
assemble_scales_2d_side,
|
||
)
|
||
scale_a = assemble_scales_2d_side(all_x_sf)
|
||
|
||
# Expert offsets: cumulative [padded_T, 2*padded_T, ..., n_groups*padded_T]
|
||
expert_offsets = self._expert_offsets_buf
|
||
for g in range(self.n_local_groups):
|
||
expert_offsets[g] = (g + 1) * padded_rows_per_group
|
||
|
||
# Global scales (same for all groups)
|
||
gsa = self._gsa_buf.fill_(self._activation_global_scale)
|
||
|
||
# Run grouped GEMM
|
||
out = run_nvfp4_grouped_gemm(
|
||
mat_a=padded_x_fp4,
|
||
mat_b=self._mat_b,
|
||
scale_a=scale_a,
|
||
scale_b=self._scale_b,
|
||
expert_offsets=expert_offsets,
|
||
global_scale_a=gsa,
|
||
global_scale_b=self._gsb,
|
||
)
|
||
|
||
# Extract real outputs and reshape
|
||
# GEMM output has the same layout as mat_a: groups-first with padding
|
||
z = torch.empty(num_tokens, self.n_local_groups, self.o_lora_rank,
|
||
dtype=torch.bfloat16, device=o.device)
|
||
for g in range(self.n_local_groups):
|
||
offset = g * padded_rows_per_group
|
||
z[:, g, :] = out[offset:offset + num_tokens, :]
|
||
|
||
return z
|
||
|
||
def __call__(self, o: torch.Tensor) -> torch.Tensor:
|
||
return self.run(o)
|