[Kernel] Add KernelConfig flag to enable/disable FlashInfer autotune (#34006)
Signed-off-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
edb359cce4
commit
dd6a6e1190
@@ -11,6 +11,7 @@ from vllm.config.compilation import (
|
|||||||
)
|
)
|
||||||
from vllm.config.device import DeviceConfig
|
from vllm.config.device import DeviceConfig
|
||||||
from vllm.config.ec_transfer import ECTransferConfig
|
from vllm.config.ec_transfer import ECTransferConfig
|
||||||
|
from vllm.config.kernel import KernelConfig
|
||||||
from vllm.config.kv_events import KVEventsConfig
|
from vllm.config.kv_events import KVEventsConfig
|
||||||
from vllm.config.kv_transfer import KVTransferConfig
|
from vllm.config.kv_transfer import KVTransferConfig
|
||||||
from vllm.config.load import LoadConfig
|
from vllm.config.load import LoadConfig
|
||||||
@@ -65,6 +66,8 @@ __all__ = [
|
|||||||
"DeviceConfig",
|
"DeviceConfig",
|
||||||
# From vllm.config.ec_transfer
|
# From vllm.config.ec_transfer
|
||||||
"ECTransferConfig",
|
"ECTransferConfig",
|
||||||
|
# From vllm.config.kernel
|
||||||
|
"KernelConfig",
|
||||||
# From vllm.config.kv_events
|
# From vllm.config.kv_events
|
||||||
"KVEventsConfig",
|
"KVEventsConfig",
|
||||||
# From vllm.config.kv_transfer
|
# From vllm.config.kv_transfer
|
||||||
|
|||||||
44
vllm/config/kernel.py
Normal file
44
vllm/config/kernel.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
|
from vllm.config.utils import config
|
||||||
|
from vllm.utils.hashing import safe_hash
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
class KernelConfig:
|
||||||
|
"""Configuration for kernel selection and warmup behavior."""
|
||||||
|
|
||||||
|
enable_flashinfer_autotune: bool = Field(default=None)
|
||||||
|
"""If True, run FlashInfer autotuning during kernel warmup."""
|
||||||
|
|
||||||
|
def compute_hash(self) -> str:
|
||||||
|
"""
|
||||||
|
WARNING: Whenever a new field is added to this config,
|
||||||
|
ensure that it is included in the factors list if
|
||||||
|
it affects the computation graph.
|
||||||
|
|
||||||
|
Provide a hash that uniquely identifies all the configs
|
||||||
|
that affect the structure of the computation
|
||||||
|
graph from input ids/embeddings to the final hidden states,
|
||||||
|
excluding anything before input ids/embeddings and after
|
||||||
|
the final hidden states.
|
||||||
|
"""
|
||||||
|
# no factors to consider.
|
||||||
|
# this config will not affect the computation graph.
|
||||||
|
factors: list[Any] = []
|
||||||
|
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
|
return hash_str
|
||||||
|
|
||||||
|
@field_validator("enable_flashinfer_autotune", mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
||||||
|
"""Skip validation if the value is `None` when initialization is delayed."""
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return handler(value)
|
||||||
@@ -30,6 +30,7 @@ from .cache import CacheConfig
|
|||||||
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
|
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
|
||||||
from .device import DeviceConfig
|
from .device import DeviceConfig
|
||||||
from .ec_transfer import ECTransferConfig
|
from .ec_transfer import ECTransferConfig
|
||||||
|
from .kernel import KernelConfig
|
||||||
from .kv_events import KVEventsConfig
|
from .kv_events import KVEventsConfig
|
||||||
from .kv_transfer import KVTransferConfig
|
from .kv_transfer import KVTransferConfig
|
||||||
from .load import LoadConfig
|
from .load import LoadConfig
|
||||||
@@ -129,6 +130,9 @@ OPTIMIZATION_LEVEL_00 = {
|
|||||||
"cudagraph_mode": CUDAGraphMode.NONE,
|
"cudagraph_mode": CUDAGraphMode.NONE,
|
||||||
"use_inductor_graph_partition": False,
|
"use_inductor_graph_partition": False,
|
||||||
},
|
},
|
||||||
|
"kernel_config": {
|
||||||
|
"enable_flashinfer_autotune": False,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
OPTIMIZATION_LEVEL_01 = {
|
OPTIMIZATION_LEVEL_01 = {
|
||||||
"compilation_config": {
|
"compilation_config": {
|
||||||
@@ -145,6 +149,9 @@ OPTIMIZATION_LEVEL_01 = {
|
|||||||
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
|
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
|
||||||
"use_inductor_graph_partition": False,
|
"use_inductor_graph_partition": False,
|
||||||
},
|
},
|
||||||
|
"kernel_config": {
|
||||||
|
"enable_flashinfer_autotune": True,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
OPTIMIZATION_LEVEL_02 = {
|
OPTIMIZATION_LEVEL_02 = {
|
||||||
"compilation_config": {
|
"compilation_config": {
|
||||||
@@ -161,6 +168,9 @@ OPTIMIZATION_LEVEL_02 = {
|
|||||||
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||||
"use_inductor_graph_partition": False,
|
"use_inductor_graph_partition": False,
|
||||||
},
|
},
|
||||||
|
"kernel_config": {
|
||||||
|
"enable_flashinfer_autotune": True,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
OPTIMIZATION_LEVEL_03 = {
|
OPTIMIZATION_LEVEL_03 = {
|
||||||
"compilation_config": {
|
"compilation_config": {
|
||||||
@@ -177,6 +187,9 @@ OPTIMIZATION_LEVEL_03 = {
|
|||||||
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||||
"use_inductor_graph_partition": False,
|
"use_inductor_graph_partition": False,
|
||||||
},
|
},
|
||||||
|
"kernel_config": {
|
||||||
|
"enable_flashinfer_autotune": True,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
OPTIMIZATION_LEVEL_TO_CONFIG = {
|
OPTIMIZATION_LEVEL_TO_CONFIG = {
|
||||||
@@ -211,6 +224,8 @@ class VllmConfig:
|
|||||||
"""Load configuration."""
|
"""Load configuration."""
|
||||||
attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
|
attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
|
||||||
"""Attention configuration."""
|
"""Attention configuration."""
|
||||||
|
kernel_config: KernelConfig = Field(default_factory=KernelConfig)
|
||||||
|
"""Kernel configuration."""
|
||||||
lora_config: LoRAConfig | None = None
|
lora_config: LoRAConfig | None = None
|
||||||
"""LoRA configuration."""
|
"""LoRA configuration."""
|
||||||
speculative_config: SpeculativeConfig | None = None
|
speculative_config: SpeculativeConfig | None = None
|
||||||
@@ -756,6 +771,11 @@ class VllmConfig:
|
|||||||
|
|
||||||
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
|
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
|
||||||
self._apply_optimization_level_defaults(default_config)
|
self._apply_optimization_level_defaults(default_config)
|
||||||
|
if self.kernel_config.enable_flashinfer_autotune is None:
|
||||||
|
raise ValueError(
|
||||||
|
"KernelConfig.enable_flashinfer_autotune must be set after applying "
|
||||||
|
"optimization level defaults."
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
|
self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ from vllm.config import (
|
|||||||
DeviceConfig,
|
DeviceConfig,
|
||||||
ECTransferConfig,
|
ECTransferConfig,
|
||||||
EPLBConfig,
|
EPLBConfig,
|
||||||
|
KernelConfig,
|
||||||
KVEventsConfig,
|
KVEventsConfig,
|
||||||
KVTransferConfig,
|
KVTransferConfig,
|
||||||
LoadConfig,
|
LoadConfig,
|
||||||
@@ -536,6 +537,10 @@ class EngineArgs:
|
|||||||
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
|
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
|
||||||
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
|
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
|
||||||
attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
|
attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
|
||||||
|
kernel_config: KernelConfig = get_field(VllmConfig, "kernel_config")
|
||||||
|
enable_flashinfer_autotune: bool = get_field(
|
||||||
|
KernelConfig, "enable_flashinfer_autotune"
|
||||||
|
)
|
||||||
worker_cls: str = ParallelConfig.worker_cls
|
worker_cls: str = ParallelConfig.worker_cls
|
||||||
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
||||||
|
|
||||||
@@ -595,6 +600,8 @@ class EngineArgs:
|
|||||||
self.compilation_config = CompilationConfig(**self.compilation_config)
|
self.compilation_config = CompilationConfig(**self.compilation_config)
|
||||||
if isinstance(self.attention_config, dict):
|
if isinstance(self.attention_config, dict):
|
||||||
self.attention_config = AttentionConfig(**self.attention_config)
|
self.attention_config = AttentionConfig(**self.attention_config)
|
||||||
|
if isinstance(self.kernel_config, dict):
|
||||||
|
self.kernel_config = KernelConfig(**self.kernel_config)
|
||||||
if isinstance(self.eplb_config, dict):
|
if isinstance(self.eplb_config, dict):
|
||||||
self.eplb_config = EPLBConfig(**self.eplb_config)
|
self.eplb_config = EPLBConfig(**self.eplb_config)
|
||||||
if isinstance(self.weight_transfer_config, dict):
|
if isinstance(self.weight_transfer_config, dict):
|
||||||
@@ -1163,6 +1170,17 @@ class EngineArgs:
|
|||||||
**compilation_kwargs["max_cudagraph_capture_size"],
|
**compilation_kwargs["max_cudagraph_capture_size"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Kernel arguments
|
||||||
|
kernel_kwargs = get_kwargs(KernelConfig)
|
||||||
|
kernel_group = parser.add_argument_group(
|
||||||
|
title="KernelConfig",
|
||||||
|
description=KernelConfig.__doc__,
|
||||||
|
)
|
||||||
|
kernel_group.add_argument(
|
||||||
|
"--enable-flashinfer-autotune",
|
||||||
|
**kernel_kwargs["enable_flashinfer_autotune"],
|
||||||
|
)
|
||||||
|
|
||||||
# vLLM arguments
|
# vLLM arguments
|
||||||
vllm_kwargs = get_kwargs(VllmConfig)
|
vllm_kwargs = get_kwargs(VllmConfig)
|
||||||
vllm_group = parser.add_argument_group(
|
vllm_group = parser.add_argument_group(
|
||||||
@@ -1189,6 +1207,7 @@ class EngineArgs:
|
|||||||
vllm_group.add_argument(
|
vllm_group.add_argument(
|
||||||
"--attention-config", "-ac", **vllm_kwargs["attention_config"]
|
"--attention-config", "-ac", **vllm_kwargs["attention_config"]
|
||||||
)
|
)
|
||||||
|
vllm_group.add_argument("--kernel-config", **vllm_kwargs["kernel_config"])
|
||||||
vllm_group.add_argument(
|
vllm_group.add_argument(
|
||||||
"--additional-config", **vllm_kwargs["additional_config"]
|
"--additional-config", **vllm_kwargs["additional_config"]
|
||||||
)
|
)
|
||||||
@@ -1717,6 +1736,17 @@ class EngineArgs:
|
|||||||
else:
|
else:
|
||||||
attention_config.backend = self.attention_backend
|
attention_config.backend = self.attention_backend
|
||||||
|
|
||||||
|
# Kernel config overrides
|
||||||
|
kernel_config = copy.deepcopy(self.kernel_config)
|
||||||
|
if self.enable_flashinfer_autotune is not None:
|
||||||
|
if kernel_config.enable_flashinfer_autotune is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"enable_flashinfer_autotune and "
|
||||||
|
"kernel_config.enable_flashinfer_autotune "
|
||||||
|
"are mutually exclusive"
|
||||||
|
)
|
||||||
|
kernel_config.enable_flashinfer_autotune = self.enable_flashinfer_autotune
|
||||||
|
|
||||||
load_config = self.create_load_config()
|
load_config = self.create_load_config()
|
||||||
|
|
||||||
# Pass reasoning_parser into StructuredOutputsConfig
|
# Pass reasoning_parser into StructuredOutputsConfig
|
||||||
@@ -1767,6 +1797,7 @@ class EngineArgs:
|
|||||||
device_config=device_config,
|
device_config=device_config,
|
||||||
load_config=load_config,
|
load_config=load_config,
|
||||||
attention_config=attention_config,
|
attention_config=attention_config,
|
||||||
|
kernel_config=kernel_config,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
speculative_config=speculative_config,
|
speculative_config=speculative_config,
|
||||||
structured_outputs_config=self.structured_outputs_config,
|
structured_outputs_config=self.structured_outputs_config,
|
||||||
|
|||||||
@@ -36,8 +36,13 @@ def kernel_warmup(worker: "Worker"):
|
|||||||
max_tokens = worker.scheduler_config.max_num_batched_tokens
|
max_tokens = worker.scheduler_config.max_num_batched_tokens
|
||||||
deep_gemm_warmup(model, max_tokens)
|
deep_gemm_warmup(model, max_tokens)
|
||||||
|
|
||||||
|
enable_flashinfer_autotune = (
|
||||||
|
worker.vllm_config.kernel_config.enable_flashinfer_autotune
|
||||||
|
)
|
||||||
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
|
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
|
||||||
if has_flashinfer() and current_platform.has_device_capability(90):
|
if enable_flashinfer_autotune is False:
|
||||||
|
logger.info("Skipping FlashInfer autotune because it is disabled.")
|
||||||
|
elif has_flashinfer() and current_platform.has_device_capability(90):
|
||||||
flashinfer_autotune(worker.model_runner)
|
flashinfer_autotune(worker.model_runner)
|
||||||
|
|
||||||
# FlashInfer attention warmup
|
# FlashInfer attention warmup
|
||||||
|
|||||||
Reference in New Issue
Block a user