[Spec Decode] fix returning size mismatch on extract hidden states proposer (#38610)

Signed-off-by: Jaebok Lee <jaebok9541@naver.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
zzaebok
2026-04-10 04:39:39 +08:00
committed by GitHub
parent adaabb8a55
commit edee96519a

View File

@@ -145,7 +145,10 @@ class ExtractHiddenStatesProposer:
# Return the sampled tokens as "draft" tokens
# Shape: [batch_size, 1] to match num_speculative_tokens=1
return sampled_token_ids
# On decode steps with spec tokens, sampled_token_ids may have
# shape [batch_size, 2] (target + spec verification); slice to
# return only the target-sampled column.
return sampled_token_ids[:, :1]
def _get_slot_mapping(
self,