Migrate logits computation and gather to model_runner (#3233)

This commit is contained in:
Roy
2024-03-21 07:25:01 +08:00
committed by GitHub
parent 6e435de766
commit f1c0fc3919
35 changed files with 576 additions and 305 deletions

View File

@@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
@@ -300,11 +301,15 @@ class Qwen2ForCausalLM(nn.Module):
self.linear_method = linear_method
self.model = Qwen2Model(config, linear_method)
if not config.tie_word_embeddings:
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size)
self.lm_head_weight = self.lm_head.weight
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@@ -317,17 +322,18 @@ class Qwen2ForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
else:
lm_head_weight = self.lm_head.weight
next_tokens = self.sampler(lm_head_weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,