diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a6df2b20a..962d0fe78 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -848,15 +848,17 @@ class FusedMoEModularKernel(torch.nn.Module): # We can reuse the memory between cache1 and cache3 because by the # time we need cache3, we're done with cache1. # Construct the entire output that can then be processed in chunks. - # Reuse workspace13 for the output in the non-chunked case as long - # as it is large enough. This will not always be the case for standard + # Reuse workspace13 for the output in the non-chunked case. + # This will not always be the case for standard # format experts and with experts that have empty workspaces. - if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape): - workspace13, workspace2 = current_workspace_manager().get_simultaneous( - (workspace13_shape, workspace_dtype), + if num_chunks == 1: + max_shape_size = max(prod(workspace13_shape), prod(fused_out_shape)) + common_workspace, workspace2 = current_workspace_manager().get_simultaneous( + ((max_shape_size,), workspace_dtype), (workspace2_shape, workspace_dtype), ) - fused_out = _resize_cache(workspace13, fused_out_shape) + workspace13 = _resize_cache(common_workspace, workspace13_shape) + fused_out = _resize_cache(common_workspace, fused_out_shape) else: workspace13, workspace2, fused_out = ( current_workspace_manager().get_simultaneous(