From f74481018412e4bc63d0fc396ec675ca4bf9bf18 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Thu, 22 Jan 2026 15:35:18 -0500 Subject: [PATCH] [Refactor] Remove unused tpu files (#32610) Signed-off-by: yewentao256 --- vllm/v1/sample/tpu/__init__.py | 0 vllm/v1/sample/tpu/metadata.py | 120 ------------------ vllm/v1/sample/tpu/sampler.py | 215 --------------------------------- vllm/v1/worker/tpu_worker.py | 18 --- 4 files changed, 353 deletions(-) delete mode 100644 vllm/v1/sample/tpu/__init__.py delete mode 100644 vllm/v1/sample/tpu/metadata.py delete mode 100644 vllm/v1/sample/tpu/sampler.py delete mode 100644 vllm/v1/worker/tpu_worker.py diff --git a/vllm/v1/sample/tpu/__init__.py b/vllm/v1/sample/tpu/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py deleted file mode 100644 index 0c1a22e84..000000000 --- a/vllm/v1/sample/tpu/metadata.py +++ /dev/null @@ -1,120 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass, field - -import torch - -from vllm.v1.worker.tpu_input_batch import InputBatch - -DEFAULT_SAMPLING_PARAMS = dict( - temperature=-1.0, - min_p=0.0, - # strictly disabled for now - top_k=0, - top_p=1.0, - # frequency_penalties=0.0, - # presence_penalties=0.0, - # repetition_penalties=0.0, -) - - -@dataclass -class TPUSupportedSamplingMetadata: - # This class exposes a more xla-friendly interface than SamplingMetadata - # on TPU, in particular all arguments should be traceable and no optionals - # are allowed, to avoid graph recompilation on Nones. - temperature: torch.Tensor = None - - min_p: torch.Tensor = None - top_k: torch.Tensor = None - top_p: torch.Tensor = None - - all_greedy: bool = True - all_random: bool = False - - # Whether logprobs are to be gathered in this batch of request. To balance - # out compile time and runtime, a fixed `max_number_logprobs` value is used - # when gathering logprobs, regardless of the values specified in the batch. - logprobs: bool = False - - # TODO No penalties for now - no_penalties: bool = True - prompt_token_ids = None - frequency_penalties = None - presence_penalties = None - repetition_penalties = None - # should use tensor - output_token_ids: list[list[int]] = field(default_factory=lambda: list()) - - min_tokens = None # impl is not vectorized - - logit_bias: list[dict[int, float] | None] = field(default_factory=lambda: list()) - - allowed_token_ids_mask = None - bad_words_token_ids = None - - # Generator not supported by xla - _generators: dict[int, torch.Generator] = field(default_factory=lambda: dict()) - - @property - def generators(self) -> dict[int, torch.Generator]: - # Generator not supported by torch/xla. This field must be immutable. - return self._generators - - @classmethod - def from_input_batch( - cls, - input_batch: InputBatch, - padded_num_reqs: int, - xla_device: torch.device, - generate_params_if_all_greedy: bool = False, - ) -> "TPUSupportedSamplingMetadata": - """ - Copy sampling tensors slices from `input_batch` to on device tensors. - - `InputBatch._make_sampling_metadata` causes recompilation on XLA as it - slices dynamic shapes on device tensors. This impl moves the dynamic - ops to CPU and produces tensors of fixed `padded_num_reqs` size. - - Args: - input_batch: The input batch containing sampling parameters. - padded_num_reqs: The padded number of requests. - xla_device: The XLA device. - generate_params_if_all_greedy: If True, generate sampling parameters - even if all requests are greedy. this is useful for cases where - we want to pre-compile a graph with sampling parameters, even if - they are not strictly needed for greedy decoding. - """ - needs_logprobs = ( - input_batch.max_num_logprobs > 0 if input_batch.max_num_logprobs else False - ) - # Early return to avoid unnecessary cpu to tpu copy - if input_batch.all_greedy is True and generate_params_if_all_greedy is False: - return cls(all_greedy=True, logprobs=needs_logprobs) - - num_reqs = input_batch.num_reqs - - def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor: - # Pad value is the default one. - cpu_tensor[num_reqs:padded_num_reqs] = fill_val - - fill_slice( - input_batch.temperature_cpu_tensor, DEFAULT_SAMPLING_PARAMS["temperature"] - ) - fill_slice(input_batch.min_p_cpu_tensor, DEFAULT_SAMPLING_PARAMS["min_p"]) - fill_slice(input_batch.top_k_cpu_tensor, DEFAULT_SAMPLING_PARAMS["top_k"]) - fill_slice(input_batch.top_p_cpu_tensor, DEFAULT_SAMPLING_PARAMS["top_p"]) - - # Slice persistent device tensors to a fixed pre-compiled padded shape. - return cls( - temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].to( - xla_device - ), - all_greedy=input_batch.all_greedy, - all_random=input_batch.all_random, - # TODO enable more and avoid returning None values - top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(xla_device), - top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(xla_device), - min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(xla_device), - logprobs=needs_logprobs, - ) diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py deleted file mode 100644 index 6d992bb37..000000000 --- a/vllm/v1/sample/tpu/sampler.py +++ /dev/null @@ -1,215 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Sampler layer implementing TPU supported operations.""" - -import torch -import torch.nn as nn - -from vllm.v1.outputs import LogprobsTensors, SamplerOutput -from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata - -_SAMPLING_EPS = 1e-5 - - -class Sampler(nn.Module): - def __init__(self): - # TODO(houseroad): Add support for logprobs_mode. - super().__init__() - - def forward( - self, - logits: torch.Tensor, - sampling_metadata: TPUSupportedSamplingMetadata, - ) -> SamplerOutput: - # Use float32 for the logits. - logits = logits.to(torch.float32) - # Sample the next token. - sampled = self.sample(logits, sampling_metadata) - - # These are TPU tensors. - sampler_output = SamplerOutput( - # The sampled tokens are expanded to 2D tensor with shape - # [num_requests, 1], where each row represents one generated - # token per request. - sampled_token_ids=sampled.unsqueeze(-1), - logprobs_tensors=None, - ) - return sampler_output - - def apply_temperature( - self, - logits: torch.Tensor, - temp: torch.Tensor, - all_random: bool = False, - ) -> torch.Tensor: - # Avoid division by zero for greedy sampling (temperature ~ 0.0). - if not all_random: - temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) - return logits.div_(temp.unsqueeze(dim=1)) - - def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: - return logits.argmax(dim=-1).view(-1) - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: TPUSupportedSamplingMetadata, - ) -> torch.Tensor: - greedy_sampled = self.greedy_sample(logits) - - assert sampling_metadata.temperature is not None - - # Apply temperature. - logits = self.apply_temperature( - logits, sampling_metadata.temperature, sampling_metadata.all_random - ) - - # Apply min_p. - if sampling_metadata.min_p is not None: - logits = self.apply_min_p(logits, sampling_metadata.min_p) - - # Apply top_k and/or top_p. - logits = apply_top_k_top_p( - logits, - sampling_metadata.top_k, - sampling_metadata.top_p, - ) - - # Random sample. - probs = logits.softmax(dim=-1, dtype=torch.float32) - random_sampled = self.random_sample(probs, sampling_metadata.generators) - - sampled = torch.where( - sampling_metadata.temperature < _SAMPLING_EPS, - greedy_sampled, - random_sampled, - ) - return sampled - - def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: - return logits.log_softmax(dim=-1, dtype=torch.float32) - - def gather_logprobs( - self, - logprobs: torch.Tensor, - num_logprobs: int, - token_ids: torch.Tensor, - ) -> LogprobsTensors: - """ - Gather logprobs for topk and sampled/prompt token. - - Args: - logprobs: (num tokens) x (vocab) tensor - num_logprobs: minimum number of logprobs to - retain per token - token_ids: prompt tokens (if prompt logprobs) - or sampled tokens (if sampled - logprobs); 1D token ID tensor - with (num tokens) elements - - Returns: - Top-k int indices tensor, (num tokens) x (num_logprobs + 1) - Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) - Sampled token rank tensor, (num tokens) - """ - # Find the topK values. - topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1) - - # Get with the logprob of the prompt or sampled token. - token_ids = token_ids.unsqueeze(-1) - token_logprobs = logprobs.gather(-1, token_ids) - - # Compute the ranks of the actual token. - token_ranks = (logprobs >= token_logprobs).sum(-1) - - # Concatenate together with the topk. - indices = torch.cat((token_ids, topk_indices), dim=1) - logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1) - - # Use int32 to reduce the tensor size. - indices = indices.to(torch.int32) - - return LogprobsTensors(indices, logprobs, token_ranks) - - def apply_min_p( - self, - logits: torch.Tensor, - min_p: torch.Tensor, - ) -> torch.Tensor: - """ - Filters logits using adaptive probability thresholding. - """ - # Convert logits to probability distribution - probability_values = torch.nn.functional.softmax(logits, dim=-1) - # Calculate maximum probabilities per sequence - max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) - # Reshape min_p for broadcasting - adjusted_min_p = min_p.unsqueeze(1) * max_probabilities - # Identify valid tokens using threshold comparison - valid_token_mask = probability_values >= adjusted_min_p - # Apply mask using boolean indexing (xla friendly) - logits.masked_fill_(~valid_token_mask, -float("inf")) - return logits - - def random_sample( - self, - probs: torch.Tensor, - generators: dict[int, torch.Generator], - ) -> torch.Tensor: - q = torch.empty_like(probs) - # NOTE(woosuk): To batch-process the requests without their own seeds, - # which is the common case, we first assume that every request does - # not have its own seed. Then, we overwrite the values for the requests - # that have their own seeds. - q.exponential_() - if generators: - for i, generator in generators.items(): - q[i].exponential_(generator=generator) - return probs.div_(q).argmax(dim=-1).view(-1) - - -def apply_top_k_top_p( - logits: torch.Tensor, - k: torch.Tensor | None, - p: torch.Tensor | None, -) -> torch.Tensor: - """ - Apply top-k and top-p optimized for TPU. - - This algorithm avoids using torch.scatter which is extremely slow on TPU. - This is achieved by finding a "cut-off" element in the original logit, and - after thresholding the logit using this cut-off, the remaining elements - shall constitute the top-p set. - - Note: in the case of tie (i.e. multiple cut-off elements present in the - logit), all tie elements are included in the top-p set. In other words, - this function does not break ties. Instead, these tie tokens have equal - chance of being chosen during final sampling, so we can consider the tie - being broken then. - """ - probs = logits.softmax(dim=-1) - probs_sort, _ = probs.sort(dim=-1, descending=False) - - if k is not None: - top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) - top_k_count = top_k_count.unsqueeze(dim=1) - top_k_cutoff = probs_sort.gather(-1, top_k_count) - - # Make sure the no top-k rows are no-op. - no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) - top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) - - elements_to_discard = probs < top_k_cutoff - logits.masked_fill_(elements_to_discard, -float("inf")) - - if p is not None: - cumprob = torch.cumsum(probs_sort, dim=-1) - top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) - top_p_mask[:, -1] = False # at least one - - top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) - top_p_cutoff = probs_sort.gather(-1, top_p_count) - elements_to_discard = probs < top_p_cutoff - logits.masked_fill_(elements_to_discard, -float("inf")) - - return logits diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py deleted file mode 100644 index 4c73d6c92..000000000 --- a/vllm/v1/worker/tpu_worker.py +++ /dev/null @@ -1,18 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A TPU worker class.""" - -from typing import TypeVar - -from vllm.logger import init_logger -from vllm.platforms.tpu import USE_TPU_INFERENCE - -logger = init_logger(__name__) - -_R = TypeVar("_R") - -# TODO(weiyulin) Remove this file after adding an official way to use hardware plugin -if USE_TPU_INFERENCE: - from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker - - TPUWorker = TpuInferenceWorker # type: ignore