[V1][TPU] Speed up top-k on TPU by using torch.topk (#15242)
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
This commit is contained in:
@@ -95,6 +95,7 @@ if TYPE_CHECKING:
|
||||
VLLM_DP_MASTER_PORT: int = 0
|
||||
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
||||
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
||||
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@@ -623,6 +624,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# an environment with potentially malicious users.
|
||||
"VLLM_V0_USE_OUTLINES_CACHE":
|
||||
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
|
||||
|
||||
# If set, disables TPU-specific optimization for top-k & top-p sampling
|
||||
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
|
||||
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
|
||||
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
Reference in New Issue
Block a user