[Model] Remove model sampler (#21059)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -30,7 +30,6 @@ from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@@ -161,8 +160,6 @@ class MiMoMTP(nn.Module):
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size)
|
||||
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -187,14 +184,6 @@ class MiMoMTP(nn.Module):
|
||||
return self.model.compute_logits(hidden_states, self.lm_head,
|
||||
sampling_metadata, spec_step_idx)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
|
||||
Reference in New Issue
Block a user