diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 1888d78ca..610891ebf 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -1570,7 +1570,7 @@ class rocm_aiter_ops: def group_fp8_quant( input_2d: torch.Tensor, group_size: int = 128, - ) -> tuple[torch.Tensor, ...]: + ) -> tuple[torch.Tensor, torch.Tensor]: assert group_size == 128, "Group size must be 128" return torch.ops.vllm.rocm_aiter_group_fp8_quant(input_2d, group_size) diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index 50c97190d..5bc78afa4 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -14,6 +14,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( prep_scale_for_group_broadcast, ) from vllm.platforms import current_platform +from vllm.utils.deep_gemm import ( + DeepGemmQuantScaleFMT, + is_deep_gemm_e8m0_used, + is_deep_gemm_supported, +) _FP8_DTYPE = current_platform.fp8_dtype() _FP8_MIN, _FP8_MAX = get_fp8_min_max() @@ -59,7 +64,8 @@ class QuantFP8(CustomOp): self.num_token_padding = num_token_padding self.column_major_scales = column_major_scales self.tma_aligned_scales = tma_aligned_scales - self.use_ue8m0 = use_ue8m0 + self.use_ue8m0 = is_deep_gemm_e8m0_used() if use_ue8m0 is None else use_ue8m0 + self.use_deep_gemm_supported = is_deep_gemm_supported() self.use_aiter = rocm_aiter_ops.is_linear_fp8_enabled() @@ -79,10 +85,23 @@ class QuantFP8(CustomOp): x: torch.Tensor, scale: torch.Tensor | None = None, scale_ub: torch.Tensor | None = None, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: + from vllm.model_executor.layers.quantization.utils import fp8_utils + + if ( + self.is_group_quant + and self.use_deep_gemm_supported + and (DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0) + ): + return fp8_utils.per_token_group_quant_fp8_packed_for_deepgemm( + x, + group_size=self.group_size, + use_ue8m0=True, + ) + if self.is_group_quant and not self.static: assert scale is None, "Dynamic group quantization does not use scale" - from vllm.model_executor.layers.quantization.utils import fp8_utils return fp8_utils.per_token_group_quant_fp8( x, @@ -116,25 +135,34 @@ class QuantFP8(CustomOp): x: torch.Tensor, scale: torch.Tensor | None = None, scale_ub: torch.Tensor | None = None, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: - use_aiter_quant = ( - not self.is_group_quant - and self.use_aiter - and scale_ub is None - and x.is_contiguous() - ) - use_aiter_per_tensor_quant = ( - use_aiter_quant and self.group_shape == GroupShape.PER_TENSOR - ) - use_aiter_per_token_quant = ( - use_aiter_quant and self.group_shape == GroupShape.PER_TOKEN - ) + use_triton = kwargs.get("use_triton", False) + if self.is_group_quant and use_triton: + assert scale is None, "Dynamic group quantization does not use scale" + return torch.ops.vllm.triton_per_token_group_quant_fp8(x, self.group_size) + + use_aiter_quant = self.use_aiter and scale_ub is None and x.is_contiguous() + use_aiter_per_tensor_quant = ( + use_aiter_quant and self.group_shape.is_per_tensor() + ) + use_aiter_per_token_quant = use_aiter_quant and self.group_shape.is_per_token() + + use_aiter_per_group_quant = use_aiter_quant and self.group_shape.is_per_group() + + if use_aiter_per_group_quant: + return rocm_aiter_ops.group_fp8_quant(x, self.group_size) if use_aiter_per_tensor_quant: return rocm_aiter_ops.per_tensor_quant(x, _FP8_DTYPE, scale) if use_aiter_per_token_quant: return rocm_aiter_ops.per_token_quant(x, _FP8_DTYPE, scale) + # Fallback to native implementation for group quantization. + if self.is_group_quant: + assert scale is None, "Dynamic group quantization does not use scale" + return self._quantize_group_native(x) + # Fallback to CUDA implementation return self.forward_cuda(x, scale, scale_ub) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 6aa59218b..cc6c2eee4 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -33,7 +33,6 @@ from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import ( - DeepGemmQuantScaleFMT, fp8_gemm_nt, get_tma_aligned_size, is_deep_gemm_e8m0_used, @@ -426,15 +425,8 @@ class W8A8BlockFp8LinearOp: weight: torch.Tensor, weight_scale: torch.Tensor, ) -> torch.Tensor: - if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0: - q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm( - input_2d, - group_size=self.act_quant_group_shape.col, - use_ue8m0=True, - ) - else: - assert self.deepgemm_input_quant_op is not None - q_input, input_scale = self.deepgemm_input_quant_op(input_2d) + assert self.deepgemm_input_quant_op is not None + q_input, input_scale = self.deepgemm_input_quant_op(input_2d) output = torch.empty( (q_input.shape[0], weight.shape[0]), dtype=torch.bfloat16, @@ -497,15 +489,8 @@ class W8A8BlockFp8LinearOp: if input_scale is not None: q_input = input_2d - elif use_triton: - q_input, input_scale = torch.ops.vllm.triton_per_token_group_quant_fp8( - input_2d, - self.act_quant_group_shape.col, - ) else: - q_input, input_scale = rocm_aiter_ops.group_fp8_quant( - input_2d, self.act_quant_group_shape.col - ) + q_input, input_scale = self.input_quant_op(input_2d, use_triton=use_triton) return gemm_a8w8_blockscale_op( q_input, @@ -572,7 +557,7 @@ class W8A8BlockFp8LinearOp: ], torch.Tensor, ], - QuantFP8 | None, + QuantFP8, ]: if use_cutlass: return self._run_cutlass, ( @@ -584,7 +569,12 @@ class W8A8BlockFp8LinearOp: ) ) if use_aiter_and_is_supported: - return self._run_aiter, None + return self._run_aiter, QuantFP8( + False, + self.act_quant_group_shape, + column_major_scales=False, + use_ue8m0=False, + ) return self._run_triton, ( QuantFP8( False,