134 lines
4.9 KiB
Python
134 lines
4.9 KiB
Python
"""
|
|
NVFP4 Linear Method — runs BF16 input through DeepGEMM fp8_fp4_gemm natively.
|
|
|
|
Weight format: NVFP4 (E2M1 packed int8 + UE4M3 block16 scales + float32 global scale)
|
|
Activation: BF16 → FP8 e4m3fn with UE8M0 per-token scales
|
|
GEMM: deep_gemm.fp8_fp4_gemm_nn(a=(fp8, ue8m0_scale), b=(nvfp4_packed, float32_scale))
|
|
Output: BF16
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
from vllm.model_executor.layers.linear import LinearMethodBase
|
|
|
|
|
|
class NVFP4LinearMethod(LinearMethodBase):
|
|
"""Linear method that runs BF16 x NVFP4 via DeepGEMM fp8_fp4_gemm.
|
|
|
|
The layer must have:
|
|
- weight: E2M1 packed int8 (2 values per byte), shape (N, K//2)
|
|
- weight_scale: float8_e4m3fn UE4M3 block scales, shape (N, K//16)
|
|
- weight_scale_2: float32 global scale, shape (num_logical_weights,)
|
|
- input_scale: float32 activation scale (unused, dynamic quant)
|
|
"""
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: nn.Module,
|
|
input_size_per_partition: int,
|
|
output_partition_sizes: list[int],
|
|
input_size: int,
|
|
output_size: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
):
|
|
pass
|
|
|
|
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
|
"""Fold global scale into block scales and prepare for DeepGEMM consumption."""
|
|
w_data = layer.weight.data
|
|
device = w_data.device
|
|
|
|
if w_data.dtype not in (torch.uint8, torch.int8):
|
|
return
|
|
|
|
N = w_data.shape[0]
|
|
K = w_data.shape[1] * 2 # unpacked K
|
|
|
|
# Get block scales
|
|
sf_e4m3 = None
|
|
for attr in ("weight_scale", "weight_scale_inv"):
|
|
if hasattr(layer, attr):
|
|
sf_e4m3 = getattr(layer, attr).data
|
|
break
|
|
assert sf_e4m3 is not None
|
|
|
|
# Get global scale
|
|
if hasattr(layer, "weight_global_scale"):
|
|
global_scale = layer.weight_global_scale.data.to(torch.float32)
|
|
elif hasattr(layer, "weight_scale_2"):
|
|
ws2 = layer.weight_scale_2.data
|
|
if ws2.numel() > 1:
|
|
logical_widths = getattr(layer, 'logical_widths', None)
|
|
if logical_widths is not None and len(ws2) == len(logical_widths):
|
|
expanded = []
|
|
for i, w in enumerate(logical_widths):
|
|
expanded.append(ws2[i:i+1].expand(w))
|
|
global_scale = torch.cat(expanded).to(torch.float32).unsqueeze(1)
|
|
else:
|
|
global_scale = ws2.max().to(torch.float32)
|
|
else:
|
|
global_scale = ws2.max().to(torch.float32)
|
|
else:
|
|
global_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
|
|
|
# Fold global scale into block scales and store as float32
|
|
# (DeepGEMM fp8_fp4_gemm_nn expects float32 scales, NOT float8_e4m3fn)
|
|
sf_f32 = sf_e4m3.to(torch.float32) * global_scale
|
|
# Pad to align with gran_k=16 for DeepGEM
|
|
sf_k = sf_f32.shape[1] # K//16
|
|
gran_k = 16
|
|
aligned_k = (sf_k + gran_k - 1) // gran_k * gran_k
|
|
if aligned_k > sf_k:
|
|
# Pad the scale tensor to be aligned
|
|
sf_padded = torch.zeros(N, aligned_k, dtype=torch.float32, device=device)
|
|
sf_padded[:, :sf_k] = sf_f32
|
|
sf_f32 = sf_padded
|
|
|
|
layer.weight_scale_inv = nn.Parameter(sf_f32.contiguous(), requires_grad=False)
|
|
del sf_f32, sf_e4m3
|
|
|
|
# Ensure weight is contiguous int8, K-major (required by DeepGEMM)
|
|
if w_data.dtype == torch.uint8:
|
|
layer.weight.data = w_data.view(torch.int8).contiguous()
|
|
else:
|
|
layer.weight.data = w_data.contiguous()
|
|
|
|
# Free source attributes
|
|
for attr in ("weight_scale", "weight_scale_2", "input_scale",
|
|
"weight_global_scale", "input_global_scale",
|
|
"alpha", "input_global_scale_inv"):
|
|
if hasattr(layer, attr):
|
|
delattr(layer, attr)
|
|
|
|
def apply(
|
|
self,
|
|
layer: nn.Module,
|
|
x: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
import deep_gemm
|
|
|
|
M, K = x.shape
|
|
|
|
# Quantize activation to FP8 with UE8M0 per-token scales
|
|
x_fp8, x_sf = deep_gemm.per_token_cast_to_fp8(
|
|
x, use_ue8m0=True, use_packed_ue8m0=True)
|
|
|
|
# Weight: E2M1 packed int8 + folded float32 block scales
|
|
b_weight = layer.weight.data # (N, K//2) int8
|
|
b_sf = layer.weight_scale_inv.data # (N, K//16) float32
|
|
|
|
N = b_weight.shape[0]
|
|
d = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
|
|
|
|
# DeepGEMM fp8_fp4_gemm: A is FP8 (M, K), B is FP4 (N, K//2 packed)
|
|
# B scales are float32 with gran_k=16 (NVFP4 block size)
|
|
deep_gemm.fp8_fp4_gemm_nn(
|
|
a=(x_fp8, x_sf),
|
|
b=(b_weight, b_sf),
|
|
d=d,
|
|
recipe_b=(1, 16), # NVFP4: gran_mn=1, gran_k=16
|
|
)
|
|
|
|
return d
|