[V1][Spec Decode] Fix greedy temperature detection after sampler refactor (#27077)
Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com> Co-authored-by: Pradyun Ramadorai <pradyunr@amazon.com>
This commit is contained in:
@@ -15,7 +15,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
|
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
|
||||||
GREEDY_TEMPERATURE: tl.constexpr = -1
|
GREEDY_TEMPERATURE: tl.constexpr = 0
|
||||||
# Maximum number of speculative draft tokens allowed per request in a single
|
# Maximum number of speculative draft tokens allowed per request in a single
|
||||||
# step. This value is chosen to be large enough to handle typical use cases.
|
# step. This value is chosen to be large enough to handle typical use cases.
|
||||||
MAX_SPEC_LEN = 128
|
MAX_SPEC_LEN = 128
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class TPUSupportedSamplingMetadata:
|
|||||||
top_p: torch.Tensor = None
|
top_p: torch.Tensor = None
|
||||||
|
|
||||||
all_greedy: bool = True
|
all_greedy: bool = True
|
||||||
|
all_random: bool = False
|
||||||
|
|
||||||
# Whether logprobs are to be gathered in this batch of request. To balance
|
# Whether logprobs are to be gathered in this batch of request. To balance
|
||||||
# out compile time and runtime, a fixed `max_number_logprobs` value is used
|
# out compile time and runtime, a fixed `max_number_logprobs` value is used
|
||||||
@@ -110,6 +111,7 @@ class TPUSupportedSamplingMetadata:
|
|||||||
xla_device
|
xla_device
|
||||||
),
|
),
|
||||||
all_greedy=input_batch.all_greedy,
|
all_greedy=input_batch.all_greedy,
|
||||||
|
all_random=input_batch.all_random,
|
||||||
# TODO enable more and avoid returning None values
|
# TODO enable more and avoid returning None values
|
||||||
top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(xla_device),
|
top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(xla_device),
|
||||||
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(xla_device),
|
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(xla_device),
|
||||||
|
|||||||
@@ -40,7 +40,11 @@ class Sampler(nn.Module):
|
|||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
temp: torch.Tensor,
|
temp: torch.Tensor,
|
||||||
|
all_random: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# Avoid division by zero for greedy sampling (temperature ~ 0.0).
|
||||||
|
if not all_random:
|
||||||
|
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
||||||
return logits.div_(temp.unsqueeze(dim=1))
|
return logits.div_(temp.unsqueeze(dim=1))
|
||||||
|
|
||||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -56,7 +60,9 @@ class Sampler(nn.Module):
|
|||||||
assert sampling_metadata.temperature is not None
|
assert sampling_metadata.temperature is not None
|
||||||
|
|
||||||
# Apply temperature.
|
# Apply temperature.
|
||||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
logits = self.apply_temperature(
|
||||||
|
logits, sampling_metadata.temperature, sampling_metadata.all_random
|
||||||
|
)
|
||||||
|
|
||||||
# Apply min_p.
|
# Apply min_p.
|
||||||
if sampling_metadata.min_p is not None:
|
if sampling_metadata.min_p is not None:
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
from vllm.v1.sample.sampler import _SAMPLING_EPS
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.utils import CpuGpuBuffer
|
from vllm.v1.utils import CpuGpuBuffer
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
@@ -1140,8 +1141,15 @@ def compute_probs_and_sample_next_token(
|
|||||||
next_token_ids = logits.argmax(dim=-1)
|
next_token_ids = logits.argmax(dim=-1)
|
||||||
return next_token_ids, probs
|
return next_token_ids, probs
|
||||||
|
|
||||||
is_greedy = sampling_metadata.temperature == -1
|
assert sampling_metadata.temperature is not None
|
||||||
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
|
|
||||||
|
# Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
|
||||||
|
# consistent with sampler.py's _SAMPLING_EPS threshold
|
||||||
|
temperature = sampling_metadata.temperature
|
||||||
|
# Avoid division by zero if there are greedy requests.
|
||||||
|
if not sampling_metadata.all_random:
|
||||||
|
is_greedy = temperature < _SAMPLING_EPS
|
||||||
|
temperature = torch.where(is_greedy, 1.0, temperature)
|
||||||
logits.div_(temperature.view(-1, 1))
|
logits.div_(temperature.view(-1, 1))
|
||||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
|
|
||||||
|
|||||||
@@ -215,8 +215,8 @@ class InputBatch:
|
|||||||
sampling_params = request.sampling_params
|
sampling_params = request.sampling_params
|
||||||
assert sampling_params is not None, "pooling requests not supported yet"
|
assert sampling_params is not None, "pooling requests not supported yet"
|
||||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||||
# Avoid later division by zero.
|
# Should avoid division by zero later when apply_temperature.
|
||||||
self.temperature_cpu[req_index] = -1.0
|
self.temperature_cpu[req_index] = 0.0
|
||||||
self.greedy_reqs.add(req_id)
|
self.greedy_reqs.add(req_id)
|
||||||
else:
|
else:
|
||||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||||
|
|||||||
Reference in New Issue
Block a user