[TPU] Support collective communications in XLA devices (#6813)

This commit is contained in:
Woosuk Kwon
2024-07-26 18:45:57 -07:00
committed by GitHub
parent bb5494676f
commit d09b94ca58
4 changed files with 70 additions and 2 deletions

View File

@@ -5,10 +5,12 @@ from typing import Optional
import torch
import torch.nn as nn
from vllm.distributed import tensor_model_parallel_gather
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
class LogitsProcessor(nn.Module):
@@ -39,6 +41,8 @@ class LogitsProcessor(nn.Module):
self.org_vocab_size = org_vocab_size or vocab_size
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
self.use_gather = not current_platform.is_tpu()
def forward(
self,
@@ -76,7 +80,15 @@ class LogitsProcessor(nn.Module):
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
logits = tensor_model_parallel_gather(logits)
if self.use_gather:
logits = tensor_model_parallel_gather(logits)
else:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]