[V1] Implement vLLM V1 [1/N] (#9289)
This commit is contained in:
@@ -48,14 +48,15 @@ class LogitsProcessor(nn.Module):
|
||||
self,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sampling_metadata: Optional[SamplingMetadata] = None,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if self.logits_as_input:
|
||||
logits = hidden_states
|
||||
else:
|
||||
hidden_states = _prune_hidden_states(hidden_states,
|
||||
sampling_metadata)
|
||||
if sampling_metadata is not None:
|
||||
hidden_states = _prune_hidden_states(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
# Get the logits for the next tokens.
|
||||
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
|
||||
@@ -69,7 +70,8 @@ class LogitsProcessor(nn.Module):
|
||||
logits *= self.scale
|
||||
|
||||
# Apply logits processors (if any).
|
||||
logits = _apply_logits_processors(logits, sampling_metadata)
|
||||
if sampling_metadata is not None:
|
||||
logits = _apply_logits_processors(logits, sampling_metadata)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
Reference in New Issue
Block a user