diff --git a/tests/kernels/core/test_layernorm.py b/tests/kernels/core/test_layernorm.py index 42da24ccb..c39d42c75 100644 --- a/tests/kernels/core/test_layernorm.py +++ b/tests/kernels/core/test_layernorm.py @@ -6,7 +6,7 @@ import torch from tests.kernels.quant_utils import FP8_DTYPE from tests.kernels.utils import opcheck -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed @@ -162,3 +162,31 @@ def test_fused_rms_norm_quant( atol=1e-3, rtol=1e-3, ) + + +@torch.inference_mode() +def test_gemma_rms_norm_mixed_input_weight_dtype(default_vllm_config) -> None: + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + + device = CUDA_DEVICES[0] + torch.set_default_device(device) + + num_tokens, hidden_size = 32, 1024 + x = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=device) + layer = GemmaRMSNorm(hidden_size, eps=1e-6).to(device=device) + layer.weight.data.normal_(mean=0.0, std=0.1) + + # Gemma uses fp32 weight parameter while activations can be bf16. + assert layer.weight.dtype == torch.float32 + out = layer(x) + + x_fp32 = x.float() + weight_fp32 = layer.weight.data.float() + 1.0 + variance = x_fp32.pow(2).mean(dim=-1, keepdim=True) + ref = (x_fp32 * torch.rsqrt(variance + layer.variance_epsilon) * weight_fp32).to( + x.dtype + ) + + assert out.dtype == x.dtype + torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2) diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 86cdd7c5e..09b9a557f 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -12,6 +12,9 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass import vllm.ir.ops +from vllm.compilation.passes.fusion.rms_quant_fusion import ( + _rms_input_weight_dtype_match, +) from vllm.config import VllmConfig from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce @@ -320,7 +323,12 @@ class AllReduceRMSNormPattern(BasePattern): return allreduce[3], allreduce[1] pm.register_replacement( - pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + pattern, + replacement, + self.get_inputs(), + pm.fwd_only, + pm_pass, + extra_check=_rms_input_weight_dtype_match, ) @@ -459,7 +467,12 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): return allreduce[4], allreduce[1] pm.register_replacement( - pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + pattern, + replacement, + self.get_inputs(), + pm.fwd_only, + pm_pass, + extra_check=_rms_input_weight_dtype_match, ) @@ -621,7 +634,12 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): return allreduce[4], allreduce[1], allreduce[5] pm.register_replacement( - pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + pattern, + replacement, + self.get_inputs(), + pm.fwd_only, + pm_pass, + extra_check=_rms_input_weight_dtype_match, ) diff --git a/vllm/compilation/passes/fusion/rms_quant_fusion.py b/vllm/compilation/passes/fusion/rms_quant_fusion.py index 0e5121c78..850e434a3 100644 --- a/vllm/compilation/passes/fusion/rms_quant_fusion.py +++ b/vllm/compilation/passes/fusion/rms_quant_fusion.py @@ -38,6 +38,22 @@ FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 +_RMS_NORM_OP = torch.ops.vllm_ir.rms_norm.default + + +# TODO: extend rmsnorm quant kernels to support mixed input/weight dtypes, +# and remove this check. +def _rms_input_weight_dtype_match(match: pm.Match) -> bool: + """Prevent fusion when rms_norm input and weight dtypes differ.""" + for node in match.nodes: + if node.target == _RMS_NORM_OP: + # rms_norm(x, weight, epsilon, variance_size) + x, weight = node.args[0], node.args[1] + if isinstance(x, fx.Node) and isinstance(weight, fx.Node): + return x.meta["val"].dtype == weight.meta["val"].dtype + return True + + def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor: return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") @@ -186,7 +202,14 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern): ] pattern(*inputs) - pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=_rms_input_weight_dtype_match, + ) class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): @@ -249,6 +272,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): inputs, pm.fwd_only, pm_pass, + extra_check=_rms_input_weight_dtype_match, ) @@ -350,6 +374,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): self.rmsnorm_matcher.inputs() + [scale], pm.fwd_only, pm_pass, + extra_check=_rms_input_weight_dtype_match, ) @@ -445,6 +470,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): ], pm.fwd_only, pm_pass, + extra_check=_rms_input_weight_dtype_match, ) @@ -503,6 +529,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): ], pm.fwd_only, pm_pass, + extra_check=_rms_input_weight_dtype_match, ) @@ -559,6 +586,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass, + extra_check=_rms_input_weight_dtype_match, ) diff --git a/vllm/ir/ops/layernorm.py b/vllm/ir/ops/layernorm.py index 8471aa043..ac0c38a9e 100644 --- a/vllm/ir/ops/layernorm.py +++ b/vllm/ir/ops/layernorm.py @@ -16,7 +16,6 @@ def rms_norm( x_var = x if variance_size is None else x[..., :variance_size] variance = x_var.pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(variance + epsilon) - x = x.to(orig_dtype) if weight is not None: - x = x * weight - return x + x = x.to(weight.dtype) * weight + return x.to(orig_dtype) diff --git a/vllm/kernels/aiter_ops.py b/vllm/kernels/aiter_ops.py index 1980051dd..14c2b87fb 100644 --- a/vllm/kernels/aiter_ops.py +++ b/vllm/kernels/aiter_ops.py @@ -36,13 +36,11 @@ AITER_SUPPORTED = is_aiter_found() rms_no_var_16bit_only = ( lambda x, weight, epsilon, variance_size=None: variance_size is None - and x.dtype - in ( - torch.float16, - torch.bfloat16, - ) + and x.dtype in (torch.float16, torch.bfloat16) + and (weight is None or weight.dtype == x.dtype) ) -"""AITER rms_norm only supports float16 and bfloat16 acts and no var_size override.""" +"""AITER rms_norm only supports float16 and bfloat16 acts, no var_size override, +and requires weight dtype to match x dtype.""" @ir.ops.rms_norm.register_impl( diff --git a/vllm/kernels/vllm_c.py b/vllm/kernels/vllm_c.py index fabb36d7b..124b02e4e 100644 --- a/vllm/kernels/vllm_c.py +++ b/vllm/kernels/vllm_c.py @@ -11,8 +11,11 @@ current_platform.import_kernels() CUDA_ALIKE = current_platform.is_cuda_alike() """Most kernels in this file are supported on all CUDA-alike platforms.""" -rms_no_var_size = lambda x, weight, epsilon, variance_size=None: variance_size is None -"""vLLM kernel does not support variance_size parameter.""" +rms_no_var_size = ( + lambda x, weight, epsilon, variance_size=None: variance_size is None + and (weight is None or weight.dtype == x.dtype) +) +"""vLLM kernel requires no variance_size override and matching input/weight dtype.""" @ir.ops.rms_norm.register_impl( diff --git a/vllm/kernels/xpu_ops.py b/vllm/kernels/xpu_ops.py index 3548fb868..c680c542c 100644 --- a/vllm/kernels/xpu_ops.py +++ b/vllm/kernels/xpu_ops.py @@ -18,7 +18,9 @@ def is_xpu_kernels_found() -> bool: XPU_KERNELS_SUPPORTED = is_xpu_kernels_found() """Kernels in this file are supported if vLLM XPU kernels are installed.""" -rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None +rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None and ( + weight is None or weight.dtype == x.dtype +) @ir.ops.rms_norm.register_impl( diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 7b222f9c4..9afc4c9c0 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -376,77 +376,32 @@ class GemmaRMSNorm(CustomOp): self.weight = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps - @staticmethod - def _forward_static_no_residual( - weight: torch.Tensor, - variance_epsilon: float, - x: torch.Tensor, - ) -> torch.Tensor: - """PyTorch-native implementation equivalent to forward() without residual.""" - orig_dtype = x.dtype - x = x.float() - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + variance_epsilon) - x = x * (1.0 + weight.float()) - x = x.to(orig_dtype) - return x - - @staticmethod - def _forward_static_with_residual( - weight: torch.Tensor, - variance_epsilon: float, - x: torch.Tensor, - residual: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - """PyTorch-native implementation equivalent to forward() with residual.""" - orig_dtype = x.dtype - x = ( - x.float() + residual.float() - if orig_dtype == torch.float16 - else x + residual - ) - residual = x - - x = x.float() - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + variance_epsilon) - # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) - # See https://github.com/huggingface/transformers/pull/29402 - x = x * (1.0 + weight.float()) - x = x.to(orig_dtype) - return x, residual - def forward_native( self, x: torch.Tensor, residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """PyTorch-native implementation equivalent to forward().""" - if residual is None: - return self._forward_static_no_residual( - self.weight.data, self.variance_epsilon, x - ) - else: - return self._forward_static_with_residual( - self.weight.data, self.variance_epsilon, x, residual + orig_dtype = x.dtype + weight = self.weight.data.float() + 1.0 + if residual is not None: + x = ( + x.float() + residual.float() + if orig_dtype == torch.float16 + else x + residual ) + residual = x + # ir.ops.rms_norm handles fp32 upcast internally + out = ir.ops.rms_norm(x, weight, self.variance_epsilon) + return ( + out.to(orig_dtype) if residual is None else (out.to(orig_dtype), residual) + ) def forward_cuda( self, x: torch.Tensor, residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if torch.compiler.is_compiling(): - return self.forward_native(x, residual) - - if not getattr(self, "_is_compiled", False): - self._forward_static_no_residual = torch.compile( # type: ignore - self._forward_static_no_residual - ) - self._forward_static_with_residual = torch.compile( # type: ignore - self._forward_static_with_residual - ) - self._is_compiled = True return self.forward_native(x, residual)