Use the same memory for workspace13 and fused_output. (#31531)
Signed-off-by: Andrey Khalyavin <halyavin@yandex-team.ru>
This commit is contained in:
@@ -848,15 +848,17 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
# Construct the entire output that can then be processed in chunks.
|
||||
# Reuse workspace13 for the output in the non-chunked case as long
|
||||
# as it is large enough. This will not always be the case for standard
|
||||
# Reuse workspace13 for the output in the non-chunked case.
|
||||
# This will not always be the case for standard
|
||||
# format experts and with experts that have empty workspaces.
|
||||
if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape):
|
||||
workspace13, workspace2 = current_workspace_manager().get_simultaneous(
|
||||
(workspace13_shape, workspace_dtype),
|
||||
if num_chunks == 1:
|
||||
max_shape_size = max(prod(workspace13_shape), prod(fused_out_shape))
|
||||
common_workspace, workspace2 = current_workspace_manager().get_simultaneous(
|
||||
((max_shape_size,), workspace_dtype),
|
||||
(workspace2_shape, workspace_dtype),
|
||||
)
|
||||
fused_out = _resize_cache(workspace13, fused_out_shape)
|
||||
workspace13 = _resize_cache(common_workspace, workspace13_shape)
|
||||
fused_out = _resize_cache(common_workspace, fused_out_shape)
|
||||
else:
|
||||
workspace13, workspace2, fused_out = (
|
||||
current_workspace_manager().get_simultaneous(
|
||||
|
||||
Reference in New Issue
Block a user