Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
77 lines
2.6 KiB
Python
77 lines
2.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from collections.abc import Callable
|
|
from typing import Any, Literal
|
|
|
|
from pydantic import Field, field_validator
|
|
|
|
from vllm.config.utils import config
|
|
from vllm.utils.hashing import safe_hash
|
|
|
|
MoEBackend = Literal[
|
|
"auto",
|
|
"triton",
|
|
"deep_gemm",
|
|
"cutlass",
|
|
"flashinfer_trtllm",
|
|
"flashinfer_cutlass",
|
|
"flashinfer_cutedsl",
|
|
"marlin",
|
|
"aiter",
|
|
]
|
|
|
|
|
|
@config
|
|
class KernelConfig:
|
|
"""Configuration for kernel selection and warmup behavior."""
|
|
|
|
enable_flashinfer_autotune: bool = Field(default=None)
|
|
"""If True, run FlashInfer autotuning during kernel warmup."""
|
|
|
|
moe_backend: MoEBackend = "auto"
|
|
"""Backend for MoE expert computation kernels. Available options:
|
|
|
|
- "auto": Automatically select the best backend based on model and hardware\n
|
|
- "triton": Use Triton-based fused MoE kernels\n
|
|
- "deep_gemm": Use DeepGEMM kernels (FP8 block-quantized only)\n
|
|
- "cutlass": Use vLLM CUTLASS kernels\n
|
|
- "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels\n
|
|
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels\n
|
|
- "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only)\n
|
|
- "marlin": Use Marlin kernels (weight-only quantization)\n
|
|
- "aiter": Use AMD AITer kernels (ROCm only)"""
|
|
|
|
@field_validator("moe_backend", mode="before")
|
|
@classmethod
|
|
def _normalize_moe_backend(cls, value: Any) -> Any:
|
|
if isinstance(value, str):
|
|
return value.lower().replace("-", "_")
|
|
return value
|
|
|
|
def compute_hash(self) -> str:
|
|
"""
|
|
WARNING: Whenever a new field is added to this config,
|
|
ensure that it is included in the factors list if
|
|
it affects the computation graph.
|
|
|
|
Provide a hash that uniquely identifies all the configs
|
|
that affect the structure of the computation
|
|
graph from input ids/embeddings to the final hidden states,
|
|
excluding anything before input ids/embeddings and after
|
|
the final hidden states.
|
|
"""
|
|
# no factors to consider.
|
|
# this config will not affect the computation graph.
|
|
factors: list[Any] = []
|
|
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
|
return hash_str
|
|
|
|
@field_validator("enable_flashinfer_autotune", mode="wrap")
|
|
@classmethod
|
|
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
|
"""Skip validation if the value is `None` when initialization is delayed."""
|
|
if value is None:
|
|
return value
|
|
return handler(value)
|