[V1][TPU] Support V1 Sampler for ragged attention (#14227)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
94
tests/v1/tpu/test_sampler.py
Normal file
94
tests/v1/tpu/test_sampler.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import tempfile
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm import LLM, envs
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
if not envs.VLLM_USE_V1:
|
||||||
|
pytest.skip(
|
||||||
|
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
|
||||||
|
allow_module_level=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"])
|
||||||
|
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||||
|
reason="This test needs a TPU")
|
||||||
|
def test_sampler_compilation(model_name: str, monkeypatch):
|
||||||
|
"""
|
||||||
|
Check that no recompilation happens despite changing sampling parameters.
|
||||||
|
We can't read XLA metrics from the engine process, hence we measure time.
|
||||||
|
"""
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", temp_dir)
|
||||||
|
# Compiling model init may still take some time, enforce_eager to skip.
|
||||||
|
llm = LLM(model_name,
|
||||||
|
enforce_eager=True,
|
||||||
|
max_num_seqs=16,
|
||||||
|
max_model_len=1024,
|
||||||
|
gpu_memory_utilization=0.5)
|
||||||
|
prompts = [
|
||||||
|
"A robot may not injure a human being",
|
||||||
|
"It is only with the heart that one can see rightly;",
|
||||||
|
]
|
||||||
|
# First inference should be slow
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.7,
|
||||||
|
# top_p=0.6, # TODO too slow!
|
||||||
|
# top_k=10,
|
||||||
|
min_p=0.2,
|
||||||
|
max_tokens=16)
|
||||||
|
s = time()
|
||||||
|
_ = llm.generate(prompts, sampling_params)
|
||||||
|
run1 = time() - s
|
||||||
|
|
||||||
|
# Second request with different params, but for which we
|
||||||
|
# compiled for in previous eager iteration.
|
||||||
|
sampling_params = SamplingParams(temperature=0.1,
|
||||||
|
min_p=0.8,
|
||||||
|
max_tokens=24)
|
||||||
|
s = time()
|
||||||
|
_ = llm.generate(prompts, sampling_params)
|
||||||
|
run2 = time() - s
|
||||||
|
# Much faster after compiling
|
||||||
|
assert run1 * 0.1 > run2
|
||||||
|
print("TIMES", run1, run2)
|
||||||
|
|
||||||
|
# Third request with min_p set to "None". It will not trigger
|
||||||
|
# recompilation as a default 0 value will be used.
|
||||||
|
sampling_params = SamplingParams(max_tokens=24, temperature=0.0)
|
||||||
|
s = time()
|
||||||
|
_ = llm.generate(prompts, sampling_params)
|
||||||
|
run3 = time() - s
|
||||||
|
assert run1 * 0.1 > run3
|
||||||
|
print("TIMES", run1, run3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
|
||||||
|
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||||
|
reason="This test needs a TPU")
|
||||||
|
def test_sampler_different(model_name: str):
|
||||||
|
"""
|
||||||
|
Test significantly different sampling params to assert the model produces
|
||||||
|
different results.
|
||||||
|
"""
|
||||||
|
llm = LLM(
|
||||||
|
model_name,
|
||||||
|
enforce_eager=True,
|
||||||
|
max_num_seqs=1,
|
||||||
|
max_model_len=64,
|
||||||
|
# TODO: setting to 0.5 or it will go OOM
|
||||||
|
gpu_memory_utilization=0.5)
|
||||||
|
prompts = [
|
||||||
|
"Write a short story about a robot that dreams for the first time."
|
||||||
|
]
|
||||||
|
sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64)
|
||||||
|
output = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
|
||||||
|
output2 = llm.generate(prompts, sampling_params)
|
||||||
|
assert output[0].outputs[0].text != output2[0].outputs[0].text
|
||||||
@@ -65,6 +65,8 @@ class TopKTopPSampler(nn.Module):
|
|||||||
"native implementation of top-p & top-k sampling. For the "
|
"native implementation of top-p & top-k sampling. For the "
|
||||||
"best performance, please install FlashInfer.")
|
"best performance, please install FlashInfer.")
|
||||||
self.forward = self.forward_native
|
self.forward = self.forward_native
|
||||||
|
elif current_platform.is_tpu():
|
||||||
|
self.forward = self.forward_tpu
|
||||||
else:
|
else:
|
||||||
self.forward = self.forward_native
|
self.forward = self.forward_native
|
||||||
|
|
||||||
@@ -96,6 +98,18 @@ class TopKTopPSampler(nn.Module):
|
|||||||
return random_sample(probs, generators)
|
return random_sample(probs, generators)
|
||||||
return flashinfer_sample(probs, k, p, generators)
|
return flashinfer_sample(probs, k, p, generators)
|
||||||
|
|
||||||
|
def forward_tpu(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
generators: dict[int, torch.Generator],
|
||||||
|
k: Optional[torch.Tensor],
|
||||||
|
p: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# TODO Placeholder for TPU optimized topk/p kernel
|
||||||
|
# logits = apply_top_k_top_p(logits, k, p)
|
||||||
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
|
return random_sample(probs, generators)
|
||||||
|
|
||||||
|
|
||||||
def apply_top_k_top_p(
|
def apply_top_k_top_p(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
@@ -112,7 +126,7 @@ def apply_top_k_top_p(
|
|||||||
|
|
||||||
if k is not None:
|
if k is not None:
|
||||||
# Apply top-k.
|
# Apply top-k.
|
||||||
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
|
||||||
# Get all the top_k values.
|
# Get all the top_k values.
|
||||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||||
top_k_mask = logits_sort < top_k_mask
|
top_k_mask = logits_sort < top_k_mask
|
||||||
|
|||||||
0
vllm/v1/sample/tpu/__init__.py
Normal file
0
vllm/v1/sample/tpu/__init__.py
Normal file
159
vllm/v1/sample/tpu/metadata.py
Normal file
159
vllm/v1/sample/tpu/metadata.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TPUSupportedSamplingMetadata:
|
||||||
|
# This class exposes a more xla-friendly interface than SamplingMetadata
|
||||||
|
# on TPU, in particular all arguments should be traceable and no optionals
|
||||||
|
# are allowed, to avoid graph recompilation on Nones.
|
||||||
|
temperature: torch.Tensor
|
||||||
|
|
||||||
|
min_p: torch.Tensor
|
||||||
|
# Still too slow on forward_native!
|
||||||
|
top_k: torch.Tensor = None
|
||||||
|
top_p: torch.Tensor = None
|
||||||
|
|
||||||
|
# XLA-unfriendly control flow in Sampler
|
||||||
|
all_greedy: bool = False
|
||||||
|
all_random: bool = False
|
||||||
|
# Greedy sampling flag for compiling single xla graph.
|
||||||
|
do_argmax: torch.Tensor = None
|
||||||
|
|
||||||
|
# speculation not supported
|
||||||
|
spec_token_ids = None
|
||||||
|
|
||||||
|
# Generator not supported by xla
|
||||||
|
generators: dict[int,
|
||||||
|
torch.Generator] = field(default_factory=lambda: dict())
|
||||||
|
|
||||||
|
# unsupported, you need to return an extra tensor of static size BxV
|
||||||
|
max_num_logprobs = None
|
||||||
|
|
||||||
|
# TODO No penalties for now
|
||||||
|
no_penalties: bool = True
|
||||||
|
prompt_token_ids = None
|
||||||
|
frequency_penalties = None
|
||||||
|
presence_penalties = None
|
||||||
|
repetition_penalties = None
|
||||||
|
# should use tensor
|
||||||
|
output_token_ids: list[list[int]] = field(default_factory=lambda: list())
|
||||||
|
|
||||||
|
min_tokens = None # impl is not vectorized
|
||||||
|
|
||||||
|
logit_bias: list[Optional[dict[int, float]]] = field(
|
||||||
|
default_factory=lambda: list())
|
||||||
|
|
||||||
|
allowed_token_ids_mask = None
|
||||||
|
bad_words_token_ids = None
|
||||||
|
indices_do_sample: torch.Tensor = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
temp = self.temperature
|
||||||
|
if self.indices_do_sample is None:
|
||||||
|
self.indices_do_sample = torch.zeros(temp.shape[0],
|
||||||
|
device=temp.device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
if self.do_argmax is None:
|
||||||
|
self.do_argmax = torch.tensor(0,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=temp.device)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_sampling_metadata(
|
||||||
|
cls, metadata: SamplingMetadata,
|
||||||
|
padded_do_sample_indices: torch.Tensor, num_do_sample: int,
|
||||||
|
device: torch.device) -> "TPUSupportedSamplingMetadata":
|
||||||
|
"""
|
||||||
|
Create an XLA-frienly SamplingMetadata structure. Do so by first
|
||||||
|
instantiating an object with fixed-sized tensors and then writing the
|
||||||
|
values in input `metadata`. Do that only for non-None values so that
|
||||||
|
recompilation is not triggered for optional values (None/torch.Tensor).
|
||||||
|
|
||||||
|
In order to handle different sizes for the params that range from 1 up
|
||||||
|
to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
|
||||||
|
Same thing for `padded_do_sample_indices`, which contains the indices
|
||||||
|
to be fed to the Sampler, padded to the closest pre-compiled shape.
|
||||||
|
|
||||||
|
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
|
||||||
|
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
|
||||||
|
"""
|
||||||
|
metadata = cls._validate_sampling_metadata(metadata)
|
||||||
|
# NOTE we have to initialize default tensor-based params first and
|
||||||
|
# skip None values altogether to produce the same xla graph.
|
||||||
|
num_samples = len(padded_do_sample_indices)
|
||||||
|
do_argmax = torch.tensor(metadata.all_greedy,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device)
|
||||||
|
new_metadata = cls.get_default_sampling_params(num_samples, device,
|
||||||
|
indices_do_sample=\
|
||||||
|
padded_do_sample_indices,
|
||||||
|
do_argmax=do_argmax
|
||||||
|
)
|
||||||
|
supported_params = \
|
||||||
|
TPUSupportedSamplingMetadata._get_default_params_values()
|
||||||
|
# Copy input non-None values into `new_metadata` fixed-sized tensors.
|
||||||
|
for p_name in supported_params:
|
||||||
|
old_val = getattr(metadata, p_name)
|
||||||
|
new_val = getattr(new_metadata, p_name)
|
||||||
|
if isinstance(old_val, torch.Tensor):
|
||||||
|
new_val[:num_do_sample] = old_val
|
||||||
|
setattr(new_metadata, p_name, new_val)
|
||||||
|
|
||||||
|
xm.mark_step()
|
||||||
|
xm.wait_device_ops()
|
||||||
|
return new_metadata
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_sampling_params(
|
||||||
|
cls,
|
||||||
|
num_samples: int,
|
||||||
|
device: torch.device,
|
||||||
|
indices_do_sample=None,
|
||||||
|
do_argmax=None) -> "TPUSupportedSamplingMetadata":
|
||||||
|
# As sampling happens on a single traced graph, options
|
||||||
|
# are "disabled" by having them evaluate to an Identity op.
|
||||||
|
# Note that initialization is dependent on num_samples.
|
||||||
|
sampling_metadata_disable_value = \
|
||||||
|
TPUSupportedSamplingMetadata._get_default_params_values()
|
||||||
|
init_kwargs = dict()
|
||||||
|
for p_name, (default_val,
|
||||||
|
dtype) in sampling_metadata_disable_value.items():
|
||||||
|
default_tensor = torch.full((num_samples, ),
|
||||||
|
default_val,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
init_kwargs[p_name] = default_tensor
|
||||||
|
|
||||||
|
return cls(**init_kwargs,
|
||||||
|
indices_do_sample=indices_do_sample,
|
||||||
|
do_argmax=do_argmax)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_sampling_metadata(
|
||||||
|
sampling_metadata: SamplingMetadata) -> SamplingMetadata:
|
||||||
|
if sampling_metadata.all_greedy:
|
||||||
|
# Set to None since #13587. Make sure default isn't overruled.
|
||||||
|
assert sampling_metadata.temperature is None
|
||||||
|
return sampling_metadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_default_params_values():
|
||||||
|
return dict(
|
||||||
|
# Since #13587 greedy sampling requires branching off which leads
|
||||||
|
# to separate graphs. We set temp to noop and handle argmax here.
|
||||||
|
temperature=(1.0, torch.float32),
|
||||||
|
min_p=(0.0, torch.float32),
|
||||||
|
# strictly disabled for now
|
||||||
|
# top_k=(-1, torch.int32),
|
||||||
|
# top_p=(0.0, torch.float32),
|
||||||
|
# frequency_penalties=(0.0, torch.float32),
|
||||||
|
# presence_penalties=(0.0, torch.float32),
|
||||||
|
# repetition_penalties=(0.0, torch.float32),
|
||||||
|
)
|
||||||
154
vllm/v1/sample/tpu/sampler.py
Normal file
154
vllm/v1/sample/tpu/sampler.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Sampler layer implementing TPU supported operations."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||||
|
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||||
|
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||||
|
|
||||||
|
_SAMPLING_EPS = 1e-5
|
||||||
|
|
||||||
|
|
||||||
|
class Sampler(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.topk_topp_sampler = TopKTopPSampler()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
|
# NOTE(woosuk): Use the original logits (before any penalties or
|
||||||
|
# temperature scaling) for the top-k logprobs.
|
||||||
|
# This is different from the V0 sampler, which uses the logits that
|
||||||
|
# is used for sampling (after penalties and temperature scaling).
|
||||||
|
|
||||||
|
# Use float32 for the logits.
|
||||||
|
logits = logits.to(torch.float32)
|
||||||
|
# Sample the next token.
|
||||||
|
sampled = self.sample(logits, sampling_metadata)
|
||||||
|
|
||||||
|
# Use int32 to reduce the tensor size.
|
||||||
|
sampled = sampled.to(torch.int32)
|
||||||
|
|
||||||
|
# These are GPU tensors.
|
||||||
|
sampler_output = SamplerOutput(
|
||||||
|
# The sampled tokens are expanded to 2D tensor with shape
|
||||||
|
# [num_requests, 1], where each row represents one generated
|
||||||
|
# token per request.
|
||||||
|
sampled_token_ids=sampled.unsqueeze(-1),
|
||||||
|
logprobs_tensors=None,
|
||||||
|
)
|
||||||
|
return sampler_output
|
||||||
|
|
||||||
|
def apply_temperature(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
temp: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Use in-place division to avoid creating a new tensor.
|
||||||
|
return logits.div_(temp.unsqueeze(dim=1))
|
||||||
|
|
||||||
|
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
|
return logits.argmax(dim=-1).view(-1)
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
greedy_sampled = self.greedy_sample(logits)
|
||||||
|
|
||||||
|
assert sampling_metadata.temperature is not None
|
||||||
|
|
||||||
|
# Apply temperature.
|
||||||
|
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||||
|
|
||||||
|
# Apply min_p.
|
||||||
|
if sampling_metadata.min_p is not None:
|
||||||
|
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
||||||
|
|
||||||
|
# Apply top_k and/or top_p.
|
||||||
|
random_sampled = self.topk_topp_sampler(
|
||||||
|
logits,
|
||||||
|
sampling_metadata.generators,
|
||||||
|
sampling_metadata.top_k,
|
||||||
|
sampling_metadata.top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS,
|
||||||
|
greedy_sampled, random_sampled)
|
||||||
|
return sampled
|
||||||
|
|
||||||
|
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
|
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||||
|
|
||||||
|
def gather_logprobs(
|
||||||
|
self,
|
||||||
|
logprobs: torch.Tensor,
|
||||||
|
num_logprobs: int,
|
||||||
|
token_ids: torch.Tensor,
|
||||||
|
) -> LogprobsTensors:
|
||||||
|
"""
|
||||||
|
Gather logprobs for topk and sampled/prompt token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: (num tokens) x (vocab) tensor
|
||||||
|
num_logprobs: minimum number of logprobs to
|
||||||
|
retain per token
|
||||||
|
token_ids: prompt tokens (if prompt logprobs)
|
||||||
|
or sampled tokens (if sampled
|
||||||
|
logprobs); 1D token ID tensor
|
||||||
|
with (num tokens) elements
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||||
|
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||||
|
Sampled token rank tensor, (num tokens)
|
||||||
|
"""
|
||||||
|
# Find the topK values.
|
||||||
|
topk_logprobs, topk_indices = torch.topk(logprobs,
|
||||||
|
num_logprobs,
|
||||||
|
dim=-1)
|
||||||
|
|
||||||
|
# Get with the logprob of the prompt or sampled token.
|
||||||
|
token_ids = token_ids.unsqueeze(-1)
|
||||||
|
token_logprobs = logprobs.gather(-1, token_ids)
|
||||||
|
|
||||||
|
# Compute the ranks of the actual token.
|
||||||
|
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||||
|
|
||||||
|
# Concatenate together with the topk.
|
||||||
|
indices = torch.cat((token_ids, topk_indices), dim=1)
|
||||||
|
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
||||||
|
|
||||||
|
# Use int32 to reduce the tensor size.
|
||||||
|
indices = indices.to(torch.int32)
|
||||||
|
|
||||||
|
return LogprobsTensors(indices, logprobs, token_ranks)
|
||||||
|
|
||||||
|
def apply_min_p(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
min_p: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Filters logits using adaptive probability thresholding.
|
||||||
|
"""
|
||||||
|
# Convert logits to probability distribution
|
||||||
|
probability_values = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
|
# Calculate maximum probabilities per sequence
|
||||||
|
max_probabilities = torch.amax(probability_values,
|
||||||
|
dim=-1,
|
||||||
|
keepdim=True)
|
||||||
|
# Reshape min_p for broadcasting
|
||||||
|
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
|
||||||
|
# Identify valid tokens using threshold comparison
|
||||||
|
valid_token_mask = probability_values >= adjusted_min_p
|
||||||
|
# Apply mask using boolean indexing (xla friendly)
|
||||||
|
logits.masked_fill_(~valid_token_mask, -float("inf"))
|
||||||
|
return logits
|
||||||
@@ -23,13 +23,16 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
|
|||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||||
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
||||||
|
PallasAttentionBackend,
|
||||||
PallasMetadata)
|
PallasMetadata)
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||||
ModelRunnerOutput)
|
ModelRunnerOutput, SamplerOutput)
|
||||||
|
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||||
|
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
||||||
from vllm.v1.utils import bind_kv_cache
|
from vllm.v1.utils import bind_kv_cache
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
@@ -42,6 +45,8 @@ logger = init_logger(__name__)
|
|||||||
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
|
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
|
||||||
_PAD_SLOT_ID = 1_000_000_000
|
_PAD_SLOT_ID = 1_000_000_000
|
||||||
INVALID_TOKEN_ID = -1
|
INVALID_TOKEN_ID = -1
|
||||||
|
# Smallest output size
|
||||||
|
MIN_NUM_SEQS = 8
|
||||||
|
|
||||||
|
|
||||||
class TPUModelRunner:
|
class TPUModelRunner:
|
||||||
@@ -138,8 +143,10 @@ class TPUModelRunner:
|
|||||||
device="cpu")
|
device="cpu")
|
||||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||||
|
|
||||||
|
padded_max_num_blocks_per_req = _get_padded_number(
|
||||||
|
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
|
||||||
self.block_table_cpu = torch.zeros(
|
self.block_table_cpu = torch.zeros(
|
||||||
(self.max_num_tokens, self.max_num_blocks_per_req),
|
(self.max_num_tokens, padded_max_num_blocks_per_req),
|
||||||
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
||||||
device="cpu")
|
device="cpu")
|
||||||
|
|
||||||
@@ -267,6 +274,9 @@ class TPUModelRunner:
|
|||||||
req_data.num_computed_tokens)
|
req_data.num_computed_tokens)
|
||||||
self.input_batch.block_table.append_row(req_data.new_block_ids,
|
self.input_batch.block_table.append_row(req_data.new_block_ids,
|
||||||
req_index)
|
req_index)
|
||||||
|
# Check if the batch has changed. If not, we can skip copying the
|
||||||
|
# sampling metadata from CPU to GPU.
|
||||||
|
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
|
||||||
|
|
||||||
# Add the new or resumed requests to the persistent batch.
|
# Add the new or resumed requests to the persistent batch.
|
||||||
# The smaller empty indices are filled first.
|
# The smaller empty indices are filled first.
|
||||||
@@ -284,6 +294,10 @@ class TPUModelRunner:
|
|||||||
# Condense the batched states if there are empty indices.
|
# Condense the batched states if there are empty indices.
|
||||||
if removed_req_indices:
|
if removed_req_indices:
|
||||||
self.input_batch.condense(removed_req_indices)
|
self.input_batch.condense(removed_req_indices)
|
||||||
|
|
||||||
|
# TODO This slices tensors to copy to device, triggering recompilation.
|
||||||
|
if batch_changed:
|
||||||
|
self.input_batch.refresh_sampling_metadata()
|
||||||
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
|
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
|
||||||
|
|
||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
@@ -447,6 +461,8 @@ class TPUModelRunner:
|
|||||||
# TODO: Support prompt logprobs.
|
# TODO: Support prompt logprobs.
|
||||||
padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
|
padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
|
||||||
num_reqs, self.max_num_reqs)
|
num_reqs, self.max_num_reqs)
|
||||||
|
# Indices at which we sample (positions of last token in the sequence).
|
||||||
|
# Padded to avoid recompiling when `num_reqs` varies.
|
||||||
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
|
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
|
||||||
logits_indices = logits_indices.to(self.device)
|
logits_indices = logits_indices.to(self.device)
|
||||||
return attn_metadata, logits_indices
|
return attn_metadata, logits_indices
|
||||||
@@ -576,7 +592,14 @@ class TPUModelRunner:
|
|||||||
# then the embedding layer is not included in the CUDA graph.
|
# then the embedding layer is not included in the CUDA graph.
|
||||||
input_ids = self.input_ids
|
input_ids = self.input_ids
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
|
sampling_metadata = self.input_batch.sampling_metadata
|
||||||
|
num_reqs = self.input_batch.num_reqs
|
||||||
|
# NOTE (NickLucche) here we sync with TPU: if there's any shape
|
||||||
|
# mismatch in pre-processing, it will trigger a small recompilation
|
||||||
|
# of the code thus far. Forward graph remains untouched.
|
||||||
|
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
|
||||||
|
from_sampling_metadata(sampling_metadata, logits_indices,
|
||||||
|
num_reqs, self.device)
|
||||||
# Run the decoder
|
# Run the decoder
|
||||||
with set_forward_context(attn_metadata, self.vllm_config):
|
with set_forward_context(attn_metadata, self.vllm_config):
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
@@ -585,12 +608,13 @@ class TPUModelRunner:
|
|||||||
kv_caches=self.kv_caches,
|
kv_caches=self.kv_caches,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
num_reqs = self.input_batch.num_reqs
|
selected_token_ids = self.model.sample_from_hidden(
|
||||||
selected_token_ids = self.model.compute_logits(hidden_states,
|
hidden_states, tpu_sampling_metadata)
|
||||||
logits_indices, None)
|
# Remove padding on cpu and keep dynamic op outside of xla graph.
|
||||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||||
|
|
||||||
# Then, let's update the cache state.
|
# Update the cache state concurrently. Code above will not block until
|
||||||
|
# we use `selected_token_ids`. Add mark_step if post-processing changes
|
||||||
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
|
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
|
||||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
||||||
assert req_id is not None
|
assert req_id is not None
|
||||||
@@ -607,7 +631,6 @@ class TPUModelRunner:
|
|||||||
# This relies on cuda-specific torch-internal impl details
|
# This relies on cuda-specific torch-internal impl details
|
||||||
generator.set_offset(generator.get_offset() - 4)
|
generator.set_offset(generator.get_offset() - 4)
|
||||||
|
|
||||||
# num_reqs entries should be non-None
|
|
||||||
assert all(
|
assert all(
|
||||||
req_id is not None for req_id in
|
req_id is not None for req_id in
|
||||||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
||||||
@@ -620,6 +643,7 @@ class TPUModelRunner:
|
|||||||
max_gen_len = selected_token_ids.shape[-1]
|
max_gen_len = selected_token_ids.shape[-1]
|
||||||
if max_gen_len == 1:
|
if max_gen_len == 1:
|
||||||
valid_sampled_token_ids = selected_token_ids.tolist()
|
valid_sampled_token_ids = selected_token_ids.tolist()
|
||||||
|
|
||||||
for i, req_state, seq_len in request_seq_lens:
|
for i, req_state, seq_len in request_seq_lens:
|
||||||
token_id = valid_sampled_token_ids[i][0]
|
token_id = valid_sampled_token_ids[i][0]
|
||||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
||||||
@@ -676,11 +700,8 @@ class TPUModelRunner:
|
|||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
dynamic=False)
|
dynamic=False)
|
||||||
|
|
||||||
def _dummy_run(
|
@torch.no_grad()
|
||||||
self,
|
def _dummy_run(self, kv_caches, num_tokens: int) -> None:
|
||||||
kv_caches,
|
|
||||||
num_tokens: int,
|
|
||||||
) -> None:
|
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
|
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
|
||||||
@@ -729,32 +750,10 @@ class TPUModelRunner:
|
|||||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||||
|
|
||||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||||
assert self.model is not None
|
self.model(input_ids=input_ids,
|
||||||
hidden_states = self.model(
|
positions=position_ids,
|
||||||
input_ids=input_ids,
|
kv_caches=kv_caches,
|
||||||
positions=position_ids,
|
inputs_embeds=inputs_embeds)
|
||||||
kv_caches=kv_caches,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
)
|
|
||||||
num_reqs = _get_padded_num_reqs_with_upper_limit(
|
|
||||||
64, self.max_num_reqs)
|
|
||||||
# NOTE(chengjiyao): In total, the compute_logits function utilizes a
|
|
||||||
# compilation cache size of token_bucket_num multiplied by
|
|
||||||
# req_bucket_num. This is acceptable, given the graph's relatively
|
|
||||||
# small size.
|
|
||||||
while True:
|
|
||||||
logits_indices = torch.zeros(
|
|
||||||
num_reqs,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
torch._dynamo.mark_dynamic(hidden_states, 0)
|
|
||||||
torch._dynamo.mark_dynamic(logits_indices, 0)
|
|
||||||
self.model.compute_logits(hidden_states, logits_indices, None)
|
|
||||||
if num_reqs >= self.max_num_reqs:
|
|
||||||
break
|
|
||||||
num_reqs = _get_padded_num_reqs_with_upper_limit(
|
|
||||||
num_reqs + 1, self.max_num_reqs)
|
|
||||||
|
|
||||||
def capture_model(self) -> None:
|
def capture_model(self) -> None:
|
||||||
"""Compile the model."""
|
"""Compile the model."""
|
||||||
@@ -764,13 +763,51 @@ class TPUModelRunner:
|
|||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
num_tokens = 16
|
num_tokens = 16
|
||||||
while True:
|
while True:
|
||||||
self._dummy_run(self.kv_caches, num_tokens)
|
|
||||||
logger.info(" -- num_tokens: %d", num_tokens)
|
logger.info(" -- num_tokens: %d", num_tokens)
|
||||||
|
self._dummy_run(self.kv_caches, num_tokens)
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
xm.wait_device_ops()
|
|
||||||
if num_tokens >= self.max_num_tokens:
|
if num_tokens >= self.max_num_tokens:
|
||||||
break
|
break
|
||||||
num_tokens *= 2
|
num_tokens *= 2
|
||||||
|
xm.wait_device_ops()
|
||||||
|
end = time.perf_counter()
|
||||||
|
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||||
|
|
||||||
|
logger.info("Compiling sampling with different input shapes.")
|
||||||
|
start = time.perf_counter()
|
||||||
|
num_tokens = 16
|
||||||
|
hsize = self.model_config.get_hidden_size()
|
||||||
|
device = self.device
|
||||||
|
# Compile sampling step for different model+sampler outputs in bucketed
|
||||||
|
# n_tokens x max_num_reqs. Graph is really small so this is fine.
|
||||||
|
while True:
|
||||||
|
num_reqs_to_sample = MIN_NUM_SEQS
|
||||||
|
dummy_hidden = torch.randn((num_tokens, hsize),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.bfloat16)
|
||||||
|
while True:
|
||||||
|
# Default metadata is an all_greedy setup. But since the
|
||||||
|
# `do_argmax` flag is a tensor, we still compile the full graph
|
||||||
|
meta = self.input_batch.sampling_metadata
|
||||||
|
indices = torch.zeros(
|
||||||
|
num_reqs_to_sample,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
sampling_meta = TPUSupportedSamplingMetadata.\
|
||||||
|
from_sampling_metadata(meta, indices,
|
||||||
|
num_reqs_to_sample, device)
|
||||||
|
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
|
||||||
|
num_reqs_to_sample)
|
||||||
|
self.model.sample_from_hidden(dummy_hidden, sampling_meta)
|
||||||
|
xm.mark_step()
|
||||||
|
if num_reqs_to_sample >= self.max_num_reqs:
|
||||||
|
break
|
||||||
|
num_reqs_to_sample *= 2
|
||||||
|
if num_tokens >= self.max_num_tokens:
|
||||||
|
break
|
||||||
|
num_tokens *= 2
|
||||||
|
xm.wait_device_ops()
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||||
|
|
||||||
@@ -818,6 +855,13 @@ class ModelWrapperV1(nn.Module):
|
|||||||
def __init__(self, model: nn.Module):
|
def __init__(self, model: nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.sampler = TPUSampler()
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self, logits: torch.Tensor,
|
||||||
|
sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput:
|
||||||
|
sampler_out = self.sampler(logits, sampling_metadata)
|
||||||
|
return sampler_out
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -826,7 +870,7 @@ class ModelWrapperV1(nn.Module):
|
|||||||
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
|
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Executes the forward pass of the model and samples the next token.
|
"""Executes the forward pass of the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_ids: The input token IDs of shape [num_tokens].
|
input_ids: The input token IDs of shape [num_tokens].
|
||||||
@@ -837,7 +881,6 @@ class ModelWrapperV1(nn.Module):
|
|||||||
hidden_size]. It is used for multimodal models.
|
hidden_size]. It is used for multimodal models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert self.model is not None
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
@@ -846,17 +889,33 @@ class ModelWrapperV1(nn.Module):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
def sample_from_hidden(
|
||||||
def compute_logits(
|
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
logits_indices: torch.Tensor,
|
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||||
sampling_metadata,
|
) -> torch.Tensor:
|
||||||
) -> Optional[torch.Tensor]:
|
"""
|
||||||
hidden_states = hidden_states[logits_indices]
|
Sample with xla-friendly function. This function is to be traced
|
||||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
separately from `forward` for lighter compilation overhead.
|
||||||
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
"""
|
||||||
return selected_token_ids
|
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
|
||||||
|
sample_hidden_states = \
|
||||||
|
hidden_states[sampling_metadata.indices_do_sample]
|
||||||
|
logits = self.compute_logits(sample_hidden_states)
|
||||||
|
# Greedy sampling can't be run without branching the graph on Sampler.
|
||||||
|
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way.
|
||||||
|
# NOTE do_argmax is a scalar, this is just an optimized if/else.
|
||||||
|
out_tokens = torch.where(sampling_metadata.do_argmax,
|
||||||
|
torch.argmax(logits, dim=-1, keepdim=True),
|
||||||
|
self.sample(logits, sampling_metadata)\
|
||||||
|
.sampled_token_ids)
|
||||||
|
return out_tokens
|
||||||
|
|
||||||
|
def compute_logits(self,
|
||||||
|
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
|
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
|
||||||
|
logits = self.model.compute_logits(hidden_states, None)
|
||||||
|
return logits
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, *args, **kwargs):
|
def get_multimodal_embeddings(self, *args, **kwargs):
|
||||||
return self.model.get_multimodal_embeddings(*args, **kwargs)
|
return self.model.get_multimodal_embeddings(*args, **kwargs)
|
||||||
@@ -876,5 +935,5 @@ def _get_padded_token_len(x: int) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
|
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
|
||||||
res = 64 if x <= 64 else 1 << (x - 1).bit_length()
|
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
|
||||||
return min(res, upper_limit)
|
return min(res, upper_limit)
|
||||||
|
|||||||
Reference in New Issue
Block a user