[Chore] Remove Sampler from Model Code (#17084)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -23,7 +23,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@@ -336,7 +335,6 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
if self.config.tie_word_embeddings:
|
||||
self.output.weight = self.model.tok_embeddings.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
@@ -363,14 +361,6 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
@@ -423,7 +413,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
prefix=prefix,
|
||||
model_type=model_type)
|
||||
|
||||
for attr in ("output", "logits_processor", "sampler"):
|
||||
for attr in ("output", "logits_processor"):
|
||||
delattr(self, attr)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
Reference in New Issue
Block a user