[Kernel] Add KernelConfig flag to enable/disable FlashInfer autotune (#34006)
Signed-off-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
edb359cce4
commit
dd6a6e1190
@@ -11,6 +11,7 @@ from vllm.config.compilation import (
|
||||
)
|
||||
from vllm.config.device import DeviceConfig
|
||||
from vllm.config.ec_transfer import ECTransferConfig
|
||||
from vllm.config.kernel import KernelConfig
|
||||
from vllm.config.kv_events import KVEventsConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
@@ -65,6 +66,8 @@ __all__ = [
|
||||
"DeviceConfig",
|
||||
# From vllm.config.ec_transfer
|
||||
"ECTransferConfig",
|
||||
# From vllm.config.kernel
|
||||
"KernelConfig",
|
||||
# From vllm.config.kv_events
|
||||
"KVEventsConfig",
|
||||
# From vllm.config.kv_transfer
|
||||
|
||||
44
vllm/config/kernel.py
Normal file
44
vllm/config/kernel.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
|
||||
@config
|
||||
class KernelConfig:
|
||||
"""Configuration for kernel selection and warmup behavior."""
|
||||
|
||||
enable_flashinfer_autotune: bool = Field(default=None)
|
||||
"""If True, run FlashInfer autotuning during kernel warmup."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@field_validator("enable_flashinfer_autotune", mode="wrap")
|
||||
@classmethod
|
||||
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
||||
"""Skip validation if the value is `None` when initialization is delayed."""
|
||||
if value is None:
|
||||
return value
|
||||
return handler(value)
|
||||
@@ -30,6 +30,7 @@ from .cache import CacheConfig
|
||||
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
|
||||
from .device import DeviceConfig
|
||||
from .ec_transfer import ECTransferConfig
|
||||
from .kernel import KernelConfig
|
||||
from .kv_events import KVEventsConfig
|
||||
from .kv_transfer import KVTransferConfig
|
||||
from .load import LoadConfig
|
||||
@@ -129,6 +130,9 @@ OPTIMIZATION_LEVEL_00 = {
|
||||
"cudagraph_mode": CUDAGraphMode.NONE,
|
||||
"use_inductor_graph_partition": False,
|
||||
},
|
||||
"kernel_config": {
|
||||
"enable_flashinfer_autotune": False,
|
||||
},
|
||||
}
|
||||
OPTIMIZATION_LEVEL_01 = {
|
||||
"compilation_config": {
|
||||
@@ -145,6 +149,9 @@ OPTIMIZATION_LEVEL_01 = {
|
||||
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
|
||||
"use_inductor_graph_partition": False,
|
||||
},
|
||||
"kernel_config": {
|
||||
"enable_flashinfer_autotune": True,
|
||||
},
|
||||
}
|
||||
OPTIMIZATION_LEVEL_02 = {
|
||||
"compilation_config": {
|
||||
@@ -161,6 +168,9 @@ OPTIMIZATION_LEVEL_02 = {
|
||||
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
"use_inductor_graph_partition": False,
|
||||
},
|
||||
"kernel_config": {
|
||||
"enable_flashinfer_autotune": True,
|
||||
},
|
||||
}
|
||||
OPTIMIZATION_LEVEL_03 = {
|
||||
"compilation_config": {
|
||||
@@ -177,6 +187,9 @@ OPTIMIZATION_LEVEL_03 = {
|
||||
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
"use_inductor_graph_partition": False,
|
||||
},
|
||||
"kernel_config": {
|
||||
"enable_flashinfer_autotune": True,
|
||||
},
|
||||
}
|
||||
|
||||
OPTIMIZATION_LEVEL_TO_CONFIG = {
|
||||
@@ -211,6 +224,8 @@ class VllmConfig:
|
||||
"""Load configuration."""
|
||||
attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
|
||||
"""Attention configuration."""
|
||||
kernel_config: KernelConfig = Field(default_factory=KernelConfig)
|
||||
"""Kernel configuration."""
|
||||
lora_config: LoRAConfig | None = None
|
||||
"""LoRA configuration."""
|
||||
speculative_config: SpeculativeConfig | None = None
|
||||
@@ -756,6 +771,11 @@ class VllmConfig:
|
||||
|
||||
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
|
||||
self._apply_optimization_level_defaults(default_config)
|
||||
if self.kernel_config.enable_flashinfer_autotune is None:
|
||||
raise ValueError(
|
||||
"KernelConfig.enable_flashinfer_autotune must be set after applying "
|
||||
"optimization level defaults."
|
||||
)
|
||||
|
||||
if (
|
||||
self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
|
||||
|
||||
@@ -40,6 +40,7 @@ from vllm.config import (
|
||||
DeviceConfig,
|
||||
ECTransferConfig,
|
||||
EPLBConfig,
|
||||
KernelConfig,
|
||||
KVEventsConfig,
|
||||
KVTransferConfig,
|
||||
LoadConfig,
|
||||
@@ -536,6 +537,10 @@ class EngineArgs:
|
||||
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
|
||||
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
|
||||
attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
|
||||
kernel_config: KernelConfig = get_field(VllmConfig, "kernel_config")
|
||||
enable_flashinfer_autotune: bool = get_field(
|
||||
KernelConfig, "enable_flashinfer_autotune"
|
||||
)
|
||||
worker_cls: str = ParallelConfig.worker_cls
|
||||
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
||||
|
||||
@@ -595,6 +600,8 @@ class EngineArgs:
|
||||
self.compilation_config = CompilationConfig(**self.compilation_config)
|
||||
if isinstance(self.attention_config, dict):
|
||||
self.attention_config = AttentionConfig(**self.attention_config)
|
||||
if isinstance(self.kernel_config, dict):
|
||||
self.kernel_config = KernelConfig(**self.kernel_config)
|
||||
if isinstance(self.eplb_config, dict):
|
||||
self.eplb_config = EPLBConfig(**self.eplb_config)
|
||||
if isinstance(self.weight_transfer_config, dict):
|
||||
@@ -1163,6 +1170,17 @@ class EngineArgs:
|
||||
**compilation_kwargs["max_cudagraph_capture_size"],
|
||||
)
|
||||
|
||||
# Kernel arguments
|
||||
kernel_kwargs = get_kwargs(KernelConfig)
|
||||
kernel_group = parser.add_argument_group(
|
||||
title="KernelConfig",
|
||||
description=KernelConfig.__doc__,
|
||||
)
|
||||
kernel_group.add_argument(
|
||||
"--enable-flashinfer-autotune",
|
||||
**kernel_kwargs["enable_flashinfer_autotune"],
|
||||
)
|
||||
|
||||
# vLLM arguments
|
||||
vllm_kwargs = get_kwargs(VllmConfig)
|
||||
vllm_group = parser.add_argument_group(
|
||||
@@ -1189,6 +1207,7 @@ class EngineArgs:
|
||||
vllm_group.add_argument(
|
||||
"--attention-config", "-ac", **vllm_kwargs["attention_config"]
|
||||
)
|
||||
vllm_group.add_argument("--kernel-config", **vllm_kwargs["kernel_config"])
|
||||
vllm_group.add_argument(
|
||||
"--additional-config", **vllm_kwargs["additional_config"]
|
||||
)
|
||||
@@ -1717,6 +1736,17 @@ class EngineArgs:
|
||||
else:
|
||||
attention_config.backend = self.attention_backend
|
||||
|
||||
# Kernel config overrides
|
||||
kernel_config = copy.deepcopy(self.kernel_config)
|
||||
if self.enable_flashinfer_autotune is not None:
|
||||
if kernel_config.enable_flashinfer_autotune is not None:
|
||||
raise ValueError(
|
||||
"enable_flashinfer_autotune and "
|
||||
"kernel_config.enable_flashinfer_autotune "
|
||||
"are mutually exclusive"
|
||||
)
|
||||
kernel_config.enable_flashinfer_autotune = self.enable_flashinfer_autotune
|
||||
|
||||
load_config = self.create_load_config()
|
||||
|
||||
# Pass reasoning_parser into StructuredOutputsConfig
|
||||
@@ -1767,6 +1797,7 @@ class EngineArgs:
|
||||
device_config=device_config,
|
||||
load_config=load_config,
|
||||
attention_config=attention_config,
|
||||
kernel_config=kernel_config,
|
||||
lora_config=lora_config,
|
||||
speculative_config=speculative_config,
|
||||
structured_outputs_config=self.structured_outputs_config,
|
||||
|
||||
@@ -36,8 +36,13 @@ def kernel_warmup(worker: "Worker"):
|
||||
max_tokens = worker.scheduler_config.max_num_batched_tokens
|
||||
deep_gemm_warmup(model, max_tokens)
|
||||
|
||||
enable_flashinfer_autotune = (
|
||||
worker.vllm_config.kernel_config.enable_flashinfer_autotune
|
||||
)
|
||||
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
|
||||
if has_flashinfer() and current_platform.has_device_capability(90):
|
||||
if enable_flashinfer_autotune is False:
|
||||
logger.info("Skipping FlashInfer autotune because it is disabled.")
|
||||
elif has_flashinfer() and current_platform.has_device_capability(90):
|
||||
flashinfer_autotune(worker.model_runner)
|
||||
|
||||
# FlashInfer attention warmup
|
||||
|
||||
Reference in New Issue
Block a user