[SpecDec][Misc] Cleanup, remove bonus token logic. (#8701)
This commit is contained in:
@@ -11,20 +11,14 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
step.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
disable_bonus_tokens: bool = True,
|
||||
strict_mode: bool = False):
|
||||
def __init__(self, strict_mode: bool = False):
|
||||
"""Base class constructor.
|
||||
Args:
|
||||
disable_bonus_tokens: Whether or not to disable the bonus token.
|
||||
Require when bonus tokens will cause corrupt KV cache for
|
||||
proposal methods that require KV cache.
|
||||
strict_mode: Whether or not to perform shape/device/dtype checks
|
||||
during sampling. This catches correctness issues but adds
|
||||
nontrivial latency.
|
||||
"""
|
||||
super().__init__()
|
||||
self._disable_bonus_tokens = disable_bonus_tokens
|
||||
self._strict_mode = strict_mode
|
||||
|
||||
# NOTE: A "bonus token" is accepted iff all proposal tokens are
|
||||
@@ -111,13 +105,6 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
|
||||
bonus_token_ids, -1)
|
||||
|
||||
# We disable bonus tokens because it causes corrupt KV cache for
|
||||
# proposal methods that require KV cache. We can fix it by "prefilling"
|
||||
# the bonus token in the proposer. The following issue tracks the fix.
|
||||
# https://github.com/vllm-project/vllm/issues/4212
|
||||
if self._disable_bonus_tokens:
|
||||
output_with_bonus_tokens[:, -1] = -1
|
||||
|
||||
# Fill the recovered token ids.
|
||||
output.mul_(~after_false_mask).add_(
|
||||
substitute_token_ids.mul(after_false_mask))
|
||||
|
||||
Reference in New Issue
Block a user