[W8A8 Block Linear Refactor][1/N] Keep all quantization types into QuantFP8 class. (#33047)

Signed-off-by: maral <maralbahari.98@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Maral
2026-02-01 17:28:01 +08:00
committed by GitHub
parent 21997f45b1
commit b5f8c3092d
3 changed files with 53 additions and 35 deletions

View File

@@ -1570,7 +1570,7 @@ class rocm_aiter_ops:
def group_fp8_quant(
input_2d: torch.Tensor,
group_size: int = 128,
) -> tuple[torch.Tensor, ...]:
) -> tuple[torch.Tensor, torch.Tensor]:
assert group_size == 128, "Group size must be 128"
return torch.ops.vllm.rocm_aiter_group_fp8_quant(input_2d, group_size)

View File

@@ -14,6 +14,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
prep_scale_for_group_broadcast,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
_FP8_DTYPE = current_platform.fp8_dtype()
_FP8_MIN, _FP8_MAX = get_fp8_min_max()
@@ -59,7 +64,8 @@ class QuantFP8(CustomOp):
self.num_token_padding = num_token_padding
self.column_major_scales = column_major_scales
self.tma_aligned_scales = tma_aligned_scales
self.use_ue8m0 = use_ue8m0
self.use_ue8m0 = is_deep_gemm_e8m0_used() if use_ue8m0 is None else use_ue8m0
self.use_deep_gemm_supported = is_deep_gemm_supported()
self.use_aiter = rocm_aiter_ops.is_linear_fp8_enabled()
@@ -79,10 +85,23 @@ class QuantFP8(CustomOp):
x: torch.Tensor,
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.quantization.utils import fp8_utils
if (
self.is_group_quant
and self.use_deep_gemm_supported
and (DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0)
):
return fp8_utils.per_token_group_quant_fp8_packed_for_deepgemm(
x,
group_size=self.group_size,
use_ue8m0=True,
)
if self.is_group_quant and not self.static:
assert scale is None, "Dynamic group quantization does not use scale"
from vllm.model_executor.layers.quantization.utils import fp8_utils
return fp8_utils.per_token_group_quant_fp8(
x,
@@ -116,25 +135,34 @@ class QuantFP8(CustomOp):
x: torch.Tensor,
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
use_aiter_quant = (
not self.is_group_quant
and self.use_aiter
and scale_ub is None
and x.is_contiguous()
)
use_aiter_per_tensor_quant = (
use_aiter_quant and self.group_shape == GroupShape.PER_TENSOR
)
use_aiter_per_token_quant = (
use_aiter_quant and self.group_shape == GroupShape.PER_TOKEN
)
use_triton = kwargs.get("use_triton", False)
if self.is_group_quant and use_triton:
assert scale is None, "Dynamic group quantization does not use scale"
return torch.ops.vllm.triton_per_token_group_quant_fp8(x, self.group_size)
use_aiter_quant = self.use_aiter and scale_ub is None and x.is_contiguous()
use_aiter_per_tensor_quant = (
use_aiter_quant and self.group_shape.is_per_tensor()
)
use_aiter_per_token_quant = use_aiter_quant and self.group_shape.is_per_token()
use_aiter_per_group_quant = use_aiter_quant and self.group_shape.is_per_group()
if use_aiter_per_group_quant:
return rocm_aiter_ops.group_fp8_quant(x, self.group_size)
if use_aiter_per_tensor_quant:
return rocm_aiter_ops.per_tensor_quant(x, _FP8_DTYPE, scale)
if use_aiter_per_token_quant:
return rocm_aiter_ops.per_token_quant(x, _FP8_DTYPE, scale)
# Fallback to native implementation for group quantization.
if self.is_group_quant:
assert scale is None, "Dynamic group quantization does not use scale"
return self._quantize_group_native(x)
# Fallback to CUDA implementation
return self.forward_cuda(x, scale, scale_ub)

View File

@@ -33,7 +33,6 @@ from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
fp8_gemm_nt,
get_tma_aligned_size,
is_deep_gemm_e8m0_used,
@@ -426,15 +425,8 @@ class W8A8BlockFp8LinearOp:
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0:
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
input_2d,
group_size=self.act_quant_group_shape.col,
use_ue8m0=True,
)
else:
assert self.deepgemm_input_quant_op is not None
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
assert self.deepgemm_input_quant_op is not None
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
output = torch.empty(
(q_input.shape[0], weight.shape[0]),
dtype=torch.bfloat16,
@@ -497,15 +489,8 @@ class W8A8BlockFp8LinearOp:
if input_scale is not None:
q_input = input_2d
elif use_triton:
q_input, input_scale = torch.ops.vllm.triton_per_token_group_quant_fp8(
input_2d,
self.act_quant_group_shape.col,
)
else:
q_input, input_scale = rocm_aiter_ops.group_fp8_quant(
input_2d, self.act_quant_group_shape.col
)
q_input, input_scale = self.input_quant_op(input_2d, use_triton=use_triton)
return gemm_a8w8_blockscale_op(
q_input,
@@ -572,7 +557,7 @@ class W8A8BlockFp8LinearOp:
],
torch.Tensor,
],
QuantFP8 | None,
QuantFP8,
]:
if use_cutlass:
return self._run_cutlass, (
@@ -584,7 +569,12 @@ class W8A8BlockFp8LinearOp:
)
)
if use_aiter_and_is_supported:
return self._run_aiter, None
return self._run_aiter, QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=False,
use_ue8m0=False,
)
return self._run_triton, (
QuantFP8(
False,