[Bugfix][WideEP] Apply TP Attn + EP MoE fix to other models (#24982)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
Tyler Michael Smith
2025-09-27 10:22:28 -04:00
committed by GitHub
parent f9df8b4ad7
commit a5354b3ed2
23 changed files with 540 additions and 375 deletions

View File

@@ -13,11 +13,14 @@ from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
from vllm.utils import (cdiv, direct_register_custom_op,
get_cuda_view_from_cpu_tensor, is_pin_memory_available,
is_uva_available)
logger = init_logger(__name__)
@@ -743,3 +746,46 @@ def get_model_hidden_size(hf_config: PretrainedConfig) -> int:
return hf_config.hidden_size
text_config = hf_config.get_text_config()
return text_config.hidden_size
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.sequence_parallel_chunk_impl(x)
def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
# all_gather needs the sequence length to be divisible by tp_size
seq_len = x.size(0)
remainder = seq_len % tp_size
if remainder != 0:
pad_len = tp_size - remainder
y = nn.functional.pad(x, (0, 0, 0, pad_len))
else:
y = x
chunk = y.shape[0] // tp_size
start = tp_rank * chunk
return torch.narrow(y, 0, start, chunk)
def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
seq_len = cdiv(x.size(0), tp_size)
shape = list(x.shape)
shape[0] = seq_len
out = torch.empty(shape, dtype=x.dtype, device=x.device)
return out
direct_register_custom_op(
op_name="sequence_parallel_chunk_impl",
op_func=sequence_parallel_chunk_impl,
fake_impl=sequence_parallel_chunk_impl_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)