Support torchrun and SPMD-style offline inference (#12071)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-01-16 19:58:53 +08:00
committed by GitHub
parent dd7c9ad870
commit bf53e0c70b
14 changed files with 248 additions and 30 deletions

View File

@@ -6,6 +6,7 @@ import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -44,8 +45,10 @@ class LogitsProcessor(nn.Module):
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
self.use_gather = not current_platform.is_tpu(
) and not envs.VLLM_USE_V1
parallel_config = get_current_vllm_config().parallel_config
self.use_all_gather = current_platform.is_tpu() \
or envs.VLLM_USE_V1 \
or parallel_config.distributed_executor_backend == "external_launcher" # noqa
def forward(
self,
@@ -88,16 +91,17 @@ class LogitsProcessor(nn.Module):
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
if self.use_gather:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
else:
if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)
else:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[..., :self.org_vocab_size]