diff --git a/vllm/v1/worker/gpu/mm/encoder_runner.py b/vllm/v1/worker/gpu/mm/encoder_runner.py index c0676d05d..e62c2ef63 100644 --- a/vllm/v1/worker/gpu/mm/encoder_runner.py +++ b/vllm/v1/worker/gpu/mm/encoder_runner.py @@ -13,12 +13,14 @@ from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs class EncoderRunner: def __init__( self, + model: SupportsMultiModal, max_num_tokens: int, hidden_size: int, encoder_cache: EncoderCache, dtype: torch.dtype, device: torch.device, ): + self.model = model self.max_num_tokens = max_num_tokens self.hidden_size = hidden_size self.encoder_cache = encoder_cache @@ -48,25 +50,17 @@ class EncoderRunner: @torch.inference_mode() def execute_mm_encoder( self, - model: SupportsMultiModal, - mm_hashes: list[str], mm_kwargs: list[tuple[str, MultiModalKwargsItem]], ) -> list[torch.Tensor]: - if not mm_hashes: - return [] - encoder_outputs: list[torch.Tensor] = [] for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( mm_kwargs, device=self.device, pin_memory=False ): - curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) + curr_group_outputs = self.model.embed_multimodal(**mm_kwargs_group) sanity_check_mm_encoder_outputs( curr_group_outputs, expected_num_items=num_items ) encoder_outputs.extend(curr_group_outputs) - - # Cache the encoder outputs by mm_hash - self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs)) return encoder_outputs def gather_mm_embeddings( @@ -146,12 +140,11 @@ class EncoderRunner: @torch.inference_mode() def get_inputs_embeds( self, - model: SupportsMultiModal, input_ids: torch.Tensor, mm_embeds: list[torch.Tensor], is_mm_embed: torch.Tensor, ) -> torch.Tensor: - x = model.embed_input_ids( + x = self.model.embed_input_ids( input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed ) # Copy to the pre-allocated buffer for CUDA graphs. diff --git a/vllm/v1/worker/gpu/model_states/default.py b/vllm/v1/worker/gpu/model_states/default.py index d52f7d0ec..e27916b40 100644 --- a/vllm/v1/worker/gpu/model_states/default.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -41,7 +41,9 @@ class DefaultModelState(ModelState): if self.supports_mm_inputs: assert encoder_cache is not None + self.encoder_cache = encoder_cache self.encoder_runner = EncoderRunner( + model=self.model, max_num_tokens=self.max_num_tokens, hidden_size=self.inputs_embeds_size, encoder_cache=encoder_cache, @@ -82,7 +84,12 @@ class DefaultModelState(ModelState): mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs( scheduled_encoder_inputs ) - self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs) + if mm_kwargs: + # Execute the multimodal encoder. + encoder_outputs = self.encoder_runner.execute_mm_encoder(mm_kwargs) + # Cache the encoder outputs by mm_hash + self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs)) + mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings( input_batch.req_ids, input_batch.num_tokens, @@ -92,7 +99,7 @@ class DefaultModelState(ModelState): req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np], ) inputs_embeds = self.encoder_runner.get_inputs_embeds( - self.model, input_batch.input_ids, mm_embeds, is_mm_embed + input_batch.input_ids, mm_embeds, is_mm_embed ) return inputs_embeds[: input_batch.num_tokens_after_padding]