[Minor][Spec Decode] Remove compiled_softmax (#15416)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -1,30 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def compiled_softmax(
|
|
||||||
logits: torch.Tensor,
|
|
||||||
temperature: Union[float, torch.Tensor] = 1.0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Faster softmax kernel generated by torch.compile.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits: [n, vocab_size]
|
|
||||||
temperature: [n] or float
|
|
||||||
"""
|
|
||||||
# NOTE(woosuk): Avoid recompilation by marking the first dim as dynamic.
|
|
||||||
torch._dynamo.mark_dynamic(logits, index=0)
|
|
||||||
if isinstance(temperature, torch.Tensor):
|
|
||||||
torch._dynamo.mark_dynamic(temperature, index=0)
|
|
||||||
return _softmax(logits, temperature)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.compile
|
|
||||||
def _softmax(
|
|
||||||
logits: torch.Tensor,
|
|
||||||
temperature: Union[float, torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
logits = logits / temperature
|
|
||||||
return torch.softmax(logits, dim=-1, dtype=torch.float32)
|
|
||||||
@@ -9,7 +9,6 @@ import triton.language as tl
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||||
from vllm.v1.sample.ops.utils import compiled_softmax
|
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -275,8 +274,7 @@ def compute_probs(
|
|||||||
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
|
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
|
||||||
# which is slow for large vocab sizes. This may cause performance issues.
|
# which is slow for large vocab sizes. This may cause performance issues.
|
||||||
logits = apply_top_k_top_p(logits, top_k, top_p)
|
logits = apply_top_k_top_p(logits, top_k, top_p)
|
||||||
|
output_prob = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
output_prob = compiled_softmax(logits)
|
|
||||||
return output_prob
|
return output_prob
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user