[Performance] Fused blockwise quant RMS norm (#27883)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -15,13 +15,22 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
kFp8Dynamic64Sym,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
kStaticTensorScale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_block_fp8_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_e8m0_used,
|
||||
should_use_deepgemm_for_fp8_linear_for_nk,
|
||||
)
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
@@ -58,6 +67,9 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
}
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
|
||||
class FusedRMSQuantKey(NamedTuple):
|
||||
@@ -90,6 +102,18 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
||||
FusedRMSQuantKey(
|
||||
kFp8DynamicTokenSym, True
|
||||
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic128Sym, False
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic128Sym, True
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic64Sym, False
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic64Sym, True
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
@@ -100,6 +124,15 @@ class RMSNormQuantPattern:
|
||||
config = get_current_vllm_config()
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
|
||||
# groupwise FP8 linear uses col major scales if deepgemm and cutlass
|
||||
using_deepgemm = should_use_deepgemm_for_fp8_linear_for_nk(
|
||||
self.model_dtype,
|
||||
config.model_config.hf_config.intermediate_size,
|
||||
config.model_config.hf_config.hidden_size,
|
||||
)
|
||||
use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported()
|
||||
use_e8m0 = is_deep_gemm_e8m0_used() if using_deepgemm else False
|
||||
|
||||
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
|
||||
self.FUSED_OP = FUSED_OPS[key]
|
||||
|
||||
@@ -108,7 +141,9 @@ class RMSNormQuantPattern:
|
||||
if not key.fused_add
|
||||
else MatcherFusedAddRMSNorm(epsilon)
|
||||
)
|
||||
self.quant_matcher = MatcherQuantFP8(key.quant)
|
||||
self.quant_matcher = MatcherQuantFP8(
|
||||
key.quant, use_col_major_scales=use_col_major_scales, use_e8m0=use_e8m0
|
||||
)
|
||||
|
||||
|
||||
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
@@ -218,6 +253,120 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
)
|
||||
|
||||
|
||||
class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
symmetric=True,
|
||||
):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
self.group_shape = group_shape
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, residual, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(
|
||||
input, transposed=self.quant_matcher.use_col_major_scales
|
||||
)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=residual,
|
||||
group_size=self.group_shape[1],
|
||||
is_scale_transposed=self.quant_matcher.use_col_major_scales,
|
||||
)
|
||||
|
||||
# result, residual, scale
|
||||
return at[1], at[3], at[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
symmetric=True,
|
||||
):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
self.group_shape = group_shape
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor):
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, scale
|
||||
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(
|
||||
input, transposed=self.quant_matcher.use_col_major_scales
|
||||
)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=None,
|
||||
group_size=self.group_shape[1],
|
||||
is_scale_transposed=self.quant_matcher.use_col_major_scales,
|
||||
)
|
||||
|
||||
# result, scale
|
||||
return at[1], at[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -340,6 +489,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
# Make sure fused add patterns are before simple rms norm,
|
||||
# as the latter is a subset of the former in torch ops
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse fused_add_rms_norm + fp8 group quant
|
||||
FusedAddRMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse rms_norm + fp8 group quant
|
||||
RMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
|
||||
FusedAddRMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse rms_norm + fp8 group quant
|
||||
RMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + static fp8 quant
|
||||
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns
|
||||
@@ -366,9 +534,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
def uuid(self) -> Any:
|
||||
return self.hash_source(
|
||||
self,
|
||||
RMSNormGroupQuantPattern,
|
||||
RMSNormQuantPattern,
|
||||
RMSNormStaticQuantPattern,
|
||||
RMSNormDynamicQuantPattern,
|
||||
FusedAddRMSNormStaticQuantPattern,
|
||||
FusedAddRMSNormDynamicQuantPattern,
|
||||
FusedAddRMSNormGroupQuantPattern,
|
||||
)
|
||||
|
||||
@@ -13,6 +13,8 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
_normalize_quant_group_shape,
|
||||
kFp8Dynamic64Sym,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
@@ -35,6 +37,10 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
|
||||
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
|
||||
|
||||
@@ -224,12 +230,20 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
||||
|
||||
|
||||
class MatcherQuantFP8(MatcherCustomOp):
|
||||
def __init__(self, quant_key: QuantKey, enabled: bool | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
quant_key: QuantKey,
|
||||
enabled: bool | None = None,
|
||||
use_col_major_scales: bool = False,
|
||||
use_e8m0: bool = False,
|
||||
):
|
||||
if enabled is None:
|
||||
enabled = QuantFP8.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.quant_key = quant_key
|
||||
self.use_col_major_scales = use_col_major_scales
|
||||
self.use_e8m0 = use_e8m0
|
||||
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
|
||||
self.QUANT_OP = QUANT_OPS[quant_key]
|
||||
|
||||
@@ -248,6 +262,27 @@ class MatcherQuantFP8(MatcherCustomOp):
|
||||
input.shape, device=input.device, dtype=self.quant_key.dtype
|
||||
)
|
||||
|
||||
if self.quant_key.scale.group_shape.is_per_group():
|
||||
assert scale is None
|
||||
scale = self.make_scale(input, transposed=self.use_col_major_scales)
|
||||
|
||||
finfo = torch.finfo(self.quant_key.dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_, result, scale = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
input=input,
|
||||
output_q=result,
|
||||
output_s=scale,
|
||||
group_size=self.quant_key.scale.group_shape[1],
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
scale_ue8m0=self.use_e8m0,
|
||||
)
|
||||
return result, scale
|
||||
|
||||
if self.quant_key.scale.static:
|
||||
assert scale is not None
|
||||
_, result = auto_functionalized(
|
||||
@@ -269,7 +304,7 @@ class MatcherQuantFP8(MatcherCustomOp):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.quant_fp8(input, scale)
|
||||
|
||||
def make_scale(self, input: torch.Tensor):
|
||||
def make_scale(self, input: torch.Tensor, transposed: bool = False):
|
||||
normalized_group_shape = _normalize_quant_group_shape(
|
||||
input, self.quant_key.scale.group_shape
|
||||
)
|
||||
@@ -277,6 +312,11 @@ class MatcherQuantFP8(MatcherCustomOp):
|
||||
input.shape[0] // normalized_group_shape[0],
|
||||
input.shape[1] // normalized_group_shape[1],
|
||||
)
|
||||
if transposed:
|
||||
scale_shape = tuple(reversed(scale_shape))
|
||||
return torch.empty(
|
||||
scale_shape, device=input.device, dtype=torch.float32
|
||||
).permute(-1, -2)
|
||||
|
||||
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user