[V1] Implement vLLM V1 [1/N] (#9289)

This commit is contained in:
Woosuk Kwon
2024-10-22 01:24:07 -07:00
committed by GitHub
parent 3ddbe25502
commit 6c5af09b39
27 changed files with 3058 additions and 180 deletions

View File

@@ -48,14 +48,15 @@ class LogitsProcessor(nn.Module):
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_metadata: Optional[SamplingMetadata] = None,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
if sampling_metadata is not None:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
@@ -69,7 +70,8 @@ class LogitsProcessor(nn.Module):
logits *= self.scale
# Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata)
if sampling_metadata is not None:
logits = _apply_logits_processors(logits, sampling_metadata)
return logits