[Spec Decode][V0] Fix spec decode correctness test in V0 eagle/medusa (#18175)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
@@ -146,6 +146,17 @@ class EAGLE(nn.Module):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
|
||||
# Handle both empty previous_hidden_states
|
||||
# and mismatched batch size
|
||||
batch_size = inputs_embeds.size(0)
|
||||
if previous_hidden_states.size(0) == 0 or \
|
||||
previous_hidden_states.size(0) != batch_size:
|
||||
hidden_dim = self.config.model.hidden_size
|
||||
device = inputs_embeds.device
|
||||
# Create zero tensor with matching batch size
|
||||
previous_hidden_states = \
|
||||
torch.zeros(batch_size, hidden_dim, device=device)
|
||||
|
||||
if self.add_para_norm:
|
||||
inputs_embeds = torch.cat([
|
||||
self.enorm(inputs_embeds),
|
||||
|
||||
@@ -164,7 +164,14 @@ class Medusa(nn.Module):
|
||||
self,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> list[SamplerOutput]:
|
||||
) -> Optional[list[SamplerOutput]]:
|
||||
# During preemption, we may receive an empty tensor (batch_size=0)
|
||||
if previous_hidden_states.size(0) == 0:
|
||||
# Return None to signal the Top1Proposer that no proposals
|
||||
# were generated for this batch, allowing it to handle this
|
||||
# special case appropriately
|
||||
return None
|
||||
|
||||
return self.sample(
|
||||
logits=self.compute_logits(
|
||||
hidden_states=self.forward(previous_hidden_states),
|
||||
|
||||
Reference in New Issue
Block a user