Refactor NVFP4 Linear utils for ModelOpt and CT (#33201)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -25,7 +25,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
cutlass_fp4_supported,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
|
||||
@@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
cutlass_fp4_supported,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
|
||||
@@ -18,7 +18,6 @@ from compressed_tensors.quantization import (
|
||||
)
|
||||
from compressed_tensors.transform import TransformConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -63,9 +62,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
should_ignore_layer,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -627,14 +623,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
if self._is_nvfp4_format(weight_quant) and self._is_nvfp4_format(
|
||||
input_quant
|
||||
):
|
||||
if cutlass_fp4_supported() or envs.VLLM_USE_NVFP4_CT_EMULATIONS:
|
||||
return CompressedTensorsW4A4Fp4()
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Current platform does not support cutlass NVFP4."
|
||||
" Running CompressedTensorsW4A16Fp4."
|
||||
)
|
||||
return CompressedTensorsW4A16Fp4(has_input_global_scale=True)
|
||||
return CompressedTensorsW4A4Fp4()
|
||||
|
||||
if self._is_fp8_w8a8(weight_quant, input_quant):
|
||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||
|
||||
@@ -81,7 +81,7 @@ class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme):
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Rename weight_packed to weight that marlin expects
|
||||
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||
del layer.weight_packed
|
||||
@@ -98,7 +98,7 @@ class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme):
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale_2=None,
|
||||
weight_global_scale=None,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
|
||||
@@ -22,8 +22,7 @@ __all__ = ["CompressedTensorsW4A16Fp4"]
|
||||
|
||||
|
||||
class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
||||
def __init__(self, has_input_global_scale: bool = False):
|
||||
self.has_input_global_scale = has_input_global_scale
|
||||
def __init__(self):
|
||||
self.group_size = 16
|
||||
|
||||
@classmethod
|
||||
@@ -79,30 +78,16 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
if self.has_input_global_scale:
|
||||
input_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("input_global_scale", input_global_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Process parameters for marlin repacking
|
||||
|
||||
# Rename weight_packed to weight that marlin expects
|
||||
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||
del layer.weight_packed
|
||||
# Rename weight_global_scale to weight_scale_2 that marlin expects
|
||||
# Note: ct stores the inverse of what is expected by the marlin kernel
|
||||
layer.weight_scale_2 = Parameter(
|
||||
1 / layer.weight_global_scale.max().to(torch.float32), requires_grad=False
|
||||
# ct stores the inverse of what is expected by the marlin kernel
|
||||
layer.weight_global_scale = Parameter(
|
||||
1.0 / layer.weight_global_scale.max().to(torch.float32), requires_grad=False
|
||||
)
|
||||
del layer.weight_global_scale
|
||||
|
||||
if self.has_input_global_scale:
|
||||
layer.input_global_scale = torch.nn.Parameter(
|
||||
layer.input_global_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
|
||||
@@ -116,7 +101,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale_2=layer.weight_scale_2,
|
||||
weight_global_scale=layer.weight_global_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
|
||||
@@ -5,75 +5,31 @@ from collections.abc import Callable
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
||||
run_nvfp4_emulations,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported,
|
||||
pad_nvfp4_activation_for_cutlass,
|
||||
pad_nvfp4_weight_for_cutlass,
|
||||
slice_nvfp4_output,
|
||||
swizzle_blockscale,
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
apply_nvfp4_linear,
|
||||
convert_to_nvfp4_linear_kernel_format,
|
||||
select_nvfp4_linear_backend,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_scaled_fp4_mm,
|
||||
has_flashinfer,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A4Fp4"]
|
||||
|
||||
|
||||
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
def __init__(self):
|
||||
self.backend = "none"
|
||||
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
|
||||
if has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
elif cutlass_fp4_supported():
|
||||
self.backend = "cutlass"
|
||||
elif envs.VLLM_USE_FBGEMM:
|
||||
self.backend = "fbgemm"
|
||||
try:
|
||||
import fbgemm_gpu # noqa: F401
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
|
||||
"Please install with: pip install fbgemm-gpu-genai"
|
||||
) from exc
|
||||
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
|
||||
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
|
||||
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
|
||||
elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
|
||||
self.backend = "cutlass"
|
||||
assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
|
||||
|
||||
if self.backend == "none":
|
||||
raise ValueError(
|
||||
"No valid NVFP4 GEMM backend found. "
|
||||
"Please check your platform capability."
|
||||
)
|
||||
|
||||
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
|
||||
self.backend = select_nvfp4_linear_backend()
|
||||
self.group_size = 16
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
|
||||
return 80
|
||||
return 100
|
||||
return 75
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -129,120 +85,40 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
)
|
||||
layer.register_parameter("input_global_scale", input_global_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
global_input_scale = layer.input_global_scale.max().to(torch.float32)
|
||||
layer.input_global_scale = Parameter(global_input_scale, requires_grad=False)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Rename CT checkpoint names to standardized names
|
||||
layer.weight = layer.weight_packed
|
||||
del layer.weight_packed
|
||||
# Process global scales (CT stores as divisors, i.e. 1/scale)
|
||||
input_global_scale_inv = layer.input_global_scale.max().to(torch.float32)
|
||||
layer.input_global_scale = Parameter(
|
||||
(1.0 / input_global_scale_inv).to(torch.float32), requires_grad=False
|
||||
)
|
||||
weight_global_scale = layer.weight_global_scale.max().to(torch.float32)
|
||||
layer.weight_global_scale = Parameter(
|
||||
layer.weight_global_scale.max().to(torch.float32), requires_grad=False
|
||||
1.0 / weight_global_scale, requires_grad=False
|
||||
)
|
||||
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
|
||||
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
|
||||
# layout but we use our own quantization so we have to call
|
||||
# shuffles ourselves.
|
||||
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
|
||||
|
||||
weight = layer.weight_packed.data
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
epilogue_tile_m = 128
|
||||
weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
|
||||
weight_scale = (
|
||||
shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
|
||||
.reshape(weight_scale.shape)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.weight_packed = Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||
if self.backend == "fbgemm":
|
||||
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8)
|
||||
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
|
||||
|
||||
# Pad weights for CUTLASS/FlashInfer kernel alignment (K and N
|
||||
# divisible by 32). fbgemm has its own layout requirements.
|
||||
if self.backend in ("cutlass", "flashinfer-cutlass"):
|
||||
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
|
||||
layer.weight_packed.data
|
||||
)
|
||||
layer.weights_padding_cols = weights_padding_cols
|
||||
layer.weight_packed = Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
layer.weights_padding_cols = 0
|
||||
layer.weight_packed = Parameter(
|
||||
layer.weight_packed.data, requires_grad=False
|
||||
)
|
||||
|
||||
# Pre-compute alpha and inverse for runtime quantization
|
||||
layer.input_global_scale_inv = Parameter(
|
||||
input_global_scale_inv, requires_grad=False
|
||||
)
|
||||
layer.alpha = Parameter(
|
||||
1 / (layer.input_global_scale * layer.weight_global_scale),
|
||||
requires_grad=False,
|
||||
layer.input_global_scale * layer.weight_global_scale, requires_grad=False
|
||||
)
|
||||
|
||||
# Convert layer to NVFP4 linear kernel format
|
||||
convert_to_nvfp4_linear_kernel_format(self.backend, layer)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
|
||||
out = run_nvfp4_emulations(
|
||||
x=x,
|
||||
input_global_scale=layer.input_global_scale,
|
||||
weight=layer.weight_packed,
|
||||
weight_scale_swizzled=layer.weight_scale,
|
||||
weight_global_scale=layer.weight_global_scale,
|
||||
)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out
|
||||
|
||||
output_dtype = x.dtype
|
||||
output_size = layer.output_size_per_partition
|
||||
output_shape = [*x.shape[:-1], output_size]
|
||||
|
||||
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(
|
||||
x,
|
||||
layer.input_global_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
return apply_nvfp4_linear(
|
||||
backend=self.backend,
|
||||
layer=layer,
|
||||
x=x,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
# Pad activations to match weight K-dimension padding
|
||||
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
|
||||
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
|
||||
|
||||
mm_args = (
|
||||
x_fp4,
|
||||
layer.weight_packed,
|
||||
x_blockscale,
|
||||
layer.weight_scale,
|
||||
layer.alpha,
|
||||
output_dtype,
|
||||
)
|
||||
if self.backend.startswith("flashinfer-"):
|
||||
backend_name = self.backend[len("flashinfer-") :]
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
|
||||
elif self.backend == "fbgemm":
|
||||
out = torch.ops.fbgemm.f4f4bf16(
|
||||
x_fp4,
|
||||
layer.weight_packed,
|
||||
x_blockscale.view(-1).view(torch.uint8),
|
||||
layer.weight_scale,
|
||||
layer.alpha,
|
||||
use_mx=False,
|
||||
).to(output_dtype)
|
||||
else:
|
||||
assert self.backend == "cutlass"
|
||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||
|
||||
# Slice output to remove N-dimension padding
|
||||
out = slice_nvfp4_output(out, output_size)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.view(*output_shape)
|
||||
|
||||
@@ -5,12 +5,9 @@ from fnmatch import fnmatch
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
@@ -66,24 +63,19 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
get_marlin_input_dtype,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear,
|
||||
is_fp4_marlin_supported,
|
||||
prepare_fp4_layer_for_marlin,
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
apply_nvfp4_linear,
|
||||
convert_to_nvfp4_linear_kernel_format,
|
||||
select_nvfp4_linear_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
cutlass_fp4_supported,
|
||||
is_layer_skipped,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kFp8StaticTokenSym,
|
||||
kNvfp4Dynamic,
|
||||
kNvfp4Static,
|
||||
pad_nvfp4_activation_for_cutlass,
|
||||
pad_nvfp4_weight_for_cutlass,
|
||||
slice_nvfp4_output,
|
||||
swizzle_blockscale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_block_fp8_supported,
|
||||
@@ -96,11 +88,6 @@ from vllm.model_executor.parameter import (
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_scaled_fp4_mm,
|
||||
has_flashinfer,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
@@ -498,7 +485,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
weight = layer.weight
|
||||
max_w_scale = layer.weight_scale.max()
|
||||
if not (layer.weight_scale == layer.weight_scale[0]).all():
|
||||
@@ -580,7 +567,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
||||
|
||||
@@ -681,7 +668,7 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
@@ -1108,32 +1095,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.marlin_input_dtype = None
|
||||
|
||||
self.backend = "none"
|
||||
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
|
||||
if current_platform.has_device_capability(100) and has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
elif cutlass_fp4_supported():
|
||||
self.backend = "cutlass"
|
||||
elif is_fp4_marlin_supported():
|
||||
self.backend = "marlin"
|
||||
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
|
||||
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
|
||||
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
|
||||
elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
|
||||
self.backend = "cutlass"
|
||||
assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
|
||||
elif envs.VLLM_NVFP4_GEMM_BACKEND == "marlin":
|
||||
self.backend = "marlin"
|
||||
assert is_fp4_marlin_supported(), f"Marlin is required for {self.backend}"
|
||||
|
||||
if self.backend == "none":
|
||||
raise ValueError(
|
||||
"No valid NVFP4 GEMM backend found. "
|
||||
"Please check your platform capability."
|
||||
)
|
||||
|
||||
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
|
||||
self.backend = select_nvfp4_linear_backend()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -1181,19 +1143,19 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# Input Weight Scale
|
||||
input_scale = PerTensorScaleParameter(
|
||||
# Input Global Scale
|
||||
input_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
layer.register_parameter("input_scale", input_global_scale)
|
||||
|
||||
# Global Weight Scale
|
||||
weight_scale_2 = PerTensorScaleParameter(
|
||||
# Weight Global Scale
|
||||
weight_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale_2", weight_scale_2)
|
||||
layer.register_parameter("weight_scale_2", weight_global_scale)
|
||||
|
||||
# Per Block Weight Scale
|
||||
weight_scale = ModelWeightParameter(
|
||||
@@ -1209,65 +1171,25 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# global scales:
|
||||
input_scale_2 = layer.input_scale.max().to(torch.float32)
|
||||
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
|
||||
|
||||
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
|
||||
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Rename ModelOpt checkpoint names to standardized names
|
||||
input_global_scale = layer.input_scale.max().to(torch.float32)
|
||||
layer.input_global_scale = Parameter(input_global_scale, requires_grad=False)
|
||||
del layer.input_scale
|
||||
weight_global_scale = layer.weight_scale_2.max().to(torch.float32)
|
||||
layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False)
|
||||
del layer.weight_scale_2
|
||||
|
||||
# Pre-compute alpha and inverse for runtime quantization
|
||||
layer.alpha = Parameter(
|
||||
layer.input_scale * layer.weight_scale_2, requires_grad=False
|
||||
layer.input_global_scale * layer.weight_global_scale, requires_grad=False
|
||||
)
|
||||
layer.input_global_scale_inv = Parameter(
|
||||
(1.0 / layer.input_global_scale).to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
# Calculate `1 / input_scale` so that we don't need to do so at runtime
|
||||
layer.input_scale_inv = Parameter(
|
||||
(1 / layer.input_scale).to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
# Swizzle the weight blockscale.
|
||||
# contracting dimension is input dimension
|
||||
# block_size = 16;
|
||||
assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
|
||||
"Weight Block scale must be represented as FP8-E4M3"
|
||||
)
|
||||
|
||||
if self.backend == "marlin":
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
del layer.alpha
|
||||
del layer.input_scale
|
||||
elif self.backend == "flashinfer-trtllm":
|
||||
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
|
||||
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
|
||||
# layout but we use our own quantization so we have to call
|
||||
# shuffles ourselves.
|
||||
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
|
||||
|
||||
weight = layer.weight.data
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
epilogue_tile_m = 128
|
||||
weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
|
||||
weight_scale = (
|
||||
shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
|
||||
.reshape(weight_scale.shape)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
# Swizzle block scales and pad the packed NVFP4 weights for kernel
|
||||
# alignment (CUTLASS/FlashInfer require K and N divisible by 32).
|
||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
|
||||
|
||||
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
|
||||
layer.weight.data
|
||||
)
|
||||
layer.weights_padding_cols = weights_padding_cols
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
# Convert layer to NVFP4 linear kernel format
|
||||
convert_to_nvfp4_linear_kernel_format(self.backend, layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -1275,63 +1197,13 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.backend == "marlin":
|
||||
return apply_fp4_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale_2=layer.weight_scale_2,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
)
|
||||
|
||||
output_dtype = x.dtype
|
||||
|
||||
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(
|
||||
x, layer.input_scale_inv, is_sf_swizzled_layout=True, backend=self.backend
|
||||
return apply_nvfp4_linear(
|
||||
backend=self.backend,
|
||||
layer=layer,
|
||||
x=x,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
# validate dtypes of quantized input, input block scale,
|
||||
# weight and weight_blockscale
|
||||
assert x_fp4.dtype == torch.uint8
|
||||
assert layer.weight.dtype == torch.uint8
|
||||
assert x_blockscale.dtype == torch.float8_e4m3fn
|
||||
assert layer.weight_scale.dtype == torch.float8_e4m3fn
|
||||
assert layer.alpha.dtype == torch.float32
|
||||
|
||||
# Pad activations to match weight K-dimension padding
|
||||
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
|
||||
output_size = layer.output_size_per_partition
|
||||
output_shape = [x.shape[0], output_size]
|
||||
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
|
||||
|
||||
mm_args = (
|
||||
x_fp4,
|
||||
layer.weight,
|
||||
x_blockscale,
|
||||
layer.weight_scale,
|
||||
layer.alpha,
|
||||
output_dtype,
|
||||
)
|
||||
|
||||
if self.backend.startswith("flashinfer-"):
|
||||
backend_name = self.backend[len("flashinfer-") :]
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
|
||||
else:
|
||||
assert self.backend == "cutlass"
|
||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||
|
||||
# Slice output to remove N-dimension padding
|
||||
out = slice_nvfp4_output(out, output_size)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.view(*output_shape)
|
||||
|
||||
|
||||
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
"""
|
||||
|
||||
@@ -15,11 +15,13 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEParallelConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
swizzle_blockscale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Dynamic,
|
||||
kNvfp4Static,
|
||||
swizzle_blockscale,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import (
|
||||
|
||||
@@ -92,7 +92,7 @@ def apply_fp4_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
weight_scale_2: torch.Tensor | None,
|
||||
weight_global_scale: torch.Tensor | None,
|
||||
workspace: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
@@ -112,7 +112,7 @@ def apply_fp4_marlin_linear(
|
||||
|
||||
inputs = reshaped_x
|
||||
a_scales = None
|
||||
is_nvfp4 = weight_scale_2 is not None
|
||||
is_nvfp4 = weight_global_scale is not None
|
||||
if input_dtype is not None and input_dtype.itemsize == 1:
|
||||
if is_nvfp4:
|
||||
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
|
||||
@@ -128,7 +128,7 @@ def apply_fp4_marlin_linear(
|
||||
b_bias=bias,
|
||||
b_scales=weight_scale,
|
||||
a_scales=a_scales,
|
||||
global_scale=weight_scale_2,
|
||||
global_scale=weight_global_scale,
|
||||
b_zeros=None,
|
||||
g_idx=None,
|
||||
perm=None,
|
||||
@@ -154,7 +154,7 @@ def prepare_fp4_layer_for_marlin(
|
||||
"performance for compute-heavy workloads."
|
||||
)
|
||||
|
||||
is_nvfp4 = hasattr(layer, "weight_scale_2")
|
||||
is_nvfp4 = hasattr(layer, "weight_global_scale")
|
||||
if input_dtype is not None and input_dtype.itemsize == 1:
|
||||
if is_nvfp4:
|
||||
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
|
||||
@@ -210,9 +210,11 @@ def prepare_fp4_layer_for_marlin(
|
||||
weight_scale = nvfp4_marlin_process_scales(weight_scale)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
|
||||
weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2)
|
||||
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False)
|
||||
weight_global_scale = layer.weight_global_scale.to(param_dtype)
|
||||
weight_global_scale = nvfp4_marlin_process_global_scale(weight_global_scale)
|
||||
layer.weight_global_scale = torch.nn.Parameter(
|
||||
weight_global_scale, requires_grad=False
|
||||
)
|
||||
else:
|
||||
weight_scale = mxfp4_marlin_process_scales(
|
||||
weight_scale, input_dtype=input_dtype
|
||||
|
||||
@@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
is_fp4_marlin_supported,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
cutlass_fp4_supported,
|
||||
)
|
||||
|
||||
|
||||
375
vllm/model_executor/layers/quantization/utils/nvfp4_utils.py
Normal file
375
vllm/model_executor/layers/quantization/utils/nvfp4_utils.py
Normal file
@@ -0,0 +1,375 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._custom_ops import (
|
||||
cutlass_scaled_fp4_mm,
|
||||
cutlass_scaled_mm_supports_fp4,
|
||||
scaled_fp4_quant,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear,
|
||||
is_fp4_marlin_supported,
|
||||
prepare_fp4_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (
|
||||
run_nvfp4_emulations,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NvFp4LinearBackend(Enum):
|
||||
VLLM_CUTLASS = "cutlass"
|
||||
FLASHINFER_CUTLASS = "flashinfer-cutlass"
|
||||
FLASHINFER_TRTLLM = "flashinfer-trtllm"
|
||||
FLASHINFER_CUDNN = "flashinfer-cudnn"
|
||||
FBGEMM = "fbgemm"
|
||||
MARLIN = "marlin"
|
||||
EMULATION = "emulation"
|
||||
|
||||
|
||||
def select_nvfp4_linear_backend() -> NvFp4LinearBackend:
|
||||
"""
|
||||
Select the best available NVFP4 GEMM backend based on environment
|
||||
configuration and platform capabilities.
|
||||
"""
|
||||
backend: NvFp4LinearBackend | None = None
|
||||
|
||||
if envs.VLLM_USE_FBGEMM:
|
||||
try:
|
||||
import fbgemm_gpu # noqa: F401
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
|
||||
"Please install with: pip install fbgemm-gpu-genai"
|
||||
) from exc
|
||||
backend = NvFp4LinearBackend.FBGEMM
|
||||
elif envs.VLLM_USE_NVFP4_CT_EMULATIONS:
|
||||
backend = NvFp4LinearBackend.EMULATION
|
||||
elif envs.VLLM_NVFP4_GEMM_BACKEND is None:
|
||||
# Auto-select best available backend
|
||||
if current_platform.has_device_capability(100) and has_flashinfer():
|
||||
backend = NvFp4LinearBackend.FLASHINFER_CUTLASS
|
||||
elif cutlass_fp4_supported():
|
||||
backend = NvFp4LinearBackend.VLLM_CUTLASS
|
||||
elif is_fp4_marlin_supported():
|
||||
backend = NvFp4LinearBackend.MARLIN
|
||||
else:
|
||||
backend = NvFp4LinearBackend(envs.VLLM_NVFP4_GEMM_BACKEND)
|
||||
|
||||
# Validate that the backend is supported
|
||||
if backend in (
|
||||
NvFp4LinearBackend.FLASHINFER_CUTLASS,
|
||||
NvFp4LinearBackend.FLASHINFER_TRTLLM,
|
||||
NvFp4LinearBackend.FLASHINFER_CUDNN,
|
||||
):
|
||||
assert has_flashinfer(), f"FlashInfer is required for {backend}"
|
||||
elif backend == NvFp4LinearBackend.VLLM_CUTLASS:
|
||||
assert cutlass_fp4_supported(), f"Cutlass is required for {backend}"
|
||||
elif backend == NvFp4LinearBackend.MARLIN:
|
||||
assert is_fp4_marlin_supported(), f"Marlin is required for {backend}"
|
||||
elif backend is None:
|
||||
raise ValueError(
|
||||
f"No NVFP4 GEMM backend selected, "
|
||||
f"available backends: {list(NvFp4LinearBackend)}"
|
||||
)
|
||||
|
||||
logger.info_once(f"Using {backend} for NVFP4 GEMM")
|
||||
return backend
|
||||
|
||||
|
||||
def prepare_weights_for_nvfp4_flashinfer_trtllm(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Prepare weights and scales for FlashInfer TRTLLM FP4 GEMM."""
|
||||
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
|
||||
|
||||
epilogue_tile_m = 128
|
||||
shuffled_weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
|
||||
shuffled_weight_scale = (
|
||||
shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
|
||||
.reshape(weight_scale.shape)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
|
||||
return shuffled_weight, shuffled_weight_scale
|
||||
|
||||
|
||||
def prepare_weights_for_nvfp4_cutlass(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
"""
|
||||
Prepare weights and scales for CUTLASS/FlashInfer-CUTLASS FP4 GEMM.
|
||||
This involves padding weights for alignment (K and N divisible by 32)
|
||||
"""
|
||||
swizzled_weight_scale = swizzle_blockscale(weight_scale)
|
||||
padded_weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(weight)
|
||||
return padded_weight, swizzled_weight_scale, weights_padding_cols
|
||||
|
||||
|
||||
def prepare_weights_for_nvfp4_fbgemm(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Prepare weights and scales for FBGEMM FP4 GEMM."""
|
||||
swizzled_weight_scale = swizzle_blockscale(weight_scale)
|
||||
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8)
|
||||
return weight, swizzled_weight_scale
|
||||
|
||||
|
||||
def convert_to_nvfp4_linear_kernel_format(
|
||||
backend: NvFp4LinearBackend,
|
||||
layer: torch.nn.Module,
|
||||
) -> None:
|
||||
"""Convert layer to NVFP4 linear kernel format."""
|
||||
|
||||
assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
|
||||
"Weight Block scale must be represented as FP8-E4M3"
|
||||
)
|
||||
|
||||
# Default to no padding
|
||||
layer.weights_padding_cols = 0
|
||||
|
||||
if backend == NvFp4LinearBackend.MARLIN:
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
elif backend == NvFp4LinearBackend.FLASHINFER_TRTLLM:
|
||||
weight, weight_scale = prepare_weights_for_nvfp4_flashinfer_trtllm(
|
||||
layer.weight.data, layer.weight_scale.data
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||
elif backend == NvFp4LinearBackend.FBGEMM:
|
||||
weight, weight_scale = prepare_weights_for_nvfp4_fbgemm(
|
||||
layer.weight.data, layer.weight_scale.data
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||
elif (
|
||||
backend == NvFp4LinearBackend.VLLM_CUTLASS
|
||||
or backend == NvFp4LinearBackend.FLASHINFER_CUTLASS
|
||||
):
|
||||
weight, weight_scale, weights_padding_cols = prepare_weights_for_nvfp4_cutlass(
|
||||
layer.weight.data, layer.weight_scale.data
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||
layer.weights_padding_cols = weights_padding_cols
|
||||
|
||||
|
||||
def apply_nvfp4_linear(
|
||||
backend: NvFp4LinearBackend,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply NVFP4 linear transformation using the specified backend.
|
||||
"""
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
weight_global_scale = layer.weight_global_scale
|
||||
input_global_scale_inv = layer.input_global_scale_inv
|
||||
alpha = layer.alpha
|
||||
output_size = layer.output_size_per_partition
|
||||
input_size = layer.input_size_per_partition
|
||||
|
||||
if backend == NvFp4LinearBackend.MARLIN:
|
||||
return apply_fp4_marlin_linear(
|
||||
input=x,
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
weight_global_scale=weight_global_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=output_size,
|
||||
size_k=input_size,
|
||||
bias=bias,
|
||||
)
|
||||
elif backend == NvFp4LinearBackend.EMULATION:
|
||||
out = run_nvfp4_emulations(
|
||||
x=x,
|
||||
input_global_scale=input_global_scale_inv,
|
||||
weight=weight,
|
||||
weight_scale_swizzled=weight_scale,
|
||||
weight_global_scale=weight_global_scale,
|
||||
)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out
|
||||
|
||||
output_dtype = x.dtype
|
||||
output_shape = [*x.shape[:-1], output_size]
|
||||
|
||||
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(
|
||||
x, input_global_scale_inv, is_sf_swizzled_layout=True, backend=backend.value
|
||||
)
|
||||
|
||||
# Validate dtypes
|
||||
assert x_fp4.dtype == torch.uint8
|
||||
assert weight.dtype == torch.uint8
|
||||
assert x_blockscale.dtype == torch.float8_e4m3fn
|
||||
# weight_scale is fp8 for most backends, but uint8 for fbgemm
|
||||
assert weight_scale.dtype in (torch.float8_e4m3fn, torch.uint8)
|
||||
assert alpha.dtype == torch.float32
|
||||
|
||||
# Pad activations to match weight K-dimension padding
|
||||
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
|
||||
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
|
||||
|
||||
# Prepare args for the matmul
|
||||
mm_args = (
|
||||
x_fp4,
|
||||
weight,
|
||||
x_blockscale,
|
||||
weight_scale,
|
||||
alpha,
|
||||
output_dtype,
|
||||
)
|
||||
|
||||
# Call the appropriate backend
|
||||
if backend.value.startswith("flashinfer-"):
|
||||
backend_name = backend.value[len("flashinfer-") :]
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
|
||||
elif backend == NvFp4LinearBackend.FBGEMM:
|
||||
out = torch.ops.fbgemm.f4f4bf16(
|
||||
x_fp4,
|
||||
weight,
|
||||
x_blockscale.view(-1).view(torch.uint8),
|
||||
weight_scale,
|
||||
alpha,
|
||||
use_mx=False,
|
||||
).to(output_dtype)
|
||||
else:
|
||||
assert backend == NvFp4LinearBackend.VLLM_CUTLASS
|
||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||
|
||||
# Slice output to remove N-dimension padding
|
||||
out = slice_nvfp4_output(out, output_size)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
|
||||
return out.view(*output_shape)
|
||||
|
||||
|
||||
def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Pad and block-interleave the FP4 block-scales so that they match the data
|
||||
layout expected by the CUTLASS / FlashInfer kernels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scale: torch.Tensor
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The swizzled tensor with the same logical shape as *scale*.
|
||||
"""
|
||||
assert scale.dtype == torch.float8_e4m3fn, (
|
||||
"swizzle_blockscale expects the input tensor to be in "
|
||||
"torch.float8_e4m3fn format."
|
||||
)
|
||||
|
||||
scale_ndim = scale.ndim
|
||||
if scale_ndim == 2:
|
||||
scale = scale.unsqueeze(0) # (1, M, K)
|
||||
assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales."
|
||||
|
||||
B, M, K = scale.shape
|
||||
|
||||
M_padded = round_up(M, 128)
|
||||
K_padded = round_up(K, 4)
|
||||
|
||||
padded = torch.zeros(
|
||||
(B, M_padded, K_padded), dtype=scale.dtype, device=scale.device
|
||||
)
|
||||
padded[:B, :M, :K] = scale
|
||||
|
||||
# Reshape / permute to the layout required by the kernel.
|
||||
padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4)
|
||||
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
|
||||
|
||||
if scale_ndim == 2:
|
||||
return swizzled.reshape(M_padded, K_padded)
|
||||
return swizzled.reshape(B, M_padded, K_padded)
|
||||
|
||||
|
||||
def cutlass_fp4_supported() -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
return cutlass_scaled_mm_supports_fp4(capability)
|
||||
|
||||
|
||||
def pad_nvfp4_weight_for_cutlass(
|
||||
weight: torch.Tensor,
|
||||
alignment: int = 32,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""
|
||||
Pad packed NVFP4 weights so that both N (rows) and K (columns) satisfy
|
||||
the alignment constraints required by CUTLASS / FlashInfer FP4 kernels.
|
||||
|
||||
CUTLASS FP4 kernel requires both K and N matrix dimensions to be divisible
|
||||
by 32 for aligned memory access and efficient tensor core operations.
|
||||
"""
|
||||
weight_current_rows = weight.shape[0]
|
||||
|
||||
# Pad N dimension (rows) if not aligned
|
||||
if weight_current_rows % alignment != 0:
|
||||
total_rows = round_up(weight_current_rows, alignment)
|
||||
pad_rows = total_rows - weight_current_rows
|
||||
weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_rows)).contiguous()
|
||||
|
||||
# Check K dimension alignment
|
||||
# 2 FP4 items are packed per byte in the input dimension
|
||||
weight_current_col_bytes = weight.shape[1]
|
||||
weight_current_col_elements = weight_current_col_bytes * 2
|
||||
|
||||
weights_padding_bytes = 0
|
||||
if weight_current_col_elements % alignment != 0:
|
||||
total_cols = round_up(weight_current_col_elements, alignment)
|
||||
pad_cols = total_cols - weight_current_col_elements
|
||||
# Convert from FP4 element count to bytes (2 FP4 values per byte)
|
||||
# pad_cols is always even since alignment=32 and current elements are even
|
||||
pad_bytes = pad_cols // 2
|
||||
weight = torch.nn.functional.pad(weight, (0, pad_bytes, 0, 0)).contiguous()
|
||||
weights_padding_bytes = pad_bytes
|
||||
|
||||
return weight, weights_padding_bytes
|
||||
|
||||
|
||||
def pad_nvfp4_activation_for_cutlass(
|
||||
x_fp4: torch.Tensor,
|
||||
weights_padding_bytes: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pad packed FP4 activations to match the K-dimension padding applied to weights.
|
||||
The padding is in bytes (tensor dimension), not FP4 elements.
|
||||
"""
|
||||
if weights_padding_bytes > 0:
|
||||
return torch.nn.functional.pad(x_fp4, (0, weights_padding_bytes)).contiguous()
|
||||
return x_fp4
|
||||
|
||||
|
||||
def slice_nvfp4_output(
|
||||
out: torch.Tensor,
|
||||
output_size: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Slice the output tensor to remove padding in N dimension if weight was padded.
|
||||
"""
|
||||
if out.shape[-1] != output_size:
|
||||
return out[..., :output_size].contiguous()
|
||||
return out
|
||||
@@ -11,7 +11,6 @@ import numpy
|
||||
import torch
|
||||
from torch import fx
|
||||
|
||||
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
@@ -768,60 +767,6 @@ def awq_pack(
|
||||
return pack_cols(q_w, num_bits, size_k, size_n)
|
||||
|
||||
|
||||
def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Pad and block-interleave the FP4 block-scales so that they match the data
|
||||
layout expected by the CUTLASS / FlashInfer kernels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scale: torch.Tensor
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The swizzled tensor with the same logical shape as *scale*.
|
||||
"""
|
||||
assert scale.dtype == torch.float8_e4m3fn, (
|
||||
"swizzle_blockscale expects the input tensor to be in "
|
||||
"torch.float8_e4m3fn format."
|
||||
)
|
||||
|
||||
scale_ndim = scale.ndim
|
||||
if scale_ndim == 2:
|
||||
scale = scale.unsqueeze(0) # (1, M, K)
|
||||
assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales."
|
||||
|
||||
B, M, K = scale.shape
|
||||
|
||||
def _round_up(x: int, m: int) -> int:
|
||||
return (x + m - 1) // m * m
|
||||
|
||||
M_padded = _round_up(M, 128)
|
||||
K_padded = _round_up(K, 4)
|
||||
|
||||
padded = torch.zeros(
|
||||
(B, M_padded, K_padded), dtype=scale.dtype, device=scale.device
|
||||
)
|
||||
padded[:B, :M, :K] = scale
|
||||
|
||||
# Reshape / permute to the layout required by the kernel.
|
||||
padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4)
|
||||
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
|
||||
|
||||
if scale_ndim == 2:
|
||||
return swizzled.reshape(M_padded, K_padded)
|
||||
return swizzled.reshape(B, M_padded, K_padded)
|
||||
|
||||
|
||||
def cutlass_fp4_supported() -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
return cutlass_scaled_mm_supports_fp4(capability)
|
||||
|
||||
|
||||
def convert_bf16_scales_to_fp8(
|
||||
quant_fp8: Callable, scales: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -868,70 +813,3 @@ def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tens
|
||||
t |= ((nib - 8) & 0xF) << shift
|
||||
|
||||
return t
|
||||
|
||||
|
||||
def round_up(x: int, m: int) -> int:
|
||||
"""Round up x to the nearest multiple of m."""
|
||||
return (x + m - 1) // m * m
|
||||
|
||||
|
||||
def pad_nvfp4_weight_for_cutlass(
|
||||
weight: torch.Tensor,
|
||||
alignment: int = 32,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""
|
||||
Pad packed NVFP4 weights so that both N (rows) and K (columns) satisfy
|
||||
the alignment constraints required by CUTLASS / FlashInfer FP4 kernels.
|
||||
|
||||
CUTLASS FP4 kernel requires both K and N matrix dimensions to be divisible
|
||||
by 32 for aligned memory access and efficient tensor core operations.
|
||||
"""
|
||||
weight_current_rows = weight.shape[0]
|
||||
|
||||
# Pad N dimension (rows) if not aligned
|
||||
if weight_current_rows % alignment != 0:
|
||||
total_rows = round_up(weight_current_rows, alignment)
|
||||
pad_rows = total_rows - weight_current_rows
|
||||
weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_rows)).contiguous()
|
||||
|
||||
# Check K dimension alignment
|
||||
# 2 FP4 items are packed per byte in the input dimension
|
||||
weight_current_col_bytes = weight.shape[1]
|
||||
weight_current_col_elements = weight_current_col_bytes * 2
|
||||
|
||||
weights_padding_bytes = 0
|
||||
if weight_current_col_elements % alignment != 0:
|
||||
total_cols = round_up(weight_current_col_elements, alignment)
|
||||
pad_cols = total_cols - weight_current_col_elements
|
||||
# Convert from FP4 element count to bytes (2 FP4 values per byte)
|
||||
# pad_cols is always even since alignment=32 and current elements are even
|
||||
pad_bytes = pad_cols // 2
|
||||
weight = torch.nn.functional.pad(weight, (0, pad_bytes, 0, 0)).contiguous()
|
||||
weights_padding_bytes = pad_bytes
|
||||
|
||||
return weight, weights_padding_bytes
|
||||
|
||||
|
||||
def pad_nvfp4_activation_for_cutlass(
|
||||
x_fp4: torch.Tensor,
|
||||
weights_padding_bytes: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pad packed FP4 activations to match the K-dimension padding applied to weights.
|
||||
The padding is in bytes (tensor dimension), not FP4 elements.
|
||||
"""
|
||||
if weights_padding_bytes > 0:
|
||||
return torch.nn.functional.pad(x_fp4, (0, weights_padding_bytes)).contiguous()
|
||||
return x_fp4
|
||||
|
||||
|
||||
def slice_nvfp4_output(
|
||||
out: torch.Tensor,
|
||||
output_size: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Slice the output tensor to remove padding in N dimension if weight was padded.
|
||||
"""
|
||||
if out.shape[-1] != output_size:
|
||||
return out[..., :output_size].contiguous()
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user