[Bugfix] Fix broadcasting logic for multi_modal_kwargs (#6836)
This commit is contained in:
@@ -15,6 +15,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalInputs
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SamplerOutput)
|
||||
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
|
||||
@@ -323,7 +324,8 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**multi_modal_kwargs,
|
||||
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
|
||||
Reference in New Issue
Block a user