[SpecDec][Misc] Cleanup, remove bonus token logic. (#8701)

This commit is contained in:
Lily Liu
2024-09-22 12:34:14 -07:00
committed by GitHub
parent 5b59532760
commit c6bd70d772
7 changed files with 33 additions and 115 deletions

View File

@@ -11,20 +11,14 @@ class SpecDecodeBaseSampler(nn.Module):
step.
"""
def __init__(self,
disable_bonus_tokens: bool = True,
strict_mode: bool = False):
def __init__(self, strict_mode: bool = False):
"""Base class constructor.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super().__init__()
self._disable_bonus_tokens = disable_bonus_tokens
self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
@@ -111,13 +105,6 @@ class SpecDecodeBaseSampler(nn.Module):
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
if self._disable_bonus_tokens:
output_with_bonus_tokens[:, -1] = -1
# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
substitute_token_ids.mul(after_false_mask))