diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index c69df6e61..6c27fedc6 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -736,7 +736,23 @@ def cast_overflow_tensors( return tensors -def fast_topk(values, topk, dim): +def fast_topk(values: torch.Tensor, topk: int, + dim: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Optimized topk implementation that uses torch.max for k=1 case. + + This function provides better performance for the common case of k=1 + by using torch.max instead of the more general torch.topk. + + Args: + values: Input tensor to find top-k values from + topk: Number of top values to return (k). Must be > 0. + dim: Dimension along which to compute topk + + Returns: + Tuple of (values, indices) where values are the top-k values + and indices are their corresponding indices in the input tensor + """ if topk == 1: # Use max along the specified dimension to get both value and index return torch.max(values, dim=dim, keepdim=True)