Make microbatch optimization (DBO) work with general models (#37926)
Signed-off-by: Junhao Li <junhao@ubicloud.com>
This commit is contained in:
@@ -389,16 +389,20 @@ class UBatchWrapper:
|
||||
inputs_embeds,
|
||||
intermediate_tensors,
|
||||
):
|
||||
sliced_input_ids = input_ids[tokens_slice]
|
||||
sliced_input_ids = input_ids[tokens_slice] if input_ids is not None else None
|
||||
# if we are using mrope. Mrope adds an additional dimension to the
|
||||
# positions tensor
|
||||
if positions.ndim == 2:
|
||||
sliced_positions = positions[:, tokens_slice]
|
||||
else:
|
||||
sliced_positions = positions[tokens_slice]
|
||||
sliced_inputs_embeds = inputs_embeds[tokens_slice] if inputs_embeds else None
|
||||
sliced_inputs_embeds = (
|
||||
inputs_embeds[tokens_slice] if inputs_embeds is not None else None
|
||||
)
|
||||
sliced_intermediate_tensors = (
|
||||
intermediate_tensors[tokens_slice] if intermediate_tensors else None
|
||||
intermediate_tensors[tokens_slice]
|
||||
if intermediate_tensors is not None
|
||||
else None
|
||||
)
|
||||
|
||||
return (
|
||||
@@ -478,7 +482,7 @@ class UBatchWrapper:
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
)
|
||||
with self.sm_control:
|
||||
return self._capture_ubatches(ubatch_metadata, self.model)
|
||||
return self._capture_ubatches(ubatch_metadata, self.runnable)
|
||||
elif (
|
||||
num_tokens in self.cudagraphs
|
||||
and cudagraph_runtime_mode is CUDAGraphMode.FULL
|
||||
@@ -504,4 +508,4 @@ class UBatchWrapper:
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
)
|
||||
with self.sm_control:
|
||||
return self._run_ubatches(ubatch_metadata, self.model)
|
||||
return self._run_ubatches(ubatch_metadata, self.runnable)
|
||||
|
||||
Reference in New Issue
Block a user