Make microbatch optimization (DBO) work with general models (#37926)

Signed-off-by: Junhao Li <junhao@ubicloud.com>
This commit is contained in:
Junhao
2026-03-24 17:40:08 -04:00
committed by GitHub
parent 0f0e03890e
commit b73b5b0629

View File

@@ -389,16 +389,20 @@ class UBatchWrapper:
inputs_embeds,
intermediate_tensors,
):
sliced_input_ids = input_ids[tokens_slice]
sliced_input_ids = input_ids[tokens_slice] if input_ids is not None else None
# if we are using mrope. Mrope adds an additional dimension to the
# positions tensor
if positions.ndim == 2:
sliced_positions = positions[:, tokens_slice]
else:
sliced_positions = positions[tokens_slice]
sliced_inputs_embeds = inputs_embeds[tokens_slice] if inputs_embeds else None
sliced_inputs_embeds = (
inputs_embeds[tokens_slice] if inputs_embeds is not None else None
)
sliced_intermediate_tensors = (
intermediate_tensors[tokens_slice] if intermediate_tensors else None
intermediate_tensors[tokens_slice]
if intermediate_tensors is not None
else None
)
return (
@@ -478,7 +482,7 @@ class UBatchWrapper:
cudagraph_runtime_mode=CUDAGraphMode.NONE,
)
with self.sm_control:
return self._capture_ubatches(ubatch_metadata, self.model)
return self._capture_ubatches(ubatch_metadata, self.runnable)
elif (
num_tokens in self.cudagraphs
and cudagraph_runtime_mode is CUDAGraphMode.FULL
@@ -504,4 +508,4 @@ class UBatchWrapper:
cudagraph_runtime_mode=CUDAGraphMode.NONE,
)
with self.sm_control:
return self._run_ubatches(ubatch_metadata, self.model)
return self._run_ubatches(ubatch_metadata, self.runnable)