Use aiter triton fused_add_rmsnorm_pad for gpt-oss (#30976)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
@@ -126,6 +126,10 @@ class PassConfig:
|
||||
fuse_allreduce_rms: bool = Field(default=None)
|
||||
"""Enable flashinfer allreduce fusion."""
|
||||
|
||||
# ROCm/AITER specific fusions
|
||||
fuse_act_padding: bool = Field(default=None)
|
||||
"""Fuse the custom RMSNorm + padding ops."""
|
||||
|
||||
fi_allreduce_fusion_max_size_mb: float | None = None
|
||||
"""The threshold of the communicated tensor sizes under which
|
||||
vllm should use flashinfer fused allreduce. Specified as a
|
||||
@@ -194,6 +198,7 @@ class PassConfig:
|
||||
"enable_sp",
|
||||
"fuse_gemm_comms",
|
||||
"fuse_allreduce_rms",
|
||||
"fuse_act_padding",
|
||||
mode="wrap",
|
||||
)
|
||||
@classmethod
|
||||
@@ -222,12 +227,23 @@ class PassConfig:
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"Allreduce + rms norm + quant (fp8) fusion might not work"
|
||||
)
|
||||
if self.fuse_act_padding:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"RMSNorm + padding fusion might not work"
|
||||
)
|
||||
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda_alike():
|
||||
logger.warning_once(
|
||||
"QK Norm + RoPE fusion enabled but the current platform is not "
|
||||
"CUDA or ROCm. The fusion will be disabled."
|
||||
)
|
||||
self.enable_qk_norm_rope_fusion = False
|
||||
if self.fuse_act_padding and not current_platform.is_rocm():
|
||||
logger.warning_once(
|
||||
"Padding fusion enabled but the current platform is not ROCm. "
|
||||
"The fusion will be disabled."
|
||||
)
|
||||
self.fuse_act_padding = False
|
||||
|
||||
|
||||
class DynamicShapesType(str, enum.Enum):
|
||||
|
||||
Reference in New Issue
Block a user