[TPU] Support collective communications in XLA devices (#6813)
This commit is contained in:
@@ -5,10 +5,12 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.distributed import tensor_model_parallel_gather
|
||||
from vllm.distributed import (tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_gather)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class LogitsProcessor(nn.Module):
|
||||
@@ -39,6 +41,8 @@ class LogitsProcessor(nn.Module):
|
||||
self.org_vocab_size = org_vocab_size or vocab_size
|
||||
# Soft cap the logits. Used in Gemma 2.
|
||||
self.soft_cap = soft_cap
|
||||
# Whether to use gather or all-gather to gather the logits.
|
||||
self.use_gather = not current_platform.is_tpu()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -76,7 +80,15 @@ class LogitsProcessor(nn.Module):
|
||||
logits = lm_head.linear_method.apply(lm_head,
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
if self.use_gather:
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
else:
|
||||
# Gather is not supported for some devices such as TPUs.
|
||||
# Use all-gather instead.
|
||||
# NOTE(woosuk): Here, the outputs of every device should not be None
|
||||
# because XLA requires strict SPMD among all devices. Every device
|
||||
# should execute the same operations after gathering the logits.
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
# Remove paddings in vocab (if any).
|
||||
if logits is not None:
|
||||
logits = logits[:, :self.org_vocab_size]
|
||||
|
||||
Reference in New Issue
Block a user