[torch.compile] remove compilation_context and simplify code (#10838)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -2357,15 +2357,10 @@ class CompilationConfig(BaseModel):
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
return VllmBackend(self)
|
||||
|
||||
def init_during_runtime(self):
|
||||
def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
|
||||
"""To complete the initialization of config,
|
||||
we need to know the compile context, which is only available
|
||||
during the first run of the model.
|
||||
"""
|
||||
from vllm.compilation.compile_context import get_compile_context
|
||||
context = get_compile_context()
|
||||
context = copy.deepcopy(context) if context is not None else []
|
||||
sizes_to_specialize: List[int] = context
|
||||
we need to know the cudagraph sizes."""
|
||||
|
||||
if self.cudagraph_capture_sizes is None:
|
||||
self.capture_sizes = sizes_to_specialize
|
||||
else:
|
||||
@@ -2386,6 +2381,21 @@ class CompilationConfig(BaseModel):
|
||||
self.inductor_compile_sizes = []
|
||||
self.compile_sizes = self.inductor_compile_sizes
|
||||
|
||||
# sort to make sure cudagraph capture sizes are in descending order
|
||||
self.capture_sizes.sort(reverse=True)
|
||||
|
||||
|
||||
_BATCH_SIZE_ALIGNMENT = 8
|
||||
# all the token sizes that **can** be captured by cudagraph.
|
||||
# they can be arbitrarily large.
|
||||
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
|
||||
# the actual sizes to capture will be determined by the model,
|
||||
# depending on the model's max_num_seqs.
|
||||
# NOTE: get_graph_batch_size needs to be updated if this list is changed.
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
||||
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class VllmConfig:
|
||||
@@ -2413,6 +2423,41 @@ class VllmConfig:
|
||||
kv_transfer_config: KVTransferConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def get_graph_batch_size(batch_size: int) -> int:
|
||||
"""Returns the padded batch size given actual batch size.
|
||||
|
||||
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
|
||||
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
|
||||
"""
|
||||
if batch_size <= 2:
|
||||
return batch_size
|
||||
elif batch_size <= 4:
|
||||
return 4
|
||||
else:
|
||||
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
|
||||
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
|
||||
|
||||
@staticmethod
|
||||
def get_max_graph_batch_size(max_num_seqs: int) -> int:
|
||||
"""
|
||||
max_num_seqs: Maximum number of sequences in a batch.
|
||||
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
|
||||
|
||||
pad the max_num_seqs if necessary by calling get_graph_batch_size,
|
||||
which will deal with some edge cases like 1, 2, 4.
|
||||
|
||||
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded
|
||||
size. if not, it means the padded size is larger than the largest size
|
||||
in _BATCH_SIZES_TO_CAPTURE, return the largest size in
|
||||
_BATCH_SIZES_TO_CAPTURE.
|
||||
"""
|
||||
padded_size = VllmConfig.get_graph_batch_size(max_num_seqs)
|
||||
if padded_size in _BATCH_SIZES_TO_CAPTURE:
|
||||
return padded_size
|
||||
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
return _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
|
||||
@staticmethod
|
||||
def _get_quantization_config(
|
||||
model_config: ModelConfig,
|
||||
@@ -2496,6 +2541,28 @@ class VllmConfig:
|
||||
self.compilation_config.pass_config.enable_reshape = False
|
||||
self.compilation_config.level = CompilationLevel.PIECEWISE
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
max_batchsize_to_capture = 0
|
||||
if self.scheduler_config is not None and \
|
||||
self.model_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
max_batchsize_to_capture = \
|
||||
self.get_max_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs)
|
||||
batch_size_capture_list = [
|
||||
size for size in _BATCH_SIZES_TO_CAPTURE
|
||||
if size <= max_batchsize_to_capture
|
||||
]
|
||||
else:
|
||||
batch_size_capture_list = []
|
||||
if self.model_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
batch_size_capture_list = [1, 2, 4
|
||||
] + [i for i in range(8, 513, 8)]
|
||||
|
||||
self.compilation_config.init_with_cudagraph_sizes(
|
||||
batch_size_capture_list)
|
||||
|
||||
if self.cache_config is not None and \
|
||||
self.cache_config.cpu_offload_gb > 0 and \
|
||||
self.compilation_config.level != CompilationLevel.NO_COMPILATION:
|
||||
|
||||
Reference in New Issue
Block a user