Deepseek-v3 Batch Invariant on 8xH100 (#26609)
Signed-off-by: Bram Wasti <bwasti@meta.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -19,6 +19,9 @@ import torch.multiprocessing as mp
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.utils import cuda_device_count_stateless, update_environment_variables
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -71,6 +74,9 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
|
||||
is_symmetric_memory_enabled,
|
||||
)
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
return False
|
||||
|
||||
if not is_symmetric_memory_enabled():
|
||||
return False
|
||||
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
|
||||
|
||||
@@ -9,6 +9,9 @@ from vllm.distributed.device_communicators.all_reduce_utils import (
|
||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
try:
|
||||
@@ -100,6 +103,8 @@ class SymmMemCommunicator:
|
||||
return
|
||||
self.force_multimem = force_multimem
|
||||
self.disabled = False
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
self.disabled = True
|
||||
|
||||
def should_use_symm_mem(self, inp: torch.Tensor):
|
||||
if self.disabled:
|
||||
|
||||
Reference in New Issue
Block a user