[TPU] Refactor TPU worker & model runner (#6506)
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
from typing import List, Optional, Tuple
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -12,10 +13,16 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||||
SamplerOutput, SequenceGroupMetadata,
|
Logprob, SamplerOutput, SequenceGroupMetadata,
|
||||||
SequenceOutput)
|
SequenceOutput)
|
||||||
from vllm.utils import make_tensor_with_pad
|
from vllm.worker.model_runner_base import (
|
||||||
|
ModelRunnerBase, ModelRunnerInputBase,
|
||||||
|
_add_attn_metadata_broadcastable_dict,
|
||||||
|
_init_attn_metadata_from_tensor_dict)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -27,7 +34,44 @@ _ENABLE_TOP_P = False
|
|||||||
_MAX_NUM_SAMPLES = 128
|
_MAX_NUM_SAMPLES = 128
|
||||||
|
|
||||||
|
|
||||||
class TPUModelRunner:
|
@dataclass(frozen=True)
|
||||||
|
class ModelInputForTPU(ModelRunnerInputBase):
|
||||||
|
token_ids: torch.Tensor
|
||||||
|
position_ids: torch.Tensor
|
||||||
|
attn_metadata: AttentionMetadata
|
||||||
|
input_lens: torch.Tensor
|
||||||
|
t: torch.Tensor
|
||||||
|
p: torch.Tensor
|
||||||
|
num_samples: int
|
||||||
|
best_of: List[int]
|
||||||
|
seq_groups: List[List[int]]
|
||||||
|
|
||||||
|
def as_broadcastable_tensor_dict(
|
||||||
|
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||||
|
tensor_dict = {
|
||||||
|
"token_ids": self.token_ids,
|
||||||
|
"position_ids": self.position_ids,
|
||||||
|
"input_lens": self.input_lens,
|
||||||
|
"t": self.t,
|
||||||
|
"p": self.p,
|
||||||
|
"num_samples": self.num_samples,
|
||||||
|
}
|
||||||
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_broadcasted_tensor_dict(
|
||||||
|
cls: Type["ModelInputForTPU"],
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
attn_backend: Optional["AttentionBackend"] = None,
|
||||||
|
) -> "ModelInputForTPU":
|
||||||
|
if attn_backend is not None:
|
||||||
|
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||||
|
attn_backend, tensor_dict)
|
||||||
|
return cls(**tensor_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -79,6 +123,7 @@ class TPUModelRunner:
|
|||||||
multimodal_config=self.multimodal_config,
|
multimodal_config=self.multimodal_config,
|
||||||
lora_config=None,
|
lora_config=None,
|
||||||
)
|
)
|
||||||
|
model = model.eval()
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
|
|
||||||
model = ModelWrapper(model)
|
model = ModelWrapper(model)
|
||||||
@@ -147,8 +192,8 @@ class TPUModelRunner:
|
|||||||
|
|
||||||
# Dummy run.
|
# Dummy run.
|
||||||
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
||||||
self.model(token_ids, position_ids, kv_caches, attn_metadata,
|
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
|
||||||
input_lens, t, p, num_samples)
|
num_samples, kv_caches)
|
||||||
|
|
||||||
def warmup_model(
|
def warmup_model(
|
||||||
self,
|
self,
|
||||||
@@ -177,7 +222,7 @@ class TPUModelRunner:
|
|||||||
# Decode
|
# Decode
|
||||||
start = time.time()
|
start = time.time()
|
||||||
seq_len = 1
|
seq_len = 1
|
||||||
batch_size = 1
|
batch_size = 8 # Must be in sync with _get_padded_batch_size()
|
||||||
while True:
|
while True:
|
||||||
self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
|
self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
@@ -195,10 +240,10 @@ class TPUModelRunner:
|
|||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
|
||||||
assert len(seq_group_metadata_list) > 0
|
assert len(seq_group_metadata_list) > 0
|
||||||
input_tokens: List[List[int]] = []
|
input_tokens: List[int] = []
|
||||||
input_positions: List[List[int]] = []
|
input_positions: List[int] = []
|
||||||
prompt_lens: List[int] = []
|
prompt_lens: List[int] = []
|
||||||
slot_mapping: List[List[int]] = []
|
slot_mapping: List[int] = []
|
||||||
|
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
assert seq_group_metadata.is_prompt
|
assert seq_group_metadata.is_prompt
|
||||||
@@ -212,50 +257,46 @@ class TPUModelRunner:
|
|||||||
prompt_len = len(prompt_tokens)
|
prompt_len = len(prompt_tokens)
|
||||||
prompt_lens.append(prompt_len)
|
prompt_lens.append(prompt_len)
|
||||||
|
|
||||||
input_tokens.append(prompt_tokens)
|
input_tokens.extend(prompt_tokens)
|
||||||
input_positions.append(list(range(prompt_len)))
|
input_positions.extend(list(range(prompt_len)))
|
||||||
|
|
||||||
assert seq_group_metadata.block_tables is not None
|
assert seq_group_metadata.block_tables is not None
|
||||||
block_table = seq_group_metadata.block_tables[seq_id]
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
slot_mapping.append([])
|
|
||||||
for i in range(prompt_len):
|
for i in range(prompt_len):
|
||||||
block_number = block_table[i // self.block_size]
|
block_number = block_table[i // self.block_size]
|
||||||
block_offset = i % self.block_size
|
block_offset = i % self.block_size
|
||||||
slot = block_number * self.block_size + block_offset
|
slot = block_number * self.block_size + block_offset
|
||||||
slot_mapping[-1].append(slot)
|
slot_mapping.append(slot)
|
||||||
|
|
||||||
|
# Add paddings to EACH prompt to the smallest power of 2 that is
|
||||||
|
# greater than or equal to the prompt length.
|
||||||
|
# We pad the seq_len to reduce the compilation overhead.
|
||||||
|
# We execute each prompt individually (i.e., with batch_size 1)
|
||||||
|
# because the FlashAttention kernel does not support ragged inputs.
|
||||||
|
# TODO(woosuk): Use SplashAttention to support ragged inputs.
|
||||||
|
padded_prompt_len = _get_padded_prefill_len(prompt_len)
|
||||||
|
num_paddings = padded_prompt_len - prompt_len
|
||||||
|
input_tokens += [0] * num_paddings
|
||||||
|
input_positions += [0] * num_paddings
|
||||||
|
slot_mapping += [_PAD_SLOT_ID] * num_paddings
|
||||||
|
|
||||||
assert len(prompt_lens) > 0
|
assert len(prompt_lens) > 0
|
||||||
num_prefills = len(prompt_lens)
|
num_prefills = len(prompt_lens)
|
||||||
num_prefill_tokens = sum(prompt_lens)
|
input_tokens = torch.tensor(input_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
# Add paddings to make the shape [batch_size, max_prompt_len] where
|
device="cpu")
|
||||||
# max_prompt_len is smallest power of 2 that is greater than or equal
|
input_positions = torch.tensor(input_positions,
|
||||||
# to the maximum prompt length.
|
dtype=torch.int32,
|
||||||
# We need the 2D input shape because the Pallas FlashAttention kernel
|
device="cpu")
|
||||||
# does not support packed 1D inputs.
|
slot_mapping = torch.tensor(slot_mapping,
|
||||||
# We pad the seq_len to powers of 2 to reduce the compilation overhead.
|
dtype=torch.int64,
|
||||||
max_prompt_len = _get_padded_prefill_len(max(prompt_lens))
|
device="cpu")
|
||||||
input_tokens = make_tensor_with_pad(input_tokens,
|
|
||||||
max_prompt_len,
|
|
||||||
pad=0,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device)
|
|
||||||
input_positions = make_tensor_with_pad(input_positions,
|
|
||||||
max_prompt_len,
|
|
||||||
pad=0,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device)
|
|
||||||
slot_mapping = make_tensor_with_pad(slot_mapping,
|
|
||||||
max_prompt_len,
|
|
||||||
pad=_PAD_SLOT_ID,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=self.device)
|
|
||||||
prompt_lens = torch.tensor(prompt_lens,
|
prompt_lens = torch.tensor(prompt_lens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device="cpu")
|
||||||
attn_metadata = self.attn_backend.make_metadata(
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
num_prefills=num_prefills,
|
num_prefills=num_prefills,
|
||||||
num_prefill_tokens=num_prefill_tokens, # NOTE: This is not used.
|
num_prefill_tokens=0, # NOTE: This is not used.
|
||||||
num_decode_tokens=0,
|
num_decode_tokens=0,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
@@ -306,22 +347,22 @@ class TPUModelRunner:
|
|||||||
|
|
||||||
input_tokens = torch.tensor(input_tokens,
|
input_tokens = torch.tensor(input_tokens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device="cpu")
|
||||||
input_positions = torch.tensor(input_positions,
|
input_positions = torch.tensor(input_positions,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device="cpu")
|
||||||
slot_mapping = torch.tensor(slot_mapping,
|
slot_mapping = torch.tensor(slot_mapping,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=self.device)
|
device="cpu")
|
||||||
context_lens = torch.tensor(context_lens,
|
context_lens = torch.tensor(context_lens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device="cpu")
|
||||||
block_tables = torch.tensor(self.block_tables[:batch_size],
|
block_tables = torch.tensor(self.block_tables[:batch_size],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device="cpu")
|
||||||
input_lens = torch.tensor([1] * batch_size,
|
input_lens = torch.tensor([1] * batch_size,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device="cpu")
|
||||||
attn_metadata = self.attn_backend.make_metadata(
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
num_prefills=0,
|
num_prefills=0,
|
||||||
num_prefill_tokens=0,
|
num_prefill_tokens=0,
|
||||||
@@ -382,16 +423,18 @@ class TPUModelRunner:
|
|||||||
t += [1.0] * num_paddings
|
t += [1.0] * num_paddings
|
||||||
p += [1.0] * num_paddings
|
p += [1.0] * num_paddings
|
||||||
|
|
||||||
t = torch.tensor(t, dtype=torch.float32, device=self.device)
|
t = torch.tensor(t, dtype=torch.float32, device="cpu")
|
||||||
p = torch.tensor(p, dtype=torch.float32, device=self.device)
|
p = torch.tensor(p, dtype=torch.float32, device="cpu")
|
||||||
return t, p, best_of
|
return t, p, best_of
|
||||||
|
|
||||||
def _execute_model(
|
def prepare_model_input(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
virtual_engine: int = 0,
|
||||||
) -> List[CompletionSequenceGroupOutput]:
|
finished_requests_ids: Optional[List[str]] = None,
|
||||||
# Prepare inputs.
|
) -> ModelInputForTPU:
|
||||||
|
del finished_requests_ids # Unused.
|
||||||
|
assert virtual_engine == 0
|
||||||
assert len(seq_group_metadata_list) > 0
|
assert len(seq_group_metadata_list) > 0
|
||||||
# NOTE: We assume that all sequences in the group are all prompts or
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
# all decodes.
|
# all decodes.
|
||||||
@@ -400,16 +443,104 @@ class TPUModelRunner:
|
|||||||
inputs = self._prepare_prompt(seq_group_metadata_list)
|
inputs = self._prepare_prompt(seq_group_metadata_list)
|
||||||
else:
|
else:
|
||||||
inputs = self._prepare_decode(seq_group_metadata_list)
|
inputs = self._prepare_decode(seq_group_metadata_list)
|
||||||
padded_batch_size = inputs[0].shape[0]
|
input_tokens, input_positions, attn_metadata, input_lens = inputs
|
||||||
|
padded_batch_size = input_tokens.shape[0]
|
||||||
t, p, best_of = self._prepare_sample(seq_group_metadata_list,
|
t, p, best_of = self._prepare_sample(seq_group_metadata_list,
|
||||||
padded_batch_size)
|
padded_batch_size)
|
||||||
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
||||||
|
|
||||||
# Execute the model.
|
seq_groups = [
|
||||||
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
|
list(metadata.seq_data.keys())
|
||||||
*inputs[2:], t, p, num_samples)
|
for metadata in seq_group_metadata_list
|
||||||
# Retrieve the outputs to CPU.
|
]
|
||||||
next_token_ids = next_token_ids.cpu().tolist()
|
return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
|
||||||
|
input_lens, t, p, num_samples, best_of,
|
||||||
|
seq_groups)
|
||||||
|
|
||||||
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
|
self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
|
||||||
|
model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
|
||||||
|
tensor_dict, attn_backend=self.attn_backend)
|
||||||
|
return model_input
|
||||||
|
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
model_input: ModelInputForTPU,
|
||||||
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
num_steps: int = 1,
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
assert intermediate_tensors is None
|
||||||
|
if num_steps > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"TPUModelRunner does not support multi-step execution.")
|
||||||
|
|
||||||
|
def _execute_model(*args, clone: bool = False) -> torch.Tensor:
|
||||||
|
"""Move input args from CPU to device and execute the model."""
|
||||||
|
|
||||||
|
def _copy_to_device(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if clone:
|
||||||
|
# When x is a slice of a CPU tensor, XLA may copy the whole
|
||||||
|
# original tensor to TPU instead of only copying x.
|
||||||
|
# To avoid this, we copy x after cloning.
|
||||||
|
x = x.clone()
|
||||||
|
return x.to(self.device)
|
||||||
|
|
||||||
|
new_args = []
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, torch.Tensor):
|
||||||
|
arg = _copy_to_device(arg)
|
||||||
|
elif isinstance(arg, AttentionMetadata):
|
||||||
|
arg.slot_mapping = _copy_to_device(arg.slot_mapping)
|
||||||
|
if getattr(arg, "block_tables", None) is not None:
|
||||||
|
arg.block_tables = _copy_to_device(arg.block_tables)
|
||||||
|
if getattr(arg, "context_lens", None) is not None:
|
||||||
|
arg.context_lens = _copy_to_device(arg.context_lens)
|
||||||
|
new_args.append(arg)
|
||||||
|
return self.model(*new_args)
|
||||||
|
|
||||||
|
num_prefills = model_input.attn_metadata.num_prefills
|
||||||
|
is_prompt = num_prefills > 0
|
||||||
|
if is_prompt:
|
||||||
|
# NOTE(woosuk): Since the FlashAttention kernel does not support
|
||||||
|
# ragged inputs, we split the prompts into different batches and
|
||||||
|
# process them separately. This is a temporary hack that should be
|
||||||
|
# optimized by using SplashAttention.
|
||||||
|
next_token_ids = []
|
||||||
|
orig_slot_mapping = model_input.attn_metadata.slot_mapping
|
||||||
|
batch_size = model_input.input_lens.shape[0]
|
||||||
|
start_idx = 0
|
||||||
|
for i in range(batch_size):
|
||||||
|
# Get the actual prefill_len.
|
||||||
|
prefill_len = model_input.input_lens[i:i + 1].item()
|
||||||
|
prefill_len = _get_padded_prefill_len(prefill_len)
|
||||||
|
end_idx = start_idx + prefill_len
|
||||||
|
|
||||||
|
model_input.attn_metadata.slot_mapping = orig_slot_mapping[
|
||||||
|
None, start_idx:end_idx]
|
||||||
|
model_input.attn_metadata.num_prefills = 1
|
||||||
|
output_token_ids = _execute_model(
|
||||||
|
model_input.token_ids[None, start_idx:end_idx],
|
||||||
|
model_input.position_ids[None, start_idx:end_idx],
|
||||||
|
model_input.attn_metadata,
|
||||||
|
model_input.input_lens[i:i + 1],
|
||||||
|
model_input.t[i:i + 1],
|
||||||
|
model_input.p[i:i + 1],
|
||||||
|
model_input.num_samples,
|
||||||
|
kv_caches,
|
||||||
|
clone=True)
|
||||||
|
# Retrieve the outputs to CPU.
|
||||||
|
next_token_ids += output_token_ids.cpu().tolist()
|
||||||
|
start_idx = end_idx
|
||||||
|
else:
|
||||||
|
# Execute the model.
|
||||||
|
output_token_ids = _execute_model(
|
||||||
|
model_input.token_ids, model_input.position_ids,
|
||||||
|
model_input.attn_metadata, model_input.input_lens,
|
||||||
|
model_input.t, model_input.p, model_input.num_samples,
|
||||||
|
kv_caches)
|
||||||
|
# Retrieve the outputs to CPU.
|
||||||
|
next_token_ids = output_token_ids.cpu().tolist()
|
||||||
|
|
||||||
# NOTE(woosuk): Minimal code to construct the sampler outputs.
|
# NOTE(woosuk): Minimal code to construct the sampler outputs.
|
||||||
# The TPU backend does not reuse the sampler, since the TPU backend
|
# The TPU backend does not reuse the sampler, since the TPU backend
|
||||||
@@ -417,13 +548,13 @@ class TPUModelRunner:
|
|||||||
zero_logprob = Logprob(0.0)
|
zero_logprob = Logprob(0.0)
|
||||||
batch_idx = 0
|
batch_idx = 0
|
||||||
sampler_outputs = []
|
sampler_outputs = []
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group in model_input.seq_groups:
|
||||||
|
seq_ids = seq_group
|
||||||
seq_outputs = []
|
seq_outputs = []
|
||||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
assert len(seq_ids) == 1
|
assert len(seq_ids) == 1
|
||||||
seq_id = seq_ids[0]
|
seq_id = seq_ids[0]
|
||||||
for i in range(best_of[batch_idx]):
|
for i in range(model_input.best_of[batch_idx]):
|
||||||
next_token_id = next_token_ids[batch_idx][i]
|
next_token_id = next_token_ids[batch_idx][i]
|
||||||
seq_outputs.append(
|
seq_outputs.append(
|
||||||
SequenceOutput(seq_id, next_token_id,
|
SequenceOutput(seq_id, next_token_id,
|
||||||
@@ -438,35 +569,6 @@ class TPUModelRunner:
|
|||||||
batch_idx += 1
|
batch_idx += 1
|
||||||
sampler_outputs.append(
|
sampler_outputs.append(
|
||||||
CompletionSequenceGroupOutput(seq_outputs, None))
|
CompletionSequenceGroupOutput(seq_outputs, None))
|
||||||
return sampler_outputs
|
|
||||||
|
|
||||||
def execute_model(
|
|
||||||
self,
|
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
|
||||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
||||||
num_steps: int = 1,
|
|
||||||
) -> List[SamplerOutput]:
|
|
||||||
if num_steps > 1:
|
|
||||||
raise ValueError(
|
|
||||||
"TPUModelRunner does not support multi-step execution.")
|
|
||||||
|
|
||||||
assert seq_group_metadata_list is not None
|
|
||||||
assert len(seq_group_metadata_list) > 0
|
|
||||||
if seq_group_metadata_list[0].is_prompt:
|
|
||||||
# NOTE(woosuk): To reduce the compilation time, we only compile the
|
|
||||||
# prefill inputs with batch size 1. Because the scheduler is not
|
|
||||||
# aware of this limitation, we need to handle batch size > 1
|
|
||||||
# internally by calling the model multiple times and concatenating
|
|
||||||
# the outputs.
|
|
||||||
# FIXME(woosuk): This is a temporary hack to not change the existing
|
|
||||||
# scheduler. We need to fix this in the future.
|
|
||||||
sampler_outputs = []
|
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
|
||||||
sampler_outputs += self._execute_model([seq_group_metadata],
|
|
||||||
kv_caches)
|
|
||||||
else:
|
|
||||||
sampler_outputs = self._execute_model(seq_group_metadata_list,
|
|
||||||
kv_caches)
|
|
||||||
return [SamplerOutput(sampler_outputs)]
|
return [SamplerOutput(sampler_outputs)]
|
||||||
|
|
||||||
|
|
||||||
@@ -474,36 +576,37 @@ class ModelWrapper(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, model: nn.Module):
|
def __init__(self, model: nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model.eval()
|
self.model = model
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
token_ids: torch.Tensor,
|
token_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
|
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
input_lens: torch.Tensor,
|
input_lens: torch.Tensor,
|
||||||
t: torch.Tensor,
|
t: torch.Tensor,
|
||||||
p: torch.Tensor,
|
p: torch.Tensor,
|
||||||
num_samples: int,
|
num_samples: int,
|
||||||
|
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Executes the forward pass of the model and samples the next token.
|
"""Executes the forward pass of the model and samples the next token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token_ids: The input token IDs of shape [batch_size, seq_len].
|
token_ids: The input token IDs of shape [batch_size, seq_len].
|
||||||
position_ids: The input position IDs of shape [batch_size, seq_len].
|
position_ids: The input position IDs of shape [batch_size, seq_len].
|
||||||
kv_caches: The key and value caches. They can be None during the
|
|
||||||
memory profiling at initialization.
|
|
||||||
attn_metadata: The Pallas attention metadata.
|
attn_metadata: The Pallas attention metadata.
|
||||||
input_lens: The actual input lengths of shape [batch_size].
|
input_lens: The actual input lengths of shape [batch_size].
|
||||||
t: The sampling temperature of shape [batch_size].
|
t: The sampling temperature of shape [batch_size].
|
||||||
p: The top-p probability of shape [batch_size].
|
p: The top-p probability of shape [batch_size].
|
||||||
|
num_samples: Number of samples to draw from each logits vector.
|
||||||
|
kv_caches: The key and value caches. They can be None during the
|
||||||
|
memory profiling at initialization.
|
||||||
"""
|
"""
|
||||||
batch_size, seq_len = token_ids.shape
|
batch_size, seq_len = token_ids.shape
|
||||||
# Calculate the positions to sample from.
|
# Calculate the positions to sample from.
|
||||||
base_indicies = torch.arange(
|
start_indicies = torch.arange(
|
||||||
batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
|
batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
|
||||||
logits_indices = base_indicies + input_lens - 1
|
logits_indices = start_indicies + input_lens - 1
|
||||||
|
|
||||||
# FIXME(woosuk): This is a temporary hack to avoid using the existing
|
# FIXME(woosuk): This is a temporary hack to avoid using the existing
|
||||||
# sampler and sampling metadata.
|
# sampler and sampling metadata.
|
||||||
|
|||||||
@@ -13,15 +13,16 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
|||||||
init_distributed_environment)
|
init_distributed_environment)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
|
||||||
from vllm.worker.tpu_model_runner import TPUModelRunner
|
from vllm.worker.tpu_model_runner import TPUModelRunner
|
||||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||||
|
LoraNotSupportedWorkerBase, WorkerInput)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TPUWorker(LoraNotSupportedWorkerBase):
|
class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -57,14 +58,15 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||||
self.cache_config.cache_dtype]
|
self.cache_config.cache_dtype]
|
||||||
|
|
||||||
self.model_runner = TPUModelRunner(model_config,
|
self.model_runner: TPUModelRunner = TPUModelRunner(
|
||||||
parallel_config,
|
model_config,
|
||||||
scheduler_config,
|
parallel_config,
|
||||||
device_config,
|
scheduler_config,
|
||||||
cache_config,
|
device_config,
|
||||||
load_config,
|
cache_config,
|
||||||
multimodal_config,
|
load_config,
|
||||||
is_driver_worker=is_driver_worker)
|
multimodal_config,
|
||||||
|
is_driver_worker=is_driver_worker)
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
os.environ["PJRT_DEVICE"] = "TPU"
|
os.environ["PJRT_DEVICE"] = "TPU"
|
||||||
@@ -196,69 +198,70 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
dtype_size = get_dtype_size(self.cache_dtype)
|
dtype_size = get_dtype_size(self.cache_dtype)
|
||||||
return dtype_size * total
|
return dtype_size * total
|
||||||
|
|
||||||
def execute_model(
|
@property
|
||||||
self,
|
def do_metadata_broadcast(self) -> bool:
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
# TODO(woosuk): Support TP.
|
||||||
) -> List[SamplerOutput]:
|
return False
|
||||||
if not self.is_driver_worker:
|
|
||||||
self._execute_model_non_driver()
|
|
||||||
return []
|
|
||||||
assert execute_model_req is not None
|
|
||||||
# Issue cache operations.
|
|
||||||
self.cache_swap(
|
|
||||||
execute_model_req.blocks_to_swap_in,
|
|
||||||
execute_model_req.blocks_to_swap_out,
|
|
||||||
execute_model_req.blocks_to_copy,
|
|
||||||
)
|
|
||||||
# Run the model.
|
|
||||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
|
||||||
assert len(seq_group_metadata_list) > 0
|
|
||||||
output = self.model_runner.execute_model(seq_group_metadata_list,
|
|
||||||
self.tpu_cache)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def cache_swap(
|
@property
|
||||||
|
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||||
|
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
|
||||||
|
# parallelism.
|
||||||
|
return [self.tpu_cache]
|
||||||
|
|
||||||
|
def prepare_worker_input(
|
||||||
self,
|
self,
|
||||||
blocks_to_swap_in: List[Tuple[int, int]],
|
execute_model_req: ExecuteModelRequest,
|
||||||
blocks_to_swap_out: List[Tuple[int, int]],
|
) -> WorkerInput:
|
||||||
blocks_to_copy: List[Tuple[int, int]],
|
virtual_engine = execute_model_req.virtual_engine
|
||||||
) -> None:
|
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
|
||||||
|
blocks_to_swap_in = _make_src_to_dst(
|
||||||
|
execute_model_req.blocks_to_swap_in, "cpu", self.device)
|
||||||
|
blocks_to_swap_out = _make_src_to_dst(
|
||||||
|
execute_model_req.blocks_to_swap_out, self.device, "cpu")
|
||||||
|
blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy,
|
||||||
|
self.device, self.device)
|
||||||
|
return WorkerInput(
|
||||||
|
num_seq_groups=num_seq_groups,
|
||||||
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
virtual_engine=virtual_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||||
|
virtual_engine = worker_input.virtual_engine
|
||||||
|
assert virtual_engine == 0
|
||||||
attn_backend = self.model_runner.attn_backend
|
attn_backend = self.model_runner.attn_backend
|
||||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
|
|
||||||
if blocks_to_swap_in:
|
# Issue cache operations.
|
||||||
# Swap from CPU to TPU.
|
if worker_input.blocks_to_swap_in is not None:
|
||||||
src_indices, dst_indices = _make_src_to_dst(
|
src_indices, dst_indices = worker_input.blocks_to_swap_in
|
||||||
blocks_to_swap_in, "cpu", self.device)
|
if src_indices.numel() > 0:
|
||||||
for i in range(num_layers):
|
# Swap from CPU to TPU.
|
||||||
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
|
for i in range(num_layers):
|
||||||
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
|
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
|
||||||
k = cpu_k_cache[:, src_indices].to(self.device)
|
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
|
||||||
v = cpu_v_cache[:, src_indices].to(self.device)
|
k = cpu_k_cache[:, src_indices].to(self.device)
|
||||||
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
|
v = cpu_v_cache[:, src_indices].to(self.device)
|
||||||
|
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
|
||||||
|
|
||||||
if blocks_to_swap_out:
|
if worker_input.blocks_to_swap_out is not None:
|
||||||
# Swap from TPU to CPU.
|
src_indices, dst_indices = worker_input.blocks_to_swap_out
|
||||||
src_indices, dst_indices = _make_src_to_dst(
|
if src_indices.numel() > 0:
|
||||||
blocks_to_swap_out, self.device, "cpu")
|
# Swap from TPU to CPU.
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
|
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
|
||||||
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
|
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
|
||||||
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices].cpu()
|
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices]
|
||||||
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices].cpu()
|
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices]
|
||||||
|
|
||||||
if blocks_to_copy:
|
if worker_input.blocks_to_copy is not None:
|
||||||
src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
|
src_indices, dst_indices = worker_input.blocks_to_copy
|
||||||
self.device)
|
if src_indices.numel() > 0:
|
||||||
attn_backend.copy_blocks(self.tpu_cache, src_to_dst)
|
attn_backend.copy_blocks(self.tpu_cache,
|
||||||
|
(src_indices, dst_indices))
|
||||||
def start_worker_execution_loop(self) -> None:
|
|
||||||
while self._execute_model_non_driver():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _execute_model_non_driver(self) -> bool:
|
|
||||||
self.model_runner.execute_model(None, self.tpu_cache)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _make_src_to_dst(
|
def _make_src_to_dst(
|
||||||
|
|||||||
Reference in New Issue
Block a user