[Bugfix][WideEP] Apply TP Attn + EP MoE fix to other models (#24982)
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
committed by
GitHub
parent
f9df8b4ad7
commit
a5354b3ed2
@@ -13,11 +13,14 @@ from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import NestedTensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
|
||||
from vllm.utils import (cdiv, direct_register_custom_op,
|
||||
get_cuda_view_from_cpu_tensor, is_pin_memory_available,
|
||||
is_uva_available)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -743,3 +746,46 @@ def get_model_hidden_size(hf_config: PretrainedConfig) -> int:
|
||||
return hf_config.hidden_size
|
||||
text_config = hf_config.get_text_config()
|
||||
return text_config.hidden_size
|
||||
|
||||
|
||||
# Chunk x along the num_tokens axis for sequence parallelism
|
||||
# NOTE: This is wrapped in a torch custom op to work around the following issue:
|
||||
# The output tensor can have a sequence length 0 at small input sequence lengths
|
||||
# even though we explicitly pad to avoid this.
|
||||
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.sequence_parallel_chunk_impl(x)
|
||||
|
||||
|
||||
def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# all_gather needs the sequence length to be divisible by tp_size
|
||||
seq_len = x.size(0)
|
||||
remainder = seq_len % tp_size
|
||||
if remainder != 0:
|
||||
pad_len = tp_size - remainder
|
||||
y = nn.functional.pad(x, (0, 0, 0, pad_len))
|
||||
else:
|
||||
y = x
|
||||
|
||||
chunk = y.shape[0] // tp_size
|
||||
start = tp_rank * chunk
|
||||
return torch.narrow(y, 0, start, chunk)
|
||||
|
||||
|
||||
def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
seq_len = cdiv(x.size(0), tp_size)
|
||||
shape = list(x.shape)
|
||||
shape[0] = seq_len
|
||||
out = torch.empty(shape, dtype=x.dtype, device=x.device)
|
||||
return out
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="sequence_parallel_chunk_impl",
|
||||
op_func=sequence_parallel_chunk_impl,
|
||||
fake_impl=sequence_parallel_chunk_impl_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user