[Minor][Spec Decode] Remove compiled_softmax (#15416)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -9,7 +9,6 @@ import triton.language as tl
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.sample.ops.utils import compiled_softmax
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -275,8 +274,7 @@ def compute_probs(
|
||||
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
|
||||
# which is slow for large vocab sizes. This may cause performance issues.
|
||||
logits = apply_top_k_top_p(logits, top_k, top_p)
|
||||
|
||||
output_prob = compiled_softmax(logits)
|
||||
output_prob = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return output_prob
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user