Files
deepseek-v4-quant/patches/nvfp4_linear.py

134 lines
4.9 KiB
Python

"""
NVFP4 Linear Method — runs BF16 input through DeepGEMM fp8_fp4_gemm natively.
Weight format: NVFP4 (E2M1 packed int8 + UE4M3 block16 scales + float32 global scale)
Activation: BF16 → FP8 e4m3fn with UE8M0 per-token scales
GEMM: deep_gemm.fp8_fp4_gemm_nn(a=(fp8, ue8m0_scale), b=(nvfp4_packed, float32_scale))
Output: BF16
"""
import torch
import torch.nn as nn
from vllm.model_executor.layers.linear import LinearMethodBase
class NVFP4LinearMethod(LinearMethodBase):
"""Linear method that runs BF16 x NVFP4 via DeepGEMM fp8_fp4_gemm.
The layer must have:
- weight: E2M1 packed int8 (2 values per byte), shape (N, K//2)
- weight_scale: float8_e4m3fn UE4M3 block scales, shape (N, K//16)
- weight_scale_2: float32 global scale, shape (num_logical_weights,)
- input_scale: float32 activation scale (unused, dynamic quant)
"""
def create_weights(
self,
layer: nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
pass
def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Fold global scale into block scales and prepare for DeepGEMM consumption."""
w_data = layer.weight.data
device = w_data.device
if w_data.dtype not in (torch.uint8, torch.int8):
return
N = w_data.shape[0]
K = w_data.shape[1] * 2 # unpacked K
# Get block scales
sf_e4m3 = None
for attr in ("weight_scale", "weight_scale_inv"):
if hasattr(layer, attr):
sf_e4m3 = getattr(layer, attr).data
break
assert sf_e4m3 is not None
# Get global scale
if hasattr(layer, "weight_global_scale"):
global_scale = layer.weight_global_scale.data.to(torch.float32)
elif hasattr(layer, "weight_scale_2"):
ws2 = layer.weight_scale_2.data
if ws2.numel() > 1:
logical_widths = getattr(layer, 'logical_widths', None)
if logical_widths is not None and len(ws2) == len(logical_widths):
expanded = []
for i, w in enumerate(logical_widths):
expanded.append(ws2[i:i+1].expand(w))
global_scale = torch.cat(expanded).to(torch.float32).unsqueeze(1)
else:
global_scale = ws2.max().to(torch.float32)
else:
global_scale = ws2.max().to(torch.float32)
else:
global_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Fold global scale into block scales and store as float32
# (DeepGEMM fp8_fp4_gemm_nn expects float32 scales, NOT float8_e4m3fn)
sf_f32 = sf_e4m3.to(torch.float32) * global_scale
# Pad to align with gran_k=16 for DeepGEM
sf_k = sf_f32.shape[1] # K//16
gran_k = 16
aligned_k = (sf_k + gran_k - 1) // gran_k * gran_k
if aligned_k > sf_k:
# Pad the scale tensor to be aligned
sf_padded = torch.zeros(N, aligned_k, dtype=torch.float32, device=device)
sf_padded[:, :sf_k] = sf_f32
sf_f32 = sf_padded
layer.weight_scale_inv = nn.Parameter(sf_f32.contiguous(), requires_grad=False)
del sf_f32, sf_e4m3
# Ensure weight is contiguous int8, K-major (required by DeepGEMM)
if w_data.dtype == torch.uint8:
layer.weight.data = w_data.view(torch.int8).contiguous()
else:
layer.weight.data = w_data.contiguous()
# Free source attributes
for attr in ("weight_scale", "weight_scale_2", "input_scale",
"weight_global_scale", "input_global_scale",
"alpha", "input_global_scale_inv"):
if hasattr(layer, attr):
delattr(layer, attr)
def apply(
self,
layer: nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
import deep_gemm
M, K = x.shape
# Quantize activation to FP8 with UE8M0 per-token scales
x_fp8, x_sf = deep_gemm.per_token_cast_to_fp8(
x, use_ue8m0=True, use_packed_ue8m0=True)
# Weight: E2M1 packed int8 + folded float32 block scales
b_weight = layer.weight.data # (N, K//2) int8
b_sf = layer.weight_scale_inv.data # (N, K//16) float32
N = b_weight.shape[0]
d = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
# DeepGEMM fp8_fp4_gemm: A is FP8 (M, K), B is FP4 (N, K//2 packed)
# B scales are float32 with gran_k=16 (NVFP4 block size)
deep_gemm.fp8_fp4_gemm_nn(
a=(x_fp8, x_sf),
b=(b_weight, b_sf),
d=d,
recipe_b=(1, 16), # NVFP4: gran_mn=1, gran_k=16
)
return d