[V1] Make v1 more testable (#9888)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
@@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@@ -137,7 +137,7 @@ class MLPSpeculator(nn.Module):
|
||||
self.config = config
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||
config.vocab_size, 1.0)
|
||||
self.sampler = Sampler()
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def generate_proposals(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user