[ROCm] Guard group quant RMS norm fusion patterns (#30239)

This commit is contained in:
Ye (Charlotte) Qi
2025-12-08 06:44:48 -08:00
committed by GitHub
parent 80433e225e
commit eb1051fb95

View File

@@ -490,6 +490,8 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse fused_add_rms_norm + fp8 group quant
# Only register group quant patterns on CUDA where the C++ op exists
if current_platform.is_cuda():
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)