Use the same memory for workspace13 and fused_output. (#31531)

Signed-off-by: Andrey Khalyavin <halyavin@yandex-team.ru>
This commit is contained in:
Andrey Khalyavin
2026-01-18 22:14:22 +03:00
committed by GitHub
parent afc3622602
commit ba29ab441e

View File

@@ -848,15 +848,17 @@ class FusedMoEModularKernel(torch.nn.Module):
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
# Construct the entire output that can then be processed in chunks.
# Reuse workspace13 for the output in the non-chunked case as long
# as it is large enough. This will not always be the case for standard
# Reuse workspace13 for the output in the non-chunked case.
# This will not always be the case for standard
# format experts and with experts that have empty workspaces.
if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape):
workspace13, workspace2 = current_workspace_manager().get_simultaneous(
(workspace13_shape, workspace_dtype),
if num_chunks == 1:
max_shape_size = max(prod(workspace13_shape), prod(fused_out_shape))
common_workspace, workspace2 = current_workspace_manager().get_simultaneous(
((max_shape_size,), workspace_dtype),
(workspace2_shape, workspace_dtype),
)
fused_out = _resize_cache(workspace13, fused_out_shape)
workspace13 = _resize_cache(common_workspace, workspace13_shape)
fused_out = _resize_cache(common_workspace, fused_out_shape)
else:
workspace13, workspace2, fused_out = (
current_workspace_manager().get_simultaneous(