[ROCm] Aiter Quant Kernels (#25552)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm
2025-12-09 22:27:37 +08:00
committed by GitHub
parent 1166c31cc7
commit ee14644ba9
3 changed files with 123 additions and 2 deletions

View File

@@ -381,6 +381,8 @@ class RocmPlatform(Platform):
compilation_config = vllm_config.compilation_config
parallel_config = vllm_config.parallel_config
is_eager_execution = compilation_config == CUDAGraphMode.NONE
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled()
if compilation_config.cudagraph_mode.has_full_cudagraphs():
# decode context parallel does not support full cudagraphs
@@ -400,8 +402,6 @@ class RocmPlatform(Platform):
)
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
@@ -415,6 +415,9 @@ class RocmPlatform(Platform):
):
compilation_config.custom_ops.append("+rms_norm")
if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
compilation_config.custom_ops.append("+quant_fp8")
@classmethod
def verify_model_arch(cls, model_arch: str) -> None:
if model_arch in _ROCM_UNSUPPORTED_MODELS: