Files
vllm/vllm/config/kernel.py
2026-02-25 17:44:44 -08:00

77 lines
2.6 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Literal
from pydantic import Field, field_validator
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
MoEBackend = Literal[
"auto",
"triton",
"deep_gemm",
"cutlass",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_cutedsl",
"marlin",
"aiter",
]
@config
class KernelConfig:
"""Configuration for kernel selection and warmup behavior."""
enable_flashinfer_autotune: bool = Field(default=None)
"""If True, run FlashInfer autotuning during kernel warmup."""
moe_backend: MoEBackend = "auto"
"""Backend for MoE expert computation kernels. Available options:
- "auto": Automatically select the best backend based on model and hardware\n
- "triton": Use Triton-based fused MoE kernels\n
- "deep_gemm": Use DeepGEMM kernels (FP8 block-quantized only)\n
- "cutlass": Use vLLM CUTLASS kernels\n
- "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels\n
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels\n
- "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only)\n
- "marlin": Use Marlin kernels (weight-only quantization)\n
- "aiter": Use AMD AITer kernels (ROCm only)"""
@field_validator("moe_backend", mode="before")
@classmethod
def _normalize_moe_backend(cls, value: Any) -> Any:
if isinstance(value, str):
return value.lower().replace("-", "_")
return value
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator("enable_flashinfer_autotune", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialization is delayed."""
if value is None:
return value
return handler(value)