[6/N] torch.compile rollout to users (#10437)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-19 10:09:03 -08:00
committed by GitHub
parent fd9f124971
commit 803f37eaaa
15 changed files with 129 additions and 141 deletions

View File

@@ -8,12 +8,13 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
import torch
import vllm.envs as envs
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
DeviceConfig, HfOverrides, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, PoolerConfig, PromptAdapterConfig,
SchedulerConfig, SpeculativeConfig, TaskOption,
TokenizerPoolConfig, VllmConfig)
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
DecodingConfig, DeviceConfig, HfOverrides, LoadConfig,
LoadFormat, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
VllmConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@@ -189,6 +190,7 @@ class EngineArgs:
override_neuron_config: Optional[Dict[str, Any]] = None
override_pooler_config: Optional[PoolerConfig] = None
compilation_config: Optional[CompilationConfig] = None
def __post_init__(self):
if not self.tokenizer:
@@ -868,6 +870,20 @@ class EngineArgs:
help="Override or set the pooling method in the embedding model. "
"e.g. {\"pooling_type\": \"mean\", \"normalize\": false}.'")
parser.add_argument('--compilation-config',
'-O',
type=CompilationConfig.from_cli,
default=None,
help='torch.compile configuration for the model.'
'When it is a number (0, 1, 2, 3), it will be '
'interpreted as the optimization level.\n'
'NOTE: level 0 is the default level without '
'any optimization. level 1 and 2 are for internal '
'testing only. level 3 is the recommended level '
'for production.\n'
'To specify the full compilation config, '
'use a JSON string.')
return parser
@classmethod
@@ -1142,6 +1158,7 @@ class EngineArgs:
decoding_config=decoding_config,
observability_config=observability_config,
prompt_adapter_config=prompt_adapter_config,
compilation_config=self.compilation_config,
)