[V1][TPU] Speed up top-k on TPU by using torch.topk (#15242)

Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
This commit is contained in:
Hyesoo Yang
2025-03-20 19:19:40 -07:00
committed by GitHub
parent 6edbfa924d
commit 47195057e9
3 changed files with 29 additions and 4 deletions

View File

@@ -95,6 +95,7 @@ if TYPE_CHECKING:
VLLM_DP_MASTER_PORT: int = 0
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
def get_default_cache_root():
@@ -623,6 +624,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# an environment with potentially malicious users.
"VLLM_V0_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
# If set, disables TPU-specific optimization for top-k & top-p sampling
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
}
# end-env-vars-definition