[V0 Deprecation] Remove V0 sampling metadata (#25345)
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
@@ -9,7 +9,6 @@ from vllm.model_executor.models.llava import (LlavaDummyInputsBuilder,
|
||||
LlavaForConditionalGeneration,
|
||||
LlavaMultiModalProcessor,
|
||||
LlavaProcessingInfo)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
|
||||
@@ -18,11 +17,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
dummy_inputs=LlavaDummyInputsBuilder)
|
||||
class MyLlava(LlavaForConditionalGeneration):
|
||||
|
||||
def compute_logits(
|
||||
self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
|
||||
def compute_logits(self,
|
||||
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
# this dummy model always predicts the first token
|
||||
logits = super().compute_logits(hidden_states, sampling_metadata)
|
||||
logits = super().compute_logits(hidden_states)
|
||||
if logits is not None:
|
||||
logits.zero_()
|
||||
logits[:, 0] += 1.0
|
||||
|
||||
@@ -6,16 +6,14 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
|
||||
class MyOPTForCausalLM(OPTForCausalLM):
|
||||
|
||||
def compute_logits(
|
||||
self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
|
||||
def compute_logits(self,
|
||||
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
# this dummy model always predicts the first token
|
||||
logits = super().compute_logits(hidden_states, sampling_metadata)
|
||||
logits = super().compute_logits(hidden_states)
|
||||
if logits is not None:
|
||||
logits.zero_()
|
||||
logits[:, 0] += 1.0
|
||||
|
||||
Reference in New Issue
Block a user