[V1][Spec Decode] Fix greedy temperature detection after sampler refactor (#27077)

Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com>
Co-authored-by: Pradyun Ramadorai <pradyunr@amazon.com>
This commit is contained in:
Pradyun92
2025-10-17 16:27:47 -04:00
committed by GitHub
parent d29483b58a
commit acedc74b1a
5 changed files with 22 additions and 6 deletions

View File

@@ -15,7 +15,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
GREEDY_TEMPERATURE: tl.constexpr = -1 GREEDY_TEMPERATURE: tl.constexpr = 0
# Maximum number of speculative draft tokens allowed per request in a single # Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases. # step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 128 MAX_SPEC_LEN = 128

View File

@@ -30,6 +30,7 @@ class TPUSupportedSamplingMetadata:
top_p: torch.Tensor = None top_p: torch.Tensor = None
all_greedy: bool = True all_greedy: bool = True
all_random: bool = False
# Whether logprobs are to be gathered in this batch of request. To balance # Whether logprobs are to be gathered in this batch of request. To balance
# out compile time and runtime, a fixed `max_number_logprobs` value is used # out compile time and runtime, a fixed `max_number_logprobs` value is used
@@ -110,6 +111,7 @@ class TPUSupportedSamplingMetadata:
xla_device xla_device
), ),
all_greedy=input_batch.all_greedy, all_greedy=input_batch.all_greedy,
all_random=input_batch.all_random,
# TODO enable more and avoid returning None values # TODO enable more and avoid returning None values
top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(xla_device), top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(xla_device),
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(xla_device), top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(xla_device),

View File

@@ -40,7 +40,11 @@ class Sampler(nn.Module):
self, self,
logits: torch.Tensor, logits: torch.Tensor,
temp: torch.Tensor, temp: torch.Tensor,
all_random: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# Avoid division by zero for greedy sampling (temperature ~ 0.0).
if not all_random:
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
return logits.div_(temp.unsqueeze(dim=1)) return logits.div_(temp.unsqueeze(dim=1))
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
@@ -56,7 +60,9 @@ class Sampler(nn.Module):
assert sampling_metadata.temperature is not None assert sampling_metadata.temperature is not None
# Apply temperature. # Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature) logits = self.apply_temperature(
logits, sampling_metadata.temperature, sampling_metadata.all_random
)
# Apply min_p. # Apply min_p.
if sampling_metadata.min_p is not None: if sampling_metadata.min_p is not None:

View File

@@ -37,6 +37,7 @@ from vllm.v1.attention.backends.utils import (
) )
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@@ -1140,8 +1141,15 @@ def compute_probs_and_sample_next_token(
next_token_ids = logits.argmax(dim=-1) next_token_ids = logits.argmax(dim=-1)
return next_token_ids, probs return next_token_ids, probs
is_greedy = sampling_metadata.temperature == -1 assert sampling_metadata.temperature is not None
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
# Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
# consistent with sampler.py's _SAMPLING_EPS threshold
temperature = sampling_metadata.temperature
# Avoid division by zero if there are greedy requests.
if not sampling_metadata.all_random:
is_greedy = temperature < _SAMPLING_EPS
temperature = torch.where(is_greedy, 1.0, temperature)
logits.div_(temperature.view(-1, 1)) logits.div_(temperature.view(-1, 1))
probs = logits.softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32)

View File

@@ -215,8 +215,8 @@ class InputBatch:
sampling_params = request.sampling_params sampling_params = request.sampling_params
assert sampling_params is not None, "pooling requests not supported yet" assert sampling_params is not None, "pooling requests not supported yet"
if sampling_params.sampling_type == SamplingType.GREEDY: if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero. # Should avoid division by zero later when apply_temperature.
self.temperature_cpu[req_index] = -1.0 self.temperature_cpu[req_index] = 0.0
self.greedy_reqs.add(req_id) self.greedy_reqs.add(req_id)
else: else:
self.temperature_cpu[req_index] = sampling_params.temperature self.temperature_cpu[req_index] = sampling_params.temperature