Migrate logits computation and gather to model_runner (#3233)
This commit is contained in:
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
|
||||
@@ -254,7 +255,9 @@ class Starcoder2ForCausalLM(nn.Module):
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
)
|
||||
self.lm_head_weight = self.lm_head.weight
|
||||
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -267,13 +270,18 @@ class Starcoder2ForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: Optional[torch.Tensor],
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
Reference in New Issue
Block a user