[Hardware][TPU] Initial support of model parallelism with single worker using SPMD (#18011)
Signed-off-by: Siyuan Liu <lsiyuan@google.com> Co-authored-by: Hossein Sarshar <hossein.sarshar@gmail.com> Co-authored-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
@@ -7,21 +7,22 @@ from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
# TPU XLA related
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.spmd as xs
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config import ParallelConfig, VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.tpu import TPUModelLoader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
|
||||
PlaceholderRange)
|
||||
@@ -98,6 +99,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
original_parallel_config: Optional[ParallelConfig] = None,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
@@ -105,6 +107,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.original_parallel_config = original_parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
@@ -118,6 +121,14 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.device = device
|
||||
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
|
||||
|
||||
# SPMD Related
|
||||
self.use_spmd = envs.VLLM_XLA_USE_SPMD
|
||||
if self.use_spmd:
|
||||
num_devices = xr.global_runtime_device_count()
|
||||
mesh_shape = (num_devices, 1)
|
||||
device_ids = np.array(range(num_devices))
|
||||
self.mesh = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
|
||||
|
||||
self.enforce_eager = model_config.enforce_eager
|
||||
|
||||
self.num_xla_graphs = 0
|
||||
@@ -271,6 +282,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_num_mm_items_decoder_budget)
|
||||
self.max_num_mm_items_by_modality[modality] = max_num_mm_items
|
||||
|
||||
if not self.use_spmd:
|
||||
self.sample_from_logits_func = torch.compile(
|
||||
self.sample_from_logits,
|
||||
backend="openxla",
|
||||
fullgraph=True,
|
||||
dynamic=False)
|
||||
else:
|
||||
self.sample_from_logits_func = self.sample_from_logits
|
||||
|
||||
def _update_num_xla_graphs(self, case_str):
|
||||
check_comp = self.check_recompilation and not self.enforce_eager
|
||||
if not check_comp:
|
||||
@@ -825,9 +845,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
logits = self.structured_decode(require_struct_decoding,
|
||||
grammar_bitmask_padded, logits,
|
||||
arange)
|
||||
selected_token_ids = self.sample_from_logits(logits,
|
||||
tpu_sampling_metadata)
|
||||
|
||||
selected_token_ids = self.sample_from_logits_func(
|
||||
logits, tpu_sampling_metadata)
|
||||
# NOTE (NickLucche) Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs. We can't enforce it due
|
||||
# to recompilations outside torch.compiled code, so just make sure
|
||||
@@ -935,18 +954,26 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding."
|
||||
"get_tensor_model_parallel_rank",
|
||||
return_value=xm_tp_rank):
|
||||
# model = get_model(vllm_config=self.vllm_config)
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
if not hasattr(self, "model"):
|
||||
logger.info("Loading model from scratch...")
|
||||
model = model_loader.load_model(vllm_config=self.vllm_config,
|
||||
model_config=self.model_config)
|
||||
if self.use_spmd:
|
||||
tpu_loader = TPUModelLoader(
|
||||
load_config=self.vllm_config.load_config)
|
||||
model = tpu_loader.load_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.model_config,
|
||||
mesh=self.mesh)
|
||||
else:
|
||||
logger.info(
|
||||
"Model was already initialized. Loading weights inplace..."
|
||||
)
|
||||
model_loader.load_weights(self.model,
|
||||
model_config=self.model_config)
|
||||
# model = get_model(vllm_config=self.vllm_config)
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
if not hasattr(self, "model"):
|
||||
logger.info("Loading model from scratch...")
|
||||
model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.model_config)
|
||||
else:
|
||||
logger.info("Model was already initialized. \
|
||||
Loading weights inplace...")
|
||||
model_loader.load_weights(self.model,
|
||||
model_config=self.model_config)
|
||||
if self.lora_config is not None:
|
||||
model = self.load_lora_model(model, self.model_config,
|
||||
self.scheduler_config,
|
||||
@@ -970,31 +997,25 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
device=self.device)
|
||||
else:
|
||||
input_ids = torch.zeros((num_tokens),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
dtype=torch.int32).to(self.device)
|
||||
inputs_embeds = None
|
||||
actual_num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
position_ids = torch.zeros(num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
dtype=torch.int32).to(self.device)
|
||||
slot_mapping = torch.zeros(num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
dtype=torch.int64).to(self.device)
|
||||
block_tables = torch.zeros(
|
||||
(self.max_num_reqs, self.block_table_cpu.shape[1]),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
dtype=torch.int32).to(self.device)
|
||||
query_lens = [1] * self.max_num_reqs
|
||||
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
|
||||
dtype=torch.int32),
|
||||
dim=0,
|
||||
dtype=torch.int32).to(self.device)
|
||||
context_lens = torch.ones((self.max_num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
dtype=torch.int32).to(self.device)
|
||||
num_seqs = torch.tensor([actual_num_reqs],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
dtype=torch.int32).to(self.device)
|
||||
attn_metadata = PallasMetadata(
|
||||
slot_mapping=slot_mapping,
|
||||
block_tables=block_tables,
|
||||
@@ -1198,7 +1219,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
with self.maybe_select_dummy_loras(
|
||||
self.lora_config, np.array([num_reqs],
|
||||
dtype=np.int32)):
|
||||
self.sample_from_logits(dummy_logits, sampling_metadata)
|
||||
self.sample_from_logits_func(dummy_logits,
|
||||
sampling_metadata)
|
||||
logger.info(" -- num_seqs: %d", num_reqs)
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
@@ -1332,14 +1354,22 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
if self.use_spmd:
|
||||
num_kv_heads = kv_cache_spec.num_kv_heads
|
||||
assert self.original_parallel_config is not None
|
||||
tp_size = \
|
||||
self.original_parallel_config.tensor_parallel_size
|
||||
# TODO: Handle kv cache duplication under SPMD mode.
|
||||
assert num_kv_heads % tp_size == 0, (
|
||||
f"num_kv_heads {num_kv_heads} must be divisible by "
|
||||
f"tp_size {tp_size} under SPMD mode")
|
||||
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
|
||||
tpu_kv_cache = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
dtype=dtype).to(self.device)
|
||||
|
||||
kv_caches[layer_name] = tpu_kv_cache
|
||||
else:
|
||||
@@ -1350,6 +1380,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
if self.use_spmd:
|
||||
# Shard KV Cache
|
||||
for cache in self.kv_caches:
|
||||
xs.mark_sharding(cache, self.mesh, (None, 'x', None, None))
|
||||
|
||||
def reset_dynamo_cache(self):
|
||||
if self.is_multimodal_model:
|
||||
compiled_model = self.model.get_language_model().model
|
||||
@@ -1370,7 +1405,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
sample_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
# TODO: Under SPMD mode, sample_from_logits has correctness issue.
|
||||
# Re-enable the torch.compile once the issue is fixed in torchxla.
|
||||
# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
def sample_from_logits(
|
||||
self, logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:
|
||||
|
||||
@@ -45,6 +45,15 @@ class TPUWorker:
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.use_spmd = envs.VLLM_XLA_USE_SPMD
|
||||
self.original_parallel_config = None
|
||||
if self.use_spmd:
|
||||
# Under SPMD mode, distributed env is initialized as if there is
|
||||
# only one worker/device.
|
||||
self.original_parallel_config = self.parallel_config
|
||||
self.parallel_config.tensor_parallel_size = 1
|
||||
self.parallel_config.pipeline_parallel_size = 1
|
||||
self.parallel_config.world_size = 1
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
@@ -95,10 +104,9 @@ class TPUWorker:
|
||||
torch.set_default_dtype(self.model_config.dtype)
|
||||
|
||||
# Initialize the distributed environment.
|
||||
init_tpu_worker_distributed_environment(self.parallel_config,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
self._init_tpu_worker_distributed_environment(
|
||||
self.parallel_config, self.rank, self.distributed_init_method,
|
||||
self.local_rank)
|
||||
|
||||
# Device initialization should happen after initializing
|
||||
# the distributed runtime.
|
||||
@@ -132,7 +140,9 @@ class TPUWorker:
|
||||
xr.initialize_cache(per_rank_path, readonly=False)
|
||||
|
||||
# Init ModelRunner here, so that we have access to self.device.
|
||||
self.model_runner = TPUModelRunner(self.vllm_config, self.device)
|
||||
self.model_runner = \
|
||||
TPUModelRunner(self.vllm_config, self.device,
|
||||
self.original_parallel_config)
|
||||
|
||||
if rank == 0:
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
@@ -147,9 +157,7 @@ class TPUWorker:
|
||||
|
||||
# Use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
tpu_kv_cache = torch.tensor([],
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device)
|
||||
kv_caches[layer_name] = tpu_kv_cache
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@@ -178,9 +186,20 @@ class TPUWorker:
|
||||
|
||||
# Get the maximum amount of memory used by the model weights and
|
||||
# intermediate activations.
|
||||
m = xm.get_memory_info(self.device)
|
||||
total_memory_size = m["bytes_limit"]
|
||||
current_mem = m["bytes_used"]
|
||||
if self.use_spmd:
|
||||
# This is a workaround for the TPU SPMD mode. The get_memory_info
|
||||
# API doesn't work with SPMD mode in PyTorch/XLA.
|
||||
# TODO: use xm.get_memory_info for SPMD once it's supported in
|
||||
# PyTorch/XLA.
|
||||
import tpu_info
|
||||
chip_type, _ = tpu_info.device.get_local_chips()
|
||||
device_usage = tpu_info.metrics.get_chip_usage(chip_type)
|
||||
total_memory_size = device_usage[0].total_memory
|
||||
current_mem = device_usage[0].memory_usage
|
||||
else:
|
||||
m = xm.get_memory_info(self.device)
|
||||
total_memory_size = m["bytes_limit"]
|
||||
current_mem = m["bytes_used"]
|
||||
# Ideally we would use profiled = m["peak_bytes_used"] to
|
||||
# get weights + activations. But there is memory used during
|
||||
# compilation / weight loading that impacts the peak and
|
||||
@@ -241,28 +260,30 @@ class TPUWorker:
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
|
||||
def init_tpu_worker_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
|
||||
# NOTE(woosuk): This is just to initialize the TP group and broadcast
|
||||
# the input objects on CPU. The all-reduce and all-gather ops on TPU
|
||||
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
|
||||
# own context.
|
||||
init_distributed_environment(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
backend="gloo",
|
||||
)
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
def _init_tpu_worker_distributed_environment(
|
||||
self,
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
if self.use_spmd:
|
||||
xr.use_spmd()
|
||||
# NOTE(woosuk): This is just to initialize the TP group and broadcast
|
||||
# the input objects on CPU. The all-reduce and all-gather ops on TPU
|
||||
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
|
||||
# own context.
|
||||
init_distributed_environment(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
backend="gloo",
|
||||
)
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user