Refactor: Move CUDA graph dispatch logic earlier (#27382)
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -3740,6 +3740,31 @@ class GPUModelRunner(
|
|||||||
dp_rank = self.parallel_config.data_parallel_rank
|
dp_rank = self.parallel_config.data_parallel_rank
|
||||||
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
|
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
|
||||||
|
|
||||||
|
# filter out the valid batch descriptor
|
||||||
|
_cg_mode, batch_descriptor = (
|
||||||
|
self.cudagraph_dispatcher.dispatch(
|
||||||
|
BatchDescriptor(
|
||||||
|
num_tokens=num_tokens_after_padding,
|
||||||
|
uniform_decode=uniform_decode,
|
||||||
|
has_lora=activate_lora and self.lora_config is not None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not is_profile
|
||||||
|
else (CUDAGraphMode.NONE, None)
|
||||||
|
)
|
||||||
|
if cudagraph_runtime_mode is not None:
|
||||||
|
# we allow forcing NONE when the dispatcher disagrees to support
|
||||||
|
# warm ups for cudagraph capture
|
||||||
|
assert (
|
||||||
|
cudagraph_runtime_mode == CUDAGraphMode.NONE
|
||||||
|
or cudagraph_runtime_mode == _cg_mode
|
||||||
|
), (
|
||||||
|
f"Cudagraph runtime mode mismatch at dummy_run. "
|
||||||
|
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cudagraph_runtime_mode = _cg_mode
|
||||||
|
|
||||||
attn_metadata: PerLayerAttnMetadata | None = None
|
attn_metadata: PerLayerAttnMetadata | None = None
|
||||||
|
|
||||||
# If force_attention is True, we always capture attention. Otherwise,
|
# If force_attention is True, we always capture attention. Otherwise,
|
||||||
@@ -3814,31 +3839,6 @@ class GPUModelRunner(
|
|||||||
num_tokens_after_padding, None, False
|
num_tokens_after_padding, None, False
|
||||||
)
|
)
|
||||||
|
|
||||||
# filter out the valid batch descriptor
|
|
||||||
_cg_mode, batch_descriptor = (
|
|
||||||
self.cudagraph_dispatcher.dispatch(
|
|
||||||
BatchDescriptor(
|
|
||||||
num_tokens=num_tokens_after_padding,
|
|
||||||
uniform_decode=uniform_decode,
|
|
||||||
has_lora=activate_lora and self.lora_config is not None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if not is_profile
|
|
||||||
else (CUDAGraphMode.NONE, None)
|
|
||||||
)
|
|
||||||
if cudagraph_runtime_mode is not None:
|
|
||||||
# we allow forcing NONE when the dispatcher disagrees to support
|
|
||||||
# warm ups for cudagraph capture
|
|
||||||
assert (
|
|
||||||
cudagraph_runtime_mode == CUDAGraphMode.NONE
|
|
||||||
or cudagraph_runtime_mode == _cg_mode
|
|
||||||
), (
|
|
||||||
f"Cudagraph runtime mode mismatch at dummy_run. "
|
|
||||||
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cudagraph_runtime_mode = _cg_mode
|
|
||||||
|
|
||||||
if ubatch_slices is not None:
|
if ubatch_slices is not None:
|
||||||
# Adjust values to reflect a single ubatch.
|
# Adjust values to reflect a single ubatch.
|
||||||
# TODO(sage,lucas): this is cruft that should be addressed in
|
# TODO(sage,lucas): this is cruft that should be addressed in
|
||||||
|
|||||||
Reference in New Issue
Block a user