[Bugfix] Fix potential EAGLE spec decode segfault during graph capture (#32818)
Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
@@ -1222,10 +1222,14 @@ class SpecDecodeBaseProposer:
|
||||
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
|
||||
)
|
||||
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens_dp_padded
|
||||
)
|
||||
num_input_tokens = batch_desc.num_tokens
|
||||
if use_cudagraphs:
|
||||
cudagraph_runtime_mode, batch_desc = (
|
||||
self.cudagraph_dispatcher.dispatch(num_tokens_dp_padded)
|
||||
)
|
||||
num_input_tokens = batch_desc.num_tokens
|
||||
else:
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
num_input_tokens = num_tokens_dp_padded
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
|
||||
|
||||
Reference in New Issue
Block a user