[Hardware][TPU] Initial support of model parallelism with single worker using SPMD (#18011)

Signed-off-by: Siyuan Liu <lsiyuan@google.com>
Co-authored-by: Hossein Sarshar <hossein.sarshar@gmail.com>
Co-authored-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Siyuan Liu
2025-06-02 17:06:20 -07:00
committed by GitHub
parent c57d577e8d
commit 9112b443a0
11 changed files with 605 additions and 72 deletions

View File

@@ -7,21 +7,22 @@ from unittest.mock import patch
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
# TPU XLA related
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config import ParallelConfig, VllmConfig, get_layers_from_vllm_config
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.tpu import TPUModelLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
PlaceholderRange)
@@ -98,6 +99,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self,
vllm_config: VllmConfig,
device: torch.device,
original_parallel_config: Optional[ParallelConfig] = None,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
@@ -105,6 +107,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.original_parallel_config = original_parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
@@ -118,6 +121,14 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.device = device
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
# SPMD Related
self.use_spmd = envs.VLLM_XLA_USE_SPMD
if self.use_spmd:
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
self.mesh = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
self.enforce_eager = model_config.enforce_eager
self.num_xla_graphs = 0
@@ -271,6 +282,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
max_num_mm_items_decoder_budget)
self.max_num_mm_items_by_modality[modality] = max_num_mm_items
if not self.use_spmd:
self.sample_from_logits_func = torch.compile(
self.sample_from_logits,
backend="openxla",
fullgraph=True,
dynamic=False)
else:
self.sample_from_logits_func = self.sample_from_logits
def _update_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager
if not check_comp:
@@ -825,9 +845,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
logits = self.structured_decode(require_struct_decoding,
grammar_bitmask_padded, logits,
arange)
selected_token_ids = self.sample_from_logits(logits,
tpu_sampling_metadata)
selected_token_ids = self.sample_from_logits_func(
logits, tpu_sampling_metadata)
# NOTE (NickLucche) Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs. We can't enforce it due
# to recompilations outside torch.compiled code, so just make sure
@@ -935,18 +954,26 @@ class TPUModelRunner(LoRAModelRunnerMixin):
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):
# model = get_model(vllm_config=self.vllm_config)
model_loader = get_model_loader(self.load_config)
if not hasattr(self, "model"):
logger.info("Loading model from scratch...")
model = model_loader.load_model(vllm_config=self.vllm_config,
model_config=self.model_config)
if self.use_spmd:
tpu_loader = TPUModelLoader(
load_config=self.vllm_config.load_config)
model = tpu_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.vllm_config.model_config,
mesh=self.mesh)
else:
logger.info(
"Model was already initialized. Loading weights inplace..."
)
model_loader.load_weights(self.model,
model_config=self.model_config)
# model = get_model(vllm_config=self.vllm_config)
model_loader = get_model_loader(self.load_config)
if not hasattr(self, "model"):
logger.info("Loading model from scratch...")
model = model_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.model_config)
else:
logger.info("Model was already initialized. \
Loading weights inplace...")
model_loader.load_weights(self.model,
model_config=self.model_config)
if self.lora_config is not None:
model = self.load_lora_model(model, self.model_config,
self.scheduler_config,
@@ -970,31 +997,25 @@ class TPUModelRunner(LoRAModelRunnerMixin):
device=self.device)
else:
input_ids = torch.zeros((num_tokens),
dtype=torch.int32,
device=self.device)
dtype=torch.int32).to(self.device)
inputs_embeds = None
actual_num_reqs = min(num_tokens, self.max_num_reqs)
position_ids = torch.zeros(num_tokens,
dtype=torch.int32,
device=self.device)
dtype=torch.int32).to(self.device)
slot_mapping = torch.zeros(num_tokens,
dtype=torch.int64,
device=self.device)
dtype=torch.int64).to(self.device)
block_tables = torch.zeros(
(self.max_num_reqs, self.block_table_cpu.shape[1]),
dtype=torch.int32,
device=self.device)
dtype=torch.int32).to(self.device)
query_lens = [1] * self.max_num_reqs
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
dtype=torch.int32),
dim=0,
dtype=torch.int32).to(self.device)
context_lens = torch.ones((self.max_num_reqs, ),
dtype=torch.int32,
device=self.device)
dtype=torch.int32).to(self.device)
num_seqs = torch.tensor([actual_num_reqs],
dtype=torch.int32,
device=self.device)
dtype=torch.int32).to(self.device)
attn_metadata = PallasMetadata(
slot_mapping=slot_mapping,
block_tables=block_tables,
@@ -1198,7 +1219,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
with self.maybe_select_dummy_loras(
self.lora_config, np.array([num_reqs],
dtype=np.int32)):
self.sample_from_logits(dummy_logits, sampling_metadata)
self.sample_from_logits_func(dummy_logits,
sampling_metadata)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
@@ -1332,14 +1354,22 @@ class TPUModelRunner(LoRAModelRunnerMixin):
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
if isinstance(kv_cache_spec, AttentionSpec):
if self.use_spmd:
num_kv_heads = kv_cache_spec.num_kv_heads
assert self.original_parallel_config is not None
tp_size = \
self.original_parallel_config.tensor_parallel_size
# TODO: Handle kv cache duplication under SPMD mode.
assert num_kv_heads % tp_size == 0, (
f"num_kv_heads {num_kv_heads} must be divisible by "
f"tp_size {tp_size} under SPMD mode")
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
tpu_kv_cache = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
dtype=dtype).to(self.device)
kv_caches[layer_name] = tpu_kv_cache
else:
@@ -1350,6 +1380,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
if self.use_spmd:
# Shard KV Cache
for cache in self.kv_caches:
xs.mark_sharding(cache, self.mesh, (None, 'x', None, None))
def reset_dynamo_cache(self):
if self.is_multimodal_model:
compiled_model = self.model.get_language_model().model
@@ -1370,7 +1405,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
sample_hidden_states: torch.Tensor) -> torch.Tensor:
return self.model.compute_logits(sample_hidden_states, None)
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
# TODO: Under SPMD mode, sample_from_logits has correctness issue.
# Re-enable the torch.compile once the issue is fixed in torchxla.
# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def sample_from_logits(
self, logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:

View File

@@ -45,6 +45,15 @@ class TPUWorker:
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.use_spmd = envs.VLLM_XLA_USE_SPMD
self.original_parallel_config = None
if self.use_spmd:
# Under SPMD mode, distributed env is initialized as if there is
# only one worker/device.
self.original_parallel_config = self.parallel_config
self.parallel_config.tensor_parallel_size = 1
self.parallel_config.pipeline_parallel_size = 1
self.parallel_config.world_size = 1
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
@@ -95,10 +104,9 @@ class TPUWorker:
torch.set_default_dtype(self.model_config.dtype)
# Initialize the distributed environment.
init_tpu_worker_distributed_environment(self.parallel_config,
self.rank,
self.distributed_init_method,
self.local_rank)
self._init_tpu_worker_distributed_environment(
self.parallel_config, self.rank, self.distributed_init_method,
self.local_rank)
# Device initialization should happen after initializing
# the distributed runtime.
@@ -132,7 +140,9 @@ class TPUWorker:
xr.initialize_cache(per_rank_path, readonly=False)
# Init ModelRunner here, so that we have access to self.device.
self.model_runner = TPUModelRunner(self.vllm_config, self.device)
self.model_runner = \
TPUModelRunner(self.vllm_config, self.device,
self.original_parallel_config)
if rank == 0:
# If usage stat is enabled, collect relevant info.
@@ -147,9 +157,7 @@ class TPUWorker:
# Use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
tpu_kv_cache = torch.tensor([],
dtype=dtype,
device=self.device)
tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device)
kv_caches[layer_name] = tpu_kv_cache
else:
raise NotImplementedError(
@@ -178,9 +186,20 @@ class TPUWorker:
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
current_mem = m["bytes_used"]
if self.use_spmd:
# This is a workaround for the TPU SPMD mode. The get_memory_info
# API doesn't work with SPMD mode in PyTorch/XLA.
# TODO: use xm.get_memory_info for SPMD once it's supported in
# PyTorch/XLA.
import tpu_info
chip_type, _ = tpu_info.device.get_local_chips()
device_usage = tpu_info.metrics.get_chip_usage(chip_type)
total_memory_size = device_usage[0].total_memory
current_mem = device_usage[0].memory_usage
else:
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
current_mem = m["bytes_used"]
# Ideally we would use profiled = m["peak_bytes_used"] to
# get weights + activations. But there is memory used during
# compilation / weight loading that impacts the peak and
@@ -241,28 +260,30 @@ class TPUWorker:
# worker will always be healthy as long as it's running.
return
def init_tpu_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
local_rank=local_rank,
distributed_init_method=distributed_init_method,
backend="gloo",
)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def _init_tpu_worker_distributed_environment(
self,
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
if self.use_spmd:
xr.use_spmd()
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
local_rank=local_rank,
distributed_init_method=distributed_init_method,
backend="gloo",
)
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
try: