[W8A8 Block Linear Refactor][1/N] Keep all quantization types into QuantFP8 class. (#33047)
Signed-off-by: maral <maralbahari.98@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1570,7 +1570,7 @@ class rocm_aiter_ops:
|
||||
def group_fp8_quant(
|
||||
input_2d: torch.Tensor,
|
||||
group_size: int = 128,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert group_size == 128, "Group size must be 128"
|
||||
return torch.ops.vllm.rocm_aiter_group_fp8_quant(input_2d, group_size)
|
||||
|
||||
|
||||
@@ -14,6 +14,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
prep_scale_for_group_broadcast,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
DeepGemmQuantScaleFMT,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
|
||||
_FP8_DTYPE = current_platform.fp8_dtype()
|
||||
_FP8_MIN, _FP8_MAX = get_fp8_min_max()
|
||||
@@ -59,7 +64,8 @@ class QuantFP8(CustomOp):
|
||||
self.num_token_padding = num_token_padding
|
||||
self.column_major_scales = column_major_scales
|
||||
self.tma_aligned_scales = tma_aligned_scales
|
||||
self.use_ue8m0 = use_ue8m0
|
||||
self.use_ue8m0 = is_deep_gemm_e8m0_used() if use_ue8m0 is None else use_ue8m0
|
||||
self.use_deep_gemm_supported = is_deep_gemm_supported()
|
||||
|
||||
self.use_aiter = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
|
||||
@@ -79,10 +85,23 @@ class QuantFP8(CustomOp):
|
||||
x: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.quantization.utils import fp8_utils
|
||||
|
||||
if (
|
||||
self.is_group_quant
|
||||
and self.use_deep_gemm_supported
|
||||
and (DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0)
|
||||
):
|
||||
return fp8_utils.per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
x,
|
||||
group_size=self.group_size,
|
||||
use_ue8m0=True,
|
||||
)
|
||||
|
||||
if self.is_group_quant and not self.static:
|
||||
assert scale is None, "Dynamic group quantization does not use scale"
|
||||
from vllm.model_executor.layers.quantization.utils import fp8_utils
|
||||
|
||||
return fp8_utils.per_token_group_quant_fp8(
|
||||
x,
|
||||
@@ -116,25 +135,34 @@ class QuantFP8(CustomOp):
|
||||
x: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
use_aiter_quant = (
|
||||
not self.is_group_quant
|
||||
and self.use_aiter
|
||||
and scale_ub is None
|
||||
and x.is_contiguous()
|
||||
)
|
||||
use_aiter_per_tensor_quant = (
|
||||
use_aiter_quant and self.group_shape == GroupShape.PER_TENSOR
|
||||
)
|
||||
use_aiter_per_token_quant = (
|
||||
use_aiter_quant and self.group_shape == GroupShape.PER_TOKEN
|
||||
)
|
||||
use_triton = kwargs.get("use_triton", False)
|
||||
if self.is_group_quant and use_triton:
|
||||
assert scale is None, "Dynamic group quantization does not use scale"
|
||||
|
||||
return torch.ops.vllm.triton_per_token_group_quant_fp8(x, self.group_size)
|
||||
|
||||
use_aiter_quant = self.use_aiter and scale_ub is None and x.is_contiguous()
|
||||
use_aiter_per_tensor_quant = (
|
||||
use_aiter_quant and self.group_shape.is_per_tensor()
|
||||
)
|
||||
use_aiter_per_token_quant = use_aiter_quant and self.group_shape.is_per_token()
|
||||
|
||||
use_aiter_per_group_quant = use_aiter_quant and self.group_shape.is_per_group()
|
||||
|
||||
if use_aiter_per_group_quant:
|
||||
return rocm_aiter_ops.group_fp8_quant(x, self.group_size)
|
||||
if use_aiter_per_tensor_quant:
|
||||
return rocm_aiter_ops.per_tensor_quant(x, _FP8_DTYPE, scale)
|
||||
if use_aiter_per_token_quant:
|
||||
return rocm_aiter_ops.per_token_quant(x, _FP8_DTYPE, scale)
|
||||
|
||||
# Fallback to native implementation for group quantization.
|
||||
if self.is_group_quant:
|
||||
assert scale is None, "Dynamic group quantization does not use scale"
|
||||
return self._quantize_group_native(x)
|
||||
|
||||
# Fallback to CUDA implementation
|
||||
return self.forward_cuda(x, scale, scale_ub)
|
||||
|
||||
|
||||
@@ -33,7 +33,6 @@ from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import (
|
||||
DeepGemmQuantScaleFMT,
|
||||
fp8_gemm_nt,
|
||||
get_tma_aligned_size,
|
||||
is_deep_gemm_e8m0_used,
|
||||
@@ -426,15 +425,8 @@ class W8A8BlockFp8LinearOp:
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0:
|
||||
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
input_2d,
|
||||
group_size=self.act_quant_group_shape.col,
|
||||
use_ue8m0=True,
|
||||
)
|
||||
else:
|
||||
assert self.deepgemm_input_quant_op is not None
|
||||
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
|
||||
assert self.deepgemm_input_quant_op is not None
|
||||
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
|
||||
output = torch.empty(
|
||||
(q_input.shape[0], weight.shape[0]),
|
||||
dtype=torch.bfloat16,
|
||||
@@ -497,15 +489,8 @@ class W8A8BlockFp8LinearOp:
|
||||
|
||||
if input_scale is not None:
|
||||
q_input = input_2d
|
||||
elif use_triton:
|
||||
q_input, input_scale = torch.ops.vllm.triton_per_token_group_quant_fp8(
|
||||
input_2d,
|
||||
self.act_quant_group_shape.col,
|
||||
)
|
||||
else:
|
||||
q_input, input_scale = rocm_aiter_ops.group_fp8_quant(
|
||||
input_2d, self.act_quant_group_shape.col
|
||||
)
|
||||
q_input, input_scale = self.input_quant_op(input_2d, use_triton=use_triton)
|
||||
|
||||
return gemm_a8w8_blockscale_op(
|
||||
q_input,
|
||||
@@ -572,7 +557,7 @@ class W8A8BlockFp8LinearOp:
|
||||
],
|
||||
torch.Tensor,
|
||||
],
|
||||
QuantFP8 | None,
|
||||
QuantFP8,
|
||||
]:
|
||||
if use_cutlass:
|
||||
return self._run_cutlass, (
|
||||
@@ -584,7 +569,12 @@ class W8A8BlockFp8LinearOp:
|
||||
)
|
||||
)
|
||||
if use_aiter_and_is_supported:
|
||||
return self._run_aiter, None
|
||||
return self._run_aiter, QuantFP8(
|
||||
False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=False,
|
||||
)
|
||||
return self._run_triton, (
|
||||
QuantFP8(
|
||||
False,
|
||||
|
||||
Reference in New Issue
Block a user