[core] clean up cudagraph batchsize padding logic (#10996)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
171
vllm/config.py
171
vllm/config.py
@@ -2354,6 +2354,12 @@ class CompilationConfig(BaseModel):
|
||||
# not configurable, computed after init
|
||||
compile_sizes: List[int] = PrivateAttr
|
||||
capture_sizes: List[int] = PrivateAttr
|
||||
max_capture_size: int = PrivateAttr
|
||||
# optimization:
|
||||
# Intuitively, bs_to_padded_graph_size should be Dict[int, int].
|
||||
# since we know all keys are in a range [0, max_capture_size],
|
||||
# we can optimize it to List[int] for better lookup performance.
|
||||
bs_to_padded_graph_size: List[int] = PrivateAttr
|
||||
|
||||
# keep track of enabled and disabled custom ops
|
||||
enabled_custom_ops: Counter[str] = PrivateAttr
|
||||
@@ -2365,6 +2371,19 @@ class CompilationConfig(BaseModel):
|
||||
# Map from layer name to the attention cls
|
||||
static_forward_context: Dict[str, Any] = PrivateAttr
|
||||
|
||||
def __repr__(self) -> str:
|
||||
exclude = {
|
||||
"static_forward_context",
|
||||
"enabled_custom_ops",
|
||||
"disabled_custom_ops",
|
||||
"compilation_time",
|
||||
"bs_to_padded_graph_size",
|
||||
"pass_config",
|
||||
}
|
||||
return self.model_dump_json(exclude=exclude, exclude_unset=True)
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str) -> "CompilationConfig":
|
||||
"""Parse the CLI value for the compilation config."""
|
||||
@@ -2450,18 +2469,22 @@ class CompilationConfig(BaseModel):
|
||||
|
||||
# sort to make sure cudagraph capture sizes are in descending order
|
||||
self.capture_sizes.sort(reverse=True)
|
||||
self.max_capture_size = self.capture_sizes[
|
||||
0] if self.capture_sizes else 0
|
||||
|
||||
|
||||
_BATCH_SIZE_ALIGNMENT = 8
|
||||
# all the token sizes that **can** be captured by cudagraph.
|
||||
# they can be arbitrarily large.
|
||||
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
|
||||
# the actual sizes to capture will be determined by the model,
|
||||
# depending on the model's max_num_seqs.
|
||||
# NOTE: get_graph_batch_size needs to be updated if this list is changed.
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
||||
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
|
||||
]
|
||||
# pre-compute the mapping from batch size to padded graph size
|
||||
self.bs_to_padded_graph_size = [
|
||||
0 for i in range(self.max_capture_size + 1)
|
||||
]
|
||||
for end, start in zip(self.capture_sizes,
|
||||
self.capture_sizes[1:] + [0]):
|
||||
for bs in range(start, end):
|
||||
if bs == start:
|
||||
self.bs_to_padded_graph_size[bs] = start
|
||||
else:
|
||||
self.bs_to_padded_graph_size[bs] = end
|
||||
self.bs_to_padded_graph_size[
|
||||
self.max_capture_size] = self.max_capture_size
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -2491,40 +2514,12 @@ class VllmConfig:
|
||||
init=True) # type: ignore
|
||||
instance_id: str = ""
|
||||
|
||||
@staticmethod
|
||||
def get_graph_batch_size(batch_size: int) -> int:
|
||||
"""Returns the padded batch size given actual batch size.
|
||||
|
||||
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
|
||||
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
|
||||
"""
|
||||
if batch_size <= 2:
|
||||
return batch_size
|
||||
elif batch_size <= 4:
|
||||
return 4
|
||||
else:
|
||||
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
|
||||
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
|
||||
|
||||
@staticmethod
|
||||
def get_max_graph_batch_size(max_num_seqs: int) -> int:
|
||||
"""
|
||||
max_num_seqs: Maximum number of sequences in a batch.
|
||||
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
|
||||
|
||||
pad the max_num_seqs if necessary by calling get_graph_batch_size,
|
||||
which will deal with some edge cases like 1, 2, 4.
|
||||
|
||||
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded
|
||||
size. if not, it means the padded size is larger than the largest size
|
||||
in _BATCH_SIZES_TO_CAPTURE, return the largest size in
|
||||
_BATCH_SIZES_TO_CAPTURE.
|
||||
"""
|
||||
padded_size = VllmConfig.get_graph_batch_size(max_num_seqs)
|
||||
if padded_size in _BATCH_SIZES_TO_CAPTURE:
|
||||
return padded_size
|
||||
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
return _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
def pad_for_cudagraph(self, batch_size: int) -> int:
|
||||
# if batch_size > self.compilation_config.max_capture_size,
|
||||
# it should raise an IndexError.
|
||||
# the caller should make sure the batch_size is within the range,
|
||||
# i.e., batch_size <= self.compilation_config.max_capture_size
|
||||
return self.compilation_config.bs_to_padded_graph_size[batch_size]
|
||||
|
||||
@staticmethod
|
||||
def _get_quantization_config(
|
||||
@@ -2618,27 +2613,7 @@ class VllmConfig:
|
||||
self.compilation_config.pass_config.enable_reshape = False
|
||||
self.compilation_config.level = CompilationLevel.PIECEWISE
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
max_batchsize_to_capture = 0
|
||||
if self.scheduler_config is not None and \
|
||||
self.model_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
max_batchsize_to_capture = \
|
||||
self.get_max_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs)
|
||||
batch_size_capture_list = [
|
||||
size for size in _BATCH_SIZES_TO_CAPTURE
|
||||
if size <= max_batchsize_to_capture
|
||||
]
|
||||
else:
|
||||
batch_size_capture_list = []
|
||||
if self.model_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
batch_size_capture_list = [1, 2, 4
|
||||
] + [i for i in range(8, 513, 8)]
|
||||
|
||||
self.compilation_config.init_with_cudagraph_sizes(
|
||||
batch_size_capture_list)
|
||||
self._set_cudagraph_sizes()
|
||||
|
||||
if self.cache_config is not None and \
|
||||
self.cache_config.cpu_offload_gb > 0 and \
|
||||
@@ -2659,6 +2634,70 @@ class VllmConfig:
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
|
||||
def _set_cudagraph_sizes(self):
|
||||
"""
|
||||
cudagraph batchsize padding logic:
|
||||
|
||||
`[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible
|
||||
batch sizes that cudagraph will capture.
|
||||
|
||||
Depending on the engine's configuration of `max_num_seqs`, the
|
||||
candidate batch sizes to capture cudagraph will shrink to the subset
|
||||
which just cover the range of `[1, max_num_seqs]`. In the common case,
|
||||
`max_num_seqs` is 256, and the cudagraph batch sizes will be
|
||||
`[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`.
|
||||
|
||||
However, if users specify the cudagraph capture sizes through
|
||||
compilation config, we will use the specified sizes instead.
|
||||
|
||||
In the end, `vllm_config.compilation_config.capture_sizes` will be the
|
||||
final sizes to capture cudagraph (in descending order).
|
||||
|
||||
During runtime, if batchsize is larger than
|
||||
`vllm_config.compilation_config.capture_sizes`,
|
||||
no cudagraph will be used.
|
||||
If the batch size is no larger than
|
||||
`vllm_config.compilation_config.capture_sizes`,
|
||||
we can quickly find the padded graph size for a given batch size by
|
||||
looking up `vllm_config.compilation_config.bs_to_padded_graph_size`.
|
||||
"""
|
||||
|
||||
# calculate the default `batch_size_capture_list`
|
||||
if not envs.VLLM_USE_V1:
|
||||
batch_size_capture_list = []
|
||||
max_batchsize_to_capture = 0
|
||||
if self.scheduler_config is not None and \
|
||||
self.model_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
|
||||
possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)]
|
||||
# find the minimum size that is larger than max_num_seqs,
|
||||
# which then becomes the max_batchsize_to_capture
|
||||
larger_sizes = [
|
||||
x for x in possible_sizes
|
||||
if x >= self.scheduler_config.max_num_seqs
|
||||
]
|
||||
if larger_sizes:
|
||||
max_batchsize_to_capture = larger_sizes[0]
|
||||
else:
|
||||
max_batchsize_to_capture = possible_sizes[-1]
|
||||
|
||||
# filter out the sizes that are
|
||||
# larger than max_batchsize_to_capture
|
||||
batch_size_capture_list = [
|
||||
size for size in possible_sizes
|
||||
if size <= max_batchsize_to_capture
|
||||
]
|
||||
else:
|
||||
batch_size_capture_list = []
|
||||
if self.model_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
batch_size_capture_list = [1, 2, 4
|
||||
] + [i for i in range(8, 513, 8)]
|
||||
|
||||
self.compilation_config.init_with_cudagraph_sizes(
|
||||
batch_size_capture_list)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"model={self.model_config.model!r},"
|
||||
|
||||
Reference in New Issue
Block a user