[Chore] Remove Sampler from Model Code (#17084)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -58,7 +58,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
@@ -298,7 +297,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size, logit_scale)
|
||||
self.sampler = get_sampler()
|
||||
self.media_placeholder: int = self.config.media_placeholder_token_id
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_world_size = get_tensor_model_parallel_world_size()
|
||||
@@ -409,7 +407,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
) -> IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner from
|
||||
@@ -447,14 +445,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
sampling_metadata, **kwargs)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
config = self.config.text_config
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
|
||||
Reference in New Issue
Block a user