[Spec Decode][V0] Fix spec decode correctness test in V0 eagle/medusa (#18175)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
wwl2755
2025-05-18 21:49:46 -05:00
committed by GitHub
parent d1211f8794
commit 9da1095daf
4 changed files with 21 additions and 3 deletions

View File

@@ -164,7 +164,14 @@ class Medusa(nn.Module):
self,
previous_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> list[SamplerOutput]:
) -> Optional[list[SamplerOutput]]:
# During preemption, we may receive an empty tensor (batch_size=0)
if previous_hidden_states.size(0) == 0:
# Return None to signal the Top1Proposer that no proposals
# were generated for this batch, allowing it to handle this
# special case appropriately
return None
return self.sample(
logits=self.compute_logits(
hidden_states=self.forward(previous_hidden_states),