[torch.compile] decouple compile sizes and cudagraph sizes (#12243)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -2711,10 +2711,11 @@ class CompilationConfig(BaseModel):
|
||||
- use_inductor: whether to use inductor compilation.
|
||||
- False: inductor compilation is not used. graph runs in eager.
|
||||
- True: inductor compilation is used. one graph for symbolic shape
|
||||
is compiled. In addition, compile for cudagraph sizes that are
|
||||
in candidate_compile_sizes, using configurations
|
||||
in inductor_compile_config.
|
||||
- candidate_compile_sizes: sizes to compile for inductor.
|
||||
is compiled. In addition, compile for compile_sizes,
|
||||
using configurations in inductor_compile_config.
|
||||
- compile_sizes: sizes to compile for inductor. In addition
|
||||
to integers, it also supports "cudagraph_capture_sizes" to
|
||||
specify the sizes for cudagraph capture.
|
||||
- inductor_compile_config: additional configurations for inductor.
|
||||
- None: use default configurations.
|
||||
- inductor_passes: additional passes for inductor. It is a dictionary
|
||||
@@ -2742,7 +2743,7 @@ class CompilationConfig(BaseModel):
|
||||
splitting_ops: List[str] = Field(default=None) # type: ignore
|
||||
|
||||
use_inductor: bool = True
|
||||
candidate_compile_sizes: Optional[List[int]] = Field(default=None)
|
||||
compile_sizes: Optional[List[Union[int, str]]] = Field(default=None)
|
||||
inductor_compile_config: Dict = Field(default_factory=dict)
|
||||
inductor_passes: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
@@ -2790,8 +2791,6 @@ class CompilationConfig(BaseModel):
|
||||
pass_config: PassConfig = Field(default_factory=PassConfig)
|
||||
|
||||
# not configurable, computed after init
|
||||
compile_sizes: List[int] = PrivateAttr
|
||||
capture_sizes: List[int] = PrivateAttr
|
||||
max_capture_size: int = PrivateAttr
|
||||
local_cache_dir: str = PrivateAttr # local cache dir for each rank
|
||||
# optimization:
|
||||
@@ -2918,43 +2917,47 @@ class CompilationConfig(BaseModel):
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
return VllmBackend(vllm_config)
|
||||
|
||||
def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
|
||||
def init_with_cudagraph_sizes(self,
|
||||
cudagraph_capture_sizes: List[int]) -> None:
|
||||
"""To complete the initialization of config,
|
||||
we need to know the cudagraph sizes."""
|
||||
|
||||
if self.cudagraph_capture_sizes is None:
|
||||
self.capture_sizes = sizes_to_specialize
|
||||
self.cudagraph_capture_sizes = cudagraph_capture_sizes
|
||||
else:
|
||||
self.capture_sizes = self.cudagraph_capture_sizes
|
||||
# de-duplicate the sizes provided by the config
|
||||
self.cudagraph_capture_sizes = list(
|
||||
set(self.cudagraph_capture_sizes))
|
||||
logger.info(("cudagraph sizes specified by model runner"
|
||||
" %s is overridden by config %s"),
|
||||
sizes_to_specialize, self.cudagraph_capture_sizes)
|
||||
cudagraph_capture_sizes, self.cudagraph_capture_sizes)
|
||||
|
||||
if self.candidate_compile_sizes is None:
|
||||
self.candidate_compile_sizes = []
|
||||
self.compile_sizes = [
|
||||
x for x in self.candidate_compile_sizes if x in self.capture_sizes
|
||||
]
|
||||
ignored_sizes = [
|
||||
x for x in self.candidate_compile_sizes
|
||||
if x not in self.capture_sizes
|
||||
]
|
||||
if ignored_sizes:
|
||||
logger.warning(("candidate_compile_sizes %s are ignored "
|
||||
"because they are not cudagraph capture sizes."),
|
||||
ignored_sizes)
|
||||
computed_compile_sizes = []
|
||||
if self.compile_sizes is not None:
|
||||
# de-duplicate the sizes provided by the config
|
||||
self.compile_sizes = list(set(self.compile_sizes))
|
||||
for x in self.compile_sizes:
|
||||
if isinstance(x, str):
|
||||
assert x == "cudagraph_capture_sizes", \
|
||||
"Unrecognized size type in compile_sizes, " \
|
||||
f"expect 'cudagraph_capture_sizes', got {x}"
|
||||
computed_compile_sizes.extend(self.cudagraph_capture_sizes)
|
||||
else:
|
||||
assert isinstance(x, int)
|
||||
computed_compile_sizes.append(x)
|
||||
self.compile_sizes = computed_compile_sizes # type: ignore
|
||||
|
||||
# sort to make sure cudagraph capture sizes are in descending order
|
||||
self.capture_sizes.sort(reverse=True)
|
||||
self.max_capture_size = self.capture_sizes[
|
||||
0] if self.capture_sizes else 0
|
||||
self.cudagraph_capture_sizes.sort(reverse=True)
|
||||
self.max_capture_size = self.cudagraph_capture_sizes[
|
||||
0] if self.cudagraph_capture_sizes else 0
|
||||
|
||||
# pre-compute the mapping from batch size to padded graph size
|
||||
self.bs_to_padded_graph_size = [
|
||||
0 for i in range(self.max_capture_size + 1)
|
||||
]
|
||||
for end, start in zip(self.capture_sizes,
|
||||
self.capture_sizes[1:] + [0]):
|
||||
for end, start in zip(self.cudagraph_capture_sizes,
|
||||
self.cudagraph_capture_sizes[1:] + [0]):
|
||||
for bs in range(start, end):
|
||||
if bs == start:
|
||||
self.bs_to_padded_graph_size[bs] = start
|
||||
@@ -3225,14 +3228,14 @@ class VllmConfig:
|
||||
However, if users specify the cudagraph capture sizes through
|
||||
compilation config, we will use the specified sizes instead.
|
||||
|
||||
In the end, `vllm_config.compilation_config.capture_sizes` will be the
|
||||
final sizes to capture cudagraph (in descending order).
|
||||
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
|
||||
will be the final sizes to capture cudagraph (in descending order).
|
||||
|
||||
During runtime, if batchsize is larger than
|
||||
`vllm_config.compilation_config.capture_sizes`,
|
||||
`vllm_config.compilation_config.cudagraph_capture_sizes`,
|
||||
no cudagraph will be used.
|
||||
If the batch size is no larger than
|
||||
`vllm_config.compilation_config.capture_sizes`,
|
||||
`vllm_config.compilation_config.cudagraph_capture_sizes`,
|
||||
we can quickly find the padded graph size for a given batch size by
|
||||
looking up `vllm_config.compilation_config.bs_to_padded_graph_size`.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user