Migrate logits computation and gather to model_runner (#3233)

This commit is contained in:
Roy
2024-03-21 07:25:01 +08:00
committed by GitHub
parent 6e435de766
commit f1c0fc3919
35 changed files with 576 additions and 305 deletions

View File

@@ -19,6 +19,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
@@ -230,7 +231,8 @@ class QWenLMHeadModel(nn.Module):
self.linear_method = linear_method
self.transformer = QWenModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@@ -243,13 +245,18 @@ class QWenLMHeadModel(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,