diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index d215d2ab6..981f99342 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -25,7 +25,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( +from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( cutlass_fp4_supported, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index d9d7f5e2f..795591ec3 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso ) from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp -from vllm.model_executor.layers.quantization.utils.quant_utils import ( +from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( cutlass_fp4_supported, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 390247364..df3d733b7 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -18,7 +18,6 @@ from compressed_tensors.quantization import ( ) from compressed_tensors.transform import TransformConfig -import vllm.envs as envs from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -63,9 +62,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( should_ignore_layer, ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - cutlass_fp4_supported, -) from vllm.platforms import current_platform if TYPE_CHECKING: @@ -627,14 +623,7 @@ class CompressedTensorsConfig(QuantizationConfig): if self._is_nvfp4_format(weight_quant) and self._is_nvfp4_format( input_quant ): - if cutlass_fp4_supported() or envs.VLLM_USE_NVFP4_CT_EMULATIONS: - return CompressedTensorsW4A4Fp4() - else: - logger.warning_once( - "Current platform does not support cutlass NVFP4." - " Running CompressedTensorsW4A16Fp4." - ) - return CompressedTensorsW4A16Fp4(has_input_global_scale=True) + return CompressedTensorsW4A4Fp4() if self._is_fp8_w8a8(weight_quant, input_quant): is_fp8_w8a8_supported = self._check_scheme_supported( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py index 1c76adebe..77cea0f83 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py @@ -81,7 +81,7 @@ class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme): ) layer.register_parameter("weight_scale", weight_scale) - def process_weights_after_loading(self, layer) -> None: + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Rename weight_packed to weight that marlin expects layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) del layer.weight_packed @@ -98,7 +98,7 @@ class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme): input=x, weight=layer.weight, weight_scale=layer.weight_scale, - weight_scale_2=None, + weight_global_scale=None, workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py index d0a924471..87ef9162a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py @@ -22,8 +22,7 @@ __all__ = ["CompressedTensorsW4A16Fp4"] class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): - def __init__(self, has_input_global_scale: bool = False): - self.has_input_global_scale = has_input_global_scale + def __init__(self): self.group_size = 16 @classmethod @@ -79,30 +78,16 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): layer.register_parameter("weight_scale", weight_scale) - if self.has_input_global_scale: - input_global_scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader, - ) - layer.register_parameter("input_global_scale", input_global_scale) - - def process_weights_after_loading(self, layer) -> None: + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Process parameters for marlin repacking # Rename weight_packed to weight that marlin expects layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) del layer.weight_packed - # Rename weight_global_scale to weight_scale_2 that marlin expects - # Note: ct stores the inverse of what is expected by the marlin kernel - layer.weight_scale_2 = Parameter( - 1 / layer.weight_global_scale.max().to(torch.float32), requires_grad=False + # ct stores the inverse of what is expected by the marlin kernel + layer.weight_global_scale = Parameter( + 1.0 / layer.weight_global_scale.max().to(torch.float32), requires_grad=False ) - del layer.weight_global_scale - - if self.has_input_global_scale: - layer.input_global_scale = torch.nn.Parameter( - layer.input_global_scale.data, requires_grad=False - ) prepare_fp4_layer_for_marlin(layer) @@ -116,7 +101,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): input=x, weight=layer.weight, weight_scale=layer.weight_scale, - weight_scale_2=layer.weight_scale_2, + weight_global_scale=layer.weight_global_scale, workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 7ed769bad..a3b53626b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -5,75 +5,31 @@ from collections.abc import Callable import torch from torch.nn.parameter import Parameter -import vllm.envs as envs -from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant -from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) -from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 - run_nvfp4_emulations, -) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - cutlass_fp4_supported, - pad_nvfp4_activation_for_cutlass, - pad_nvfp4_weight_for_cutlass, - slice_nvfp4_output, - swizzle_blockscale, +from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( + apply_nvfp4_linear, + convert_to_nvfp4_linear_kernel_format, + select_nvfp4_linear_backend, ) from vllm.model_executor.parameter import ( GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter, ) -from vllm.utils.flashinfer import ( - flashinfer_scaled_fp4_mm, - has_flashinfer, -) - -logger = init_logger(__name__) __all__ = ["CompressedTensorsW4A4Fp4"] class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): def __init__(self): - self.backend = "none" - if envs.VLLM_NVFP4_GEMM_BACKEND is None: - if has_flashinfer(): - self.backend = "flashinfer-cutlass" - elif cutlass_fp4_supported(): - self.backend = "cutlass" - elif envs.VLLM_USE_FBGEMM: - self.backend = "fbgemm" - try: - import fbgemm_gpu # noqa: F401 - except ImportError as exc: - raise ImportError( - "Backend fbgemm requires fbgemm.f4f4bf16 operator, " - "Please install with: pip install fbgemm-gpu-genai" - ) from exc - elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"): - self.backend = envs.VLLM_NVFP4_GEMM_BACKEND - assert has_flashinfer(), f"FlashInfer is required for {self.backend}" - elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass": - self.backend = "cutlass" - assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}" - - if self.backend == "none": - raise ValueError( - "No valid NVFP4 GEMM backend found. " - "Please check your platform capability." - ) - - logger.info_once(f"Using {self.backend} for NVFP4 GEMM") + self.backend = select_nvfp4_linear_backend() self.group_size = 16 @classmethod def get_min_capability(cls) -> int: - if envs.VLLM_USE_NVFP4_CT_EMULATIONS: - return 80 - return 100 + return 75 def create_weights( self, @@ -129,120 +85,40 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ) layer.register_parameter("input_global_scale", input_global_scale) - def process_weights_after_loading(self, layer) -> None: - global_input_scale = layer.input_global_scale.max().to(torch.float32) - layer.input_global_scale = Parameter(global_input_scale, requires_grad=False) - + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Rename CT checkpoint names to standardized names + layer.weight = layer.weight_packed + del layer.weight_packed + # Process global scales (CT stores as divisors, i.e. 1/scale) + input_global_scale_inv = layer.input_global_scale.max().to(torch.float32) + layer.input_global_scale = Parameter( + (1.0 / input_global_scale_inv).to(torch.float32), requires_grad=False + ) + weight_global_scale = layer.weight_global_scale.max().to(torch.float32) layer.weight_global_scale = Parameter( - layer.weight_global_scale.max().to(torch.float32), requires_grad=False + 1.0 / weight_global_scale, requires_grad=False ) - if self.backend == "flashinfer-trtllm": - # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. - # FlashInfer provides nvfp4_quantize to quantize + shuffle the - # layout but we use our own quantization so we have to call - # shuffles ourselves. - from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a - - weight = layer.weight_packed.data - weight_scale = layer.weight_scale.data - - epilogue_tile_m = 128 - weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) - weight_scale = ( - shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m) - .reshape(weight_scale.shape) - .view(torch.float8_e4m3fn) - ) - - layer.weight_scale = Parameter(weight_scale, requires_grad=False) - layer.weight_packed = Parameter(weight, requires_grad=False) - else: - swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) - if self.backend == "fbgemm": - swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8) - layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) - - # Pad weights for CUTLASS/FlashInfer kernel alignment (K and N - # divisible by 32). fbgemm has its own layout requirements. - if self.backend in ("cutlass", "flashinfer-cutlass"): - weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass( - layer.weight_packed.data - ) - layer.weights_padding_cols = weights_padding_cols - layer.weight_packed = Parameter(weight, requires_grad=False) - else: - layer.weights_padding_cols = 0 - layer.weight_packed = Parameter( - layer.weight_packed.data, requires_grad=False - ) - + # Pre-compute alpha and inverse for runtime quantization + layer.input_global_scale_inv = Parameter( + input_global_scale_inv, requires_grad=False + ) layer.alpha = Parameter( - 1 / (layer.input_global_scale * layer.weight_global_scale), - requires_grad=False, + layer.input_global_scale * layer.weight_global_scale, requires_grad=False ) + # Convert layer to NVFP4 linear kernel format + convert_to_nvfp4_linear_kernel_format(self.backend, layer) + def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - if envs.VLLM_USE_NVFP4_CT_EMULATIONS: - out = run_nvfp4_emulations( - x=x, - input_global_scale=layer.input_global_scale, - weight=layer.weight_packed, - weight_scale_swizzled=layer.weight_scale, - weight_global_scale=layer.weight_global_scale, - ) - if bias is not None: - out = out + bias - return out - - output_dtype = x.dtype - output_size = layer.output_size_per_partition - output_shape = [*x.shape[:-1], output_size] - - # quantize BF16 or FP16 to (FP4 and interleaved block scale) - x_fp4, x_blockscale = scaled_fp4_quant( - x, - layer.input_global_scale, - is_sf_swizzled_layout=True, + return apply_nvfp4_linear( backend=self.backend, + layer=layer, + x=x, + bias=bias, ) - - # Pad activations to match weight K-dimension padding - weights_padding_cols = getattr(layer, "weights_padding_cols", 0) - x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols) - - mm_args = ( - x_fp4, - layer.weight_packed, - x_blockscale, - layer.weight_scale, - layer.alpha, - output_dtype, - ) - if self.backend.startswith("flashinfer-"): - backend_name = self.backend[len("flashinfer-") :] - out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name) - elif self.backend == "fbgemm": - out = torch.ops.fbgemm.f4f4bf16( - x_fp4, - layer.weight_packed, - x_blockscale.view(-1).view(torch.uint8), - layer.weight_scale, - layer.alpha, - use_mx=False, - ).to(output_dtype) - else: - assert self.backend == "cutlass" - out = cutlass_scaled_fp4_mm(*mm_args) - - # Slice output to remove N-dimension padding - out = slice_nvfp4_output(out, output_size) - - if bias is not None: - out = out + bias - return out.view(*output_shape) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 9a9480ffe..e76c109ec 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -5,12 +5,9 @@ from fnmatch import fnmatch from typing import TYPE_CHECKING, Any import torch -from torch.nn import Module from torch.nn.parameter import Parameter -import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe.config import ( @@ -66,24 +63,19 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import ( get_marlin_input_dtype, ) -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - apply_fp4_marlin_linear, - is_fp4_marlin_supported, - prepare_fp4_layer_for_marlin, +from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( + apply_nvfp4_linear, + convert_to_nvfp4_linear_kernel_format, + select_nvfp4_linear_backend, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, - cutlass_fp4_supported, is_layer_skipped, kFp8DynamicTokenSym, kFp8StaticTensorSym, kFp8StaticTokenSym, kNvfp4Dynamic, kNvfp4Static, - pad_nvfp4_activation_for_cutlass, - pad_nvfp4_weight_for_cutlass, - slice_nvfp4_output, - swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_block_fp8_supported, @@ -96,11 +88,6 @@ from vllm.model_executor.parameter import ( PerTensorScaleParameter, ) from vllm.model_executor.utils import replace_parameter -from vllm.platforms import current_platform -from vllm.utils.flashinfer import ( - flashinfer_scaled_fp4_mm, - has_flashinfer, -) if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper @@ -498,7 +485,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase): scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", scale) - def process_weights_after_loading(self, layer: Module) -> None: + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: weight = layer.weight max_w_scale = layer.weight_scale.max() if not (layer.weight_scale == layer.weight_scale[0]).all(): @@ -580,7 +567,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase): weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) - def process_weights_after_loading(self, layer: Module) -> None: + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = Parameter(layer.weight.t(), requires_grad=False) layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) @@ -681,7 +668,7 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase): weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) - def process_weights_after_loading(self, layer: Module) -> None: + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Keep weight in [out, in] layout for W8A8BlockFp8LinearOp. layer.weight = Parameter(layer.weight.data, requires_grad=False) @@ -1108,32 +1095,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptNvFp4Config) -> None: self.quant_config = quant_config self.marlin_input_dtype = None - - self.backend = "none" - if envs.VLLM_NVFP4_GEMM_BACKEND is None: - if current_platform.has_device_capability(100) and has_flashinfer(): - self.backend = "flashinfer-cutlass" - elif cutlass_fp4_supported(): - self.backend = "cutlass" - elif is_fp4_marlin_supported(): - self.backend = "marlin" - elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"): - self.backend = envs.VLLM_NVFP4_GEMM_BACKEND - assert has_flashinfer(), f"FlashInfer is required for {self.backend}" - elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass": - self.backend = "cutlass" - assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}" - elif envs.VLLM_NVFP4_GEMM_BACKEND == "marlin": - self.backend = "marlin" - assert is_fp4_marlin_supported(), f"Marlin is required for {self.backend}" - - if self.backend == "none": - raise ValueError( - "No valid NVFP4 GEMM backend found. " - "Please check your platform capability." - ) - - logger.info_once(f"Using {self.backend} for NVFP4 GEMM") + self.backend = select_nvfp4_linear_backend() def create_weights( self, @@ -1181,19 +1143,19 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ) layer.register_parameter("weight", weight) - # Input Weight Scale - input_scale = PerTensorScaleParameter( + # Input Global Scale + input_global_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) - layer.register_parameter("input_scale", input_scale) + layer.register_parameter("input_scale", input_global_scale) - # Global Weight Scale - weight_scale_2 = PerTensorScaleParameter( + # Weight Global Scale + weight_global_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) - layer.register_parameter("weight_scale_2", weight_scale_2) + layer.register_parameter("weight_scale_2", weight_global_scale) # Per Block Weight Scale weight_scale = ModelWeightParameter( @@ -1209,65 +1171,25 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): layer.register_parameter("weight_scale", weight_scale) - def process_weights_after_loading(self, layer: Module) -> None: - # global scales: - input_scale_2 = layer.input_scale.max().to(torch.float32) - layer.input_scale = Parameter(input_scale_2, requires_grad=False) - - weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) - layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Rename ModelOpt checkpoint names to standardized names + input_global_scale = layer.input_scale.max().to(torch.float32) + layer.input_global_scale = Parameter(input_global_scale, requires_grad=False) + del layer.input_scale + weight_global_scale = layer.weight_scale_2.max().to(torch.float32) + layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False) + del layer.weight_scale_2 + # Pre-compute alpha and inverse for runtime quantization layer.alpha = Parameter( - layer.input_scale * layer.weight_scale_2, requires_grad=False + layer.input_global_scale * layer.weight_global_scale, requires_grad=False + ) + layer.input_global_scale_inv = Parameter( + (1.0 / layer.input_global_scale).to(torch.float32), requires_grad=False ) - # Calculate `1 / input_scale` so that we don't need to do so at runtime - layer.input_scale_inv = Parameter( - (1 / layer.input_scale).to(torch.float32), requires_grad=False - ) - - # Swizzle the weight blockscale. - # contracting dimension is input dimension - # block_size = 16; - assert layer.weight_scale.dtype == torch.float8_e4m3fn, ( - "Weight Block scale must be represented as FP8-E4M3" - ) - - if self.backend == "marlin": - prepare_fp4_layer_for_marlin(layer) - del layer.alpha - del layer.input_scale - elif self.backend == "flashinfer-trtllm": - # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. - # FlashInfer provides nvfp4_quantize to quantize + shuffle the - # layout but we use our own quantization so we have to call - # shuffles ourselves. - from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a - - weight = layer.weight.data - weight_scale = layer.weight_scale.data - - epilogue_tile_m = 128 - weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) - weight_scale = ( - shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m) - .reshape(weight_scale.shape) - .view(torch.float8_e4m3fn) - ) - - layer.weight_scale = Parameter(weight_scale, requires_grad=False) - layer.weight = Parameter(weight, requires_grad=False) - else: - # Swizzle block scales and pad the packed NVFP4 weights for kernel - # alignment (CUTLASS/FlashInfer require K and N divisible by 32). - swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) - layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) - - weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass( - layer.weight.data - ) - layer.weights_padding_cols = weights_padding_cols - layer.weight = Parameter(weight, requires_grad=False) + # Convert layer to NVFP4 linear kernel format + convert_to_nvfp4_linear_kernel_format(self.backend, layer) def apply( self, @@ -1275,63 +1197,13 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - if self.backend == "marlin": - return apply_fp4_marlin_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - weight_scale_2=layer.weight_scale_2, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias, - input_dtype=self.marlin_input_dtype, - ) - - output_dtype = x.dtype - - # quantize BF16 or FP16 to (FP4 and interleaved block scale) - x_fp4, x_blockscale = scaled_fp4_quant( - x, layer.input_scale_inv, is_sf_swizzled_layout=True, backend=self.backend + return apply_nvfp4_linear( + backend=self.backend, + layer=layer, + x=x, + bias=bias, ) - # validate dtypes of quantized input, input block scale, - # weight and weight_blockscale - assert x_fp4.dtype == torch.uint8 - assert layer.weight.dtype == torch.uint8 - assert x_blockscale.dtype == torch.float8_e4m3fn - assert layer.weight_scale.dtype == torch.float8_e4m3fn - assert layer.alpha.dtype == torch.float32 - - # Pad activations to match weight K-dimension padding - weights_padding_cols = getattr(layer, "weights_padding_cols", 0) - output_size = layer.output_size_per_partition - output_shape = [x.shape[0], output_size] - x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols) - - mm_args = ( - x_fp4, - layer.weight, - x_blockscale, - layer.weight_scale, - layer.alpha, - output_dtype, - ) - - if self.backend.startswith("flashinfer-"): - backend_name = self.backend[len("flashinfer-") :] - out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name) - else: - assert self.backend == "cutlass" - out = cutlass_scaled_fp4_mm(*mm_args) - - # Slice output to remove N-dimension padding - out = slice_nvfp4_output(out, output_size) - - if bias is not None: - out = out + bias - return out.view(*output_shape) - class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): """ diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index ae5a934fb..6f3d19e09 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -15,11 +15,13 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEParallelConfig, RoutingMethodType, ) +from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( + swizzle_blockscale, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kNvfp4Dynamic, kNvfp4Static, - swizzle_blockscale, ) from vllm.platforms import current_platform from vllm.utils.flashinfer import ( diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 789ed5dba..41d529393 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -92,7 +92,7 @@ def apply_fp4_marlin_linear( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, - weight_scale_2: torch.Tensor | None, + weight_global_scale: torch.Tensor | None, workspace: torch.Tensor, size_n: int, size_k: int, @@ -112,7 +112,7 @@ def apply_fp4_marlin_linear( inputs = reshaped_x a_scales = None - is_nvfp4 = weight_scale_2 is not None + is_nvfp4 = weight_global_scale is not None if input_dtype is not None and input_dtype.itemsize == 1: if is_nvfp4: raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.") @@ -128,7 +128,7 @@ def apply_fp4_marlin_linear( b_bias=bias, b_scales=weight_scale, a_scales=a_scales, - global_scale=weight_scale_2, + global_scale=weight_global_scale, b_zeros=None, g_idx=None, perm=None, @@ -154,7 +154,7 @@ def prepare_fp4_layer_for_marlin( "performance for compute-heavy workloads." ) - is_nvfp4 = hasattr(layer, "weight_scale_2") + is_nvfp4 = hasattr(layer, "weight_global_scale") if input_dtype is not None and input_dtype.itemsize == 1: if is_nvfp4: raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.") @@ -210,9 +210,11 @@ def prepare_fp4_layer_for_marlin( weight_scale = nvfp4_marlin_process_scales(weight_scale) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - weight_scale_2 = layer.weight_scale_2.to(param_dtype) - weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2) - layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False) + weight_global_scale = layer.weight_global_scale.to(param_dtype) + weight_global_scale = nvfp4_marlin_process_global_scale(weight_global_scale) + layer.weight_global_scale = torch.nn.Parameter( + weight_global_scale, requires_grad=False + ) else: weight_scale = mxfp4_marlin_process_scales( weight_scale, input_dtype=input_dtype diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py index 44c5b027d..199a81c42 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( is_fp4_marlin_supported, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( +from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( cutlass_fp4_supported, ) diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_utils.py new file mode 100644 index 000000000..8b2549be0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_utils.py @@ -0,0 +1,375 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum + +import torch + +import vllm.envs as envs +from vllm._custom_ops import ( + cutlass_scaled_fp4_mm, + cutlass_scaled_mm_supports_fp4, + scaled_fp4_quant, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + is_fp4_marlin_supported, + prepare_fp4_layer_for_marlin, +) +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( + run_nvfp4_emulations, +) +from vllm.platforms import current_platform +from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer +from vllm.utils.math_utils import round_up + +logger = init_logger(__name__) + + +class NvFp4LinearBackend(Enum): + VLLM_CUTLASS = "cutlass" + FLASHINFER_CUTLASS = "flashinfer-cutlass" + FLASHINFER_TRTLLM = "flashinfer-trtllm" + FLASHINFER_CUDNN = "flashinfer-cudnn" + FBGEMM = "fbgemm" + MARLIN = "marlin" + EMULATION = "emulation" + + +def select_nvfp4_linear_backend() -> NvFp4LinearBackend: + """ + Select the best available NVFP4 GEMM backend based on environment + configuration and platform capabilities. + """ + backend: NvFp4LinearBackend | None = None + + if envs.VLLM_USE_FBGEMM: + try: + import fbgemm_gpu # noqa: F401 + except ImportError as exc: + raise ImportError( + "Backend fbgemm requires fbgemm.f4f4bf16 operator, " + "Please install with: pip install fbgemm-gpu-genai" + ) from exc + backend = NvFp4LinearBackend.FBGEMM + elif envs.VLLM_USE_NVFP4_CT_EMULATIONS: + backend = NvFp4LinearBackend.EMULATION + elif envs.VLLM_NVFP4_GEMM_BACKEND is None: + # Auto-select best available backend + if current_platform.has_device_capability(100) and has_flashinfer(): + backend = NvFp4LinearBackend.FLASHINFER_CUTLASS + elif cutlass_fp4_supported(): + backend = NvFp4LinearBackend.VLLM_CUTLASS + elif is_fp4_marlin_supported(): + backend = NvFp4LinearBackend.MARLIN + else: + backend = NvFp4LinearBackend(envs.VLLM_NVFP4_GEMM_BACKEND) + + # Validate that the backend is supported + if backend in ( + NvFp4LinearBackend.FLASHINFER_CUTLASS, + NvFp4LinearBackend.FLASHINFER_TRTLLM, + NvFp4LinearBackend.FLASHINFER_CUDNN, + ): + assert has_flashinfer(), f"FlashInfer is required for {backend}" + elif backend == NvFp4LinearBackend.VLLM_CUTLASS: + assert cutlass_fp4_supported(), f"Cutlass is required for {backend}" + elif backend == NvFp4LinearBackend.MARLIN: + assert is_fp4_marlin_supported(), f"Marlin is required for {backend}" + elif backend is None: + raise ValueError( + f"No NVFP4 GEMM backend selected, " + f"available backends: {list(NvFp4LinearBackend)}" + ) + + logger.info_once(f"Using {backend} for NVFP4 GEMM") + return backend + + +def prepare_weights_for_nvfp4_flashinfer_trtllm( + weight: torch.Tensor, + weight_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare weights and scales for FlashInfer TRTLLM FP4 GEMM.""" + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a + + epilogue_tile_m = 128 + shuffled_weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) + shuffled_weight_scale = ( + shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m) + .reshape(weight_scale.shape) + .view(torch.float8_e4m3fn) + ) + + return shuffled_weight, shuffled_weight_scale + + +def prepare_weights_for_nvfp4_cutlass( + weight: torch.Tensor, + weight_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Prepare weights and scales for CUTLASS/FlashInfer-CUTLASS FP4 GEMM. + This involves padding weights for alignment (K and N divisible by 32) + """ + swizzled_weight_scale = swizzle_blockscale(weight_scale) + padded_weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(weight) + return padded_weight, swizzled_weight_scale, weights_padding_cols + + +def prepare_weights_for_nvfp4_fbgemm( + weight: torch.Tensor, + weight_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare weights and scales for FBGEMM FP4 GEMM.""" + swizzled_weight_scale = swizzle_blockscale(weight_scale) + swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8) + return weight, swizzled_weight_scale + + +def convert_to_nvfp4_linear_kernel_format( + backend: NvFp4LinearBackend, + layer: torch.nn.Module, +) -> None: + """Convert layer to NVFP4 linear kernel format.""" + + assert layer.weight_scale.dtype == torch.float8_e4m3fn, ( + "Weight Block scale must be represented as FP8-E4M3" + ) + + # Default to no padding + layer.weights_padding_cols = 0 + + if backend == NvFp4LinearBackend.MARLIN: + prepare_fp4_layer_for_marlin(layer) + elif backend == NvFp4LinearBackend.FLASHINFER_TRTLLM: + weight, weight_scale = prepare_weights_for_nvfp4_flashinfer_trtllm( + layer.weight.data, layer.weight_scale.data + ) + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + elif backend == NvFp4LinearBackend.FBGEMM: + weight, weight_scale = prepare_weights_for_nvfp4_fbgemm( + layer.weight.data, layer.weight_scale.data + ) + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + elif ( + backend == NvFp4LinearBackend.VLLM_CUTLASS + or backend == NvFp4LinearBackend.FLASHINFER_CUTLASS + ): + weight, weight_scale, weights_padding_cols = prepare_weights_for_nvfp4_cutlass( + layer.weight.data, layer.weight_scale.data + ) + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + layer.weights_padding_cols = weights_padding_cols + + +def apply_nvfp4_linear( + backend: NvFp4LinearBackend, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, +) -> torch.Tensor: + """ + Apply NVFP4 linear transformation using the specified backend. + """ + weight = layer.weight + weight_scale = layer.weight_scale + weight_global_scale = layer.weight_global_scale + input_global_scale_inv = layer.input_global_scale_inv + alpha = layer.alpha + output_size = layer.output_size_per_partition + input_size = layer.input_size_per_partition + + if backend == NvFp4LinearBackend.MARLIN: + return apply_fp4_marlin_linear( + input=x, + weight=weight, + weight_scale=weight_scale, + weight_global_scale=weight_global_scale, + workspace=layer.workspace, + size_n=output_size, + size_k=input_size, + bias=bias, + ) + elif backend == NvFp4LinearBackend.EMULATION: + out = run_nvfp4_emulations( + x=x, + input_global_scale=input_global_scale_inv, + weight=weight, + weight_scale_swizzled=weight_scale, + weight_global_scale=weight_global_scale, + ) + if bias is not None: + out = out + bias + return out + + output_dtype = x.dtype + output_shape = [*x.shape[:-1], output_size] + + # Quantize BF16 or FP16 to (FP4 and interleaved block scale) + x_fp4, x_blockscale = scaled_fp4_quant( + x, input_global_scale_inv, is_sf_swizzled_layout=True, backend=backend.value + ) + + # Validate dtypes + assert x_fp4.dtype == torch.uint8 + assert weight.dtype == torch.uint8 + assert x_blockscale.dtype == torch.float8_e4m3fn + # weight_scale is fp8 for most backends, but uint8 for fbgemm + assert weight_scale.dtype in (torch.float8_e4m3fn, torch.uint8) + assert alpha.dtype == torch.float32 + + # Pad activations to match weight K-dimension padding + weights_padding_cols = getattr(layer, "weights_padding_cols", 0) + x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols) + + # Prepare args for the matmul + mm_args = ( + x_fp4, + weight, + x_blockscale, + weight_scale, + alpha, + output_dtype, + ) + + # Call the appropriate backend + if backend.value.startswith("flashinfer-"): + backend_name = backend.value[len("flashinfer-") :] + out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name) + elif backend == NvFp4LinearBackend.FBGEMM: + out = torch.ops.fbgemm.f4f4bf16( + x_fp4, + weight, + x_blockscale.view(-1).view(torch.uint8), + weight_scale, + alpha, + use_mx=False, + ).to(output_dtype) + else: + assert backend == NvFp4LinearBackend.VLLM_CUTLASS + out = cutlass_scaled_fp4_mm(*mm_args) + + # Slice output to remove N-dimension padding + out = slice_nvfp4_output(out, output_size) + + if bias is not None: + out = out + bias + + return out.view(*output_shape) + + +def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor: + """ + Pad and block-interleave the FP4 block-scales so that they match the data + layout expected by the CUTLASS / FlashInfer kernels. + + Parameters + ---------- + scale: torch.Tensor + + Returns + ------- + torch.Tensor + The swizzled tensor with the same logical shape as *scale*. + """ + assert scale.dtype == torch.float8_e4m3fn, ( + "swizzle_blockscale expects the input tensor to be in " + "torch.float8_e4m3fn format." + ) + + scale_ndim = scale.ndim + if scale_ndim == 2: + scale = scale.unsqueeze(0) # (1, M, K) + assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales." + + B, M, K = scale.shape + + M_padded = round_up(M, 128) + K_padded = round_up(K, 4) + + padded = torch.zeros( + (B, M_padded, K_padded), dtype=scale.dtype, device=scale.device + ) + padded[:B, :M, :K] = scale + + # Reshape / permute to the layout required by the kernel. + padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4) + swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda() + + if scale_ndim == 2: + return swizzled.reshape(M_padded, K_padded) + return swizzled.reshape(B, M_padded, K_padded) + + +def cutlass_fp4_supported() -> bool: + if not current_platform.is_cuda(): + return False + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + return cutlass_scaled_mm_supports_fp4(capability) + + +def pad_nvfp4_weight_for_cutlass( + weight: torch.Tensor, + alignment: int = 32, +) -> tuple[torch.Tensor, int]: + """ + Pad packed NVFP4 weights so that both N (rows) and K (columns) satisfy + the alignment constraints required by CUTLASS / FlashInfer FP4 kernels. + + CUTLASS FP4 kernel requires both K and N matrix dimensions to be divisible + by 32 for aligned memory access and efficient tensor core operations. + """ + weight_current_rows = weight.shape[0] + + # Pad N dimension (rows) if not aligned + if weight_current_rows % alignment != 0: + total_rows = round_up(weight_current_rows, alignment) + pad_rows = total_rows - weight_current_rows + weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_rows)).contiguous() + + # Check K dimension alignment + # 2 FP4 items are packed per byte in the input dimension + weight_current_col_bytes = weight.shape[1] + weight_current_col_elements = weight_current_col_bytes * 2 + + weights_padding_bytes = 0 + if weight_current_col_elements % alignment != 0: + total_cols = round_up(weight_current_col_elements, alignment) + pad_cols = total_cols - weight_current_col_elements + # Convert from FP4 element count to bytes (2 FP4 values per byte) + # pad_cols is always even since alignment=32 and current elements are even + pad_bytes = pad_cols // 2 + weight = torch.nn.functional.pad(weight, (0, pad_bytes, 0, 0)).contiguous() + weights_padding_bytes = pad_bytes + + return weight, weights_padding_bytes + + +def pad_nvfp4_activation_for_cutlass( + x_fp4: torch.Tensor, + weights_padding_bytes: int, +) -> torch.Tensor: + """ + Pad packed FP4 activations to match the K-dimension padding applied to weights. + The padding is in bytes (tensor dimension), not FP4 elements. + """ + if weights_padding_bytes > 0: + return torch.nn.functional.pad(x_fp4, (0, weights_padding_bytes)).contiguous() + return x_fp4 + + +def slice_nvfp4_output( + out: torch.Tensor, + output_size: int, +) -> torch.Tensor: + """ + Slice the output tensor to remove padding in N dimension if weight was padded. + """ + if out.shape[-1] != output_size: + return out[..., :output_size].contiguous() + return out diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 5dbd05f16..e42868e41 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -11,7 +11,6 @@ import numpy import torch from torch import fx -from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -768,60 +767,6 @@ def awq_pack( return pack_cols(q_w, num_bits, size_k, size_n) -def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor: - """ - Pad and block-interleave the FP4 block-scales so that they match the data - layout expected by the CUTLASS / FlashInfer kernels. - - Parameters - ---------- - scale: torch.Tensor - - Returns - ------- - torch.Tensor - The swizzled tensor with the same logical shape as *scale*. - """ - assert scale.dtype == torch.float8_e4m3fn, ( - "swizzle_blockscale expects the input tensor to be in " - "torch.float8_e4m3fn format." - ) - - scale_ndim = scale.ndim - if scale_ndim == 2: - scale = scale.unsqueeze(0) # (1, M, K) - assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales." - - B, M, K = scale.shape - - def _round_up(x: int, m: int) -> int: - return (x + m - 1) // m * m - - M_padded = _round_up(M, 128) - K_padded = _round_up(K, 4) - - padded = torch.zeros( - (B, M_padded, K_padded), dtype=scale.dtype, device=scale.device - ) - padded[:B, :M, :K] = scale - - # Reshape / permute to the layout required by the kernel. - padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4) - swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda() - - if scale_ndim == 2: - return swizzled.reshape(M_padded, K_padded) - return swizzled.reshape(B, M_padded, K_padded) - - -def cutlass_fp4_supported() -> bool: - if not current_platform.is_cuda(): - return False - capability_tuple = current_platform.get_device_capability() - capability = -1 if capability_tuple is None else capability_tuple.to_int() - return cutlass_scaled_mm_supports_fp4(capability) - - def convert_bf16_scales_to_fp8( quant_fp8: Callable, scales: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: @@ -868,70 +813,3 @@ def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tens t |= ((nib - 8) & 0xF) << shift return t - - -def round_up(x: int, m: int) -> int: - """Round up x to the nearest multiple of m.""" - return (x + m - 1) // m * m - - -def pad_nvfp4_weight_for_cutlass( - weight: torch.Tensor, - alignment: int = 32, -) -> tuple[torch.Tensor, int]: - """ - Pad packed NVFP4 weights so that both N (rows) and K (columns) satisfy - the alignment constraints required by CUTLASS / FlashInfer FP4 kernels. - - CUTLASS FP4 kernel requires both K and N matrix dimensions to be divisible - by 32 for aligned memory access and efficient tensor core operations. - """ - weight_current_rows = weight.shape[0] - - # Pad N dimension (rows) if not aligned - if weight_current_rows % alignment != 0: - total_rows = round_up(weight_current_rows, alignment) - pad_rows = total_rows - weight_current_rows - weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_rows)).contiguous() - - # Check K dimension alignment - # 2 FP4 items are packed per byte in the input dimension - weight_current_col_bytes = weight.shape[1] - weight_current_col_elements = weight_current_col_bytes * 2 - - weights_padding_bytes = 0 - if weight_current_col_elements % alignment != 0: - total_cols = round_up(weight_current_col_elements, alignment) - pad_cols = total_cols - weight_current_col_elements - # Convert from FP4 element count to bytes (2 FP4 values per byte) - # pad_cols is always even since alignment=32 and current elements are even - pad_bytes = pad_cols // 2 - weight = torch.nn.functional.pad(weight, (0, pad_bytes, 0, 0)).contiguous() - weights_padding_bytes = pad_bytes - - return weight, weights_padding_bytes - - -def pad_nvfp4_activation_for_cutlass( - x_fp4: torch.Tensor, - weights_padding_bytes: int, -) -> torch.Tensor: - """ - Pad packed FP4 activations to match the K-dimension padding applied to weights. - The padding is in bytes (tensor dimension), not FP4 elements. - """ - if weights_padding_bytes > 0: - return torch.nn.functional.pad(x_fp4, (0, weights_padding_bytes)).contiguous() - return x_fp4 - - -def slice_nvfp4_output( - out: torch.Tensor, - output_size: int, -) -> torch.Tensor: - """ - Slice the output tensor to remove padding in N dimension if weight was padded. - """ - if out.shape[-1] != output_size: - return out[..., :output_size].contiguous() - return out