[Bugfix] Fix broadcasting logic for multi_modal_kwargs (#6836)

This commit is contained in:
Cyrus Leung
2024-07-31 10:38:45 +08:00
committed by GitHub
parent da1f7cc12a
commit f230cc2ca6
16 changed files with 254 additions and 211 deletions

View File

@@ -15,6 +15,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.multimodal import MultiModalInputs
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
@@ -323,7 +324,8 @@ class TP1DraftModelRunner(ModelRunner):
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**multi_modal_kwargs,
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
device=self.device),
)
# Compute the logits.