[Kernel] Add KernelConfig flag to enable/disable FlashInfer autotune (#34006)

Signed-off-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Mohammad Miadh Angkad
2026-02-07 21:24:44 +08:00
committed by GitHub
parent edb359cce4
commit dd6a6e1190
5 changed files with 104 additions and 1 deletions

View File

@@ -11,6 +11,7 @@ from vllm.config.compilation import (
)
from vllm.config.device import DeviceConfig
from vllm.config.ec_transfer import ECTransferConfig
from vllm.config.kernel import KernelConfig
from vllm.config.kv_events import KVEventsConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.config.load import LoadConfig
@@ -65,6 +66,8 @@ __all__ = [
"DeviceConfig",
# From vllm.config.ec_transfer
"ECTransferConfig",
# From vllm.config.kernel
"KernelConfig",
# From vllm.config.kv_events
"KVEventsConfig",
# From vllm.config.kv_transfer

44
vllm/config/kernel.py Normal file
View File

@@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any
from pydantic import Field, field_validator
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
@config
class KernelConfig:
"""Configuration for kernel selection and warmup behavior."""
enable_flashinfer_autotune: bool = Field(default=None)
"""If True, run FlashInfer autotuning during kernel warmup."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator("enable_flashinfer_autotune", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialization is delayed."""
if value is None:
return value
return handler(value)

View File

@@ -30,6 +30,7 @@ from .cache import CacheConfig
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
from .device import DeviceConfig
from .ec_transfer import ECTransferConfig
from .kernel import KernelConfig
from .kv_events import KVEventsConfig
from .kv_transfer import KVTransferConfig
from .load import LoadConfig
@@ -129,6 +130,9 @@ OPTIMIZATION_LEVEL_00 = {
"cudagraph_mode": CUDAGraphMode.NONE,
"use_inductor_graph_partition": False,
},
"kernel_config": {
"enable_flashinfer_autotune": False,
},
}
OPTIMIZATION_LEVEL_01 = {
"compilation_config": {
@@ -145,6 +149,9 @@ OPTIMIZATION_LEVEL_01 = {
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
"use_inductor_graph_partition": False,
},
"kernel_config": {
"enable_flashinfer_autotune": True,
},
}
OPTIMIZATION_LEVEL_02 = {
"compilation_config": {
@@ -161,6 +168,9 @@ OPTIMIZATION_LEVEL_02 = {
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
"use_inductor_graph_partition": False,
},
"kernel_config": {
"enable_flashinfer_autotune": True,
},
}
OPTIMIZATION_LEVEL_03 = {
"compilation_config": {
@@ -177,6 +187,9 @@ OPTIMIZATION_LEVEL_03 = {
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
"use_inductor_graph_partition": False,
},
"kernel_config": {
"enable_flashinfer_autotune": True,
},
}
OPTIMIZATION_LEVEL_TO_CONFIG = {
@@ -211,6 +224,8 @@ class VllmConfig:
"""Load configuration."""
attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
"""Attention configuration."""
kernel_config: KernelConfig = Field(default_factory=KernelConfig)
"""Kernel configuration."""
lora_config: LoRAConfig | None = None
"""LoRA configuration."""
speculative_config: SpeculativeConfig | None = None
@@ -756,6 +771,11 @@ class VllmConfig:
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
self._apply_optimization_level_defaults(default_config)
if self.kernel_config.enable_flashinfer_autotune is None:
raise ValueError(
"KernelConfig.enable_flashinfer_autotune must be set after applying "
"optimization level defaults."
)
if (
self.compilation_config.cudagraph_mode.requires_piecewise_compilation()

View File

@@ -40,6 +40,7 @@ from vllm.config import (
DeviceConfig,
ECTransferConfig,
EPLBConfig,
KernelConfig,
KVEventsConfig,
KVTransferConfig,
LoadConfig,
@@ -536,6 +537,10 @@ class EngineArgs:
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
kernel_config: KernelConfig = get_field(VllmConfig, "kernel_config")
enable_flashinfer_autotune: bool = get_field(
KernelConfig, "enable_flashinfer_autotune"
)
worker_cls: str = ParallelConfig.worker_cls
worker_extension_cls: str = ParallelConfig.worker_extension_cls
@@ -595,6 +600,8 @@ class EngineArgs:
self.compilation_config = CompilationConfig(**self.compilation_config)
if isinstance(self.attention_config, dict):
self.attention_config = AttentionConfig(**self.attention_config)
if isinstance(self.kernel_config, dict):
self.kernel_config = KernelConfig(**self.kernel_config)
if isinstance(self.eplb_config, dict):
self.eplb_config = EPLBConfig(**self.eplb_config)
if isinstance(self.weight_transfer_config, dict):
@@ -1163,6 +1170,17 @@ class EngineArgs:
**compilation_kwargs["max_cudagraph_capture_size"],
)
# Kernel arguments
kernel_kwargs = get_kwargs(KernelConfig)
kernel_group = parser.add_argument_group(
title="KernelConfig",
description=KernelConfig.__doc__,
)
kernel_group.add_argument(
"--enable-flashinfer-autotune",
**kernel_kwargs["enable_flashinfer_autotune"],
)
# vLLM arguments
vllm_kwargs = get_kwargs(VllmConfig)
vllm_group = parser.add_argument_group(
@@ -1189,6 +1207,7 @@ class EngineArgs:
vllm_group.add_argument(
"--attention-config", "-ac", **vllm_kwargs["attention_config"]
)
vllm_group.add_argument("--kernel-config", **vllm_kwargs["kernel_config"])
vllm_group.add_argument(
"--additional-config", **vllm_kwargs["additional_config"]
)
@@ -1717,6 +1736,17 @@ class EngineArgs:
else:
attention_config.backend = self.attention_backend
# Kernel config overrides
kernel_config = copy.deepcopy(self.kernel_config)
if self.enable_flashinfer_autotune is not None:
if kernel_config.enable_flashinfer_autotune is not None:
raise ValueError(
"enable_flashinfer_autotune and "
"kernel_config.enable_flashinfer_autotune "
"are mutually exclusive"
)
kernel_config.enable_flashinfer_autotune = self.enable_flashinfer_autotune
load_config = self.create_load_config()
# Pass reasoning_parser into StructuredOutputsConfig
@@ -1767,6 +1797,7 @@ class EngineArgs:
device_config=device_config,
load_config=load_config,
attention_config=attention_config,
kernel_config=kernel_config,
lora_config=lora_config,
speculative_config=speculative_config,
structured_outputs_config=self.structured_outputs_config,

View File

@@ -36,8 +36,13 @@ def kernel_warmup(worker: "Worker"):
max_tokens = worker.scheduler_config.max_num_batched_tokens
deep_gemm_warmup(model, max_tokens)
enable_flashinfer_autotune = (
worker.vllm_config.kernel_config.enable_flashinfer_autotune
)
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
if has_flashinfer() and current_platform.has_device_capability(90):
if enable_flashinfer_autotune is False:
logger.info("Skipping FlashInfer autotune because it is disabled.")
elif has_flashinfer() and current_platform.has_device_capability(90):
flashinfer_autotune(worker.model_runner)
# FlashInfer attention warmup