[Model Runner V2] Minor refactoring for EncoderRunner (#35628)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -13,12 +13,14 @@ from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
|
||||
class EncoderRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
encoder_cache: EncoderCache,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.model = model
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.hidden_size = hidden_size
|
||||
self.encoder_cache = encoder_cache
|
||||
@@ -48,25 +50,17 @@ class EncoderRunner:
|
||||
@torch.inference_mode()
|
||||
def execute_mm_encoder(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
mm_hashes: list[str],
|
||||
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
|
||||
) -> list[torch.Tensor]:
|
||||
if not mm_hashes:
|
||||
return []
|
||||
|
||||
encoder_outputs: list[torch.Tensor] = []
|
||||
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs, device=self.device, pin_memory=False
|
||||
):
|
||||
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
|
||||
curr_group_outputs = self.model.embed_multimodal(**mm_kwargs_group)
|
||||
sanity_check_mm_encoder_outputs(
|
||||
curr_group_outputs, expected_num_items=num_items
|
||||
)
|
||||
encoder_outputs.extend(curr_group_outputs)
|
||||
|
||||
# Cache the encoder outputs by mm_hash
|
||||
self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
|
||||
return encoder_outputs
|
||||
|
||||
def gather_mm_embeddings(
|
||||
@@ -146,12 +140,11 @@ class EncoderRunner:
|
||||
@torch.inference_mode()
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
input_ids: torch.Tensor,
|
||||
mm_embeds: list[torch.Tensor],
|
||||
is_mm_embed: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x = model.embed_input_ids(
|
||||
x = self.model.embed_input_ids(
|
||||
input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
|
||||
)
|
||||
# Copy to the pre-allocated buffer for CUDA graphs.
|
||||
|
||||
@@ -41,7 +41,9 @@ class DefaultModelState(ModelState):
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
assert encoder_cache is not None
|
||||
self.encoder_cache = encoder_cache
|
||||
self.encoder_runner = EncoderRunner(
|
||||
model=self.model,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
hidden_size=self.inputs_embeds_size,
|
||||
encoder_cache=encoder_cache,
|
||||
@@ -82,7 +84,12 @@ class DefaultModelState(ModelState):
|
||||
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
|
||||
scheduled_encoder_inputs
|
||||
)
|
||||
self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
|
||||
if mm_kwargs:
|
||||
# Execute the multimodal encoder.
|
||||
encoder_outputs = self.encoder_runner.execute_mm_encoder(mm_kwargs)
|
||||
# Cache the encoder outputs by mm_hash
|
||||
self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
|
||||
|
||||
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
|
||||
input_batch.req_ids,
|
||||
input_batch.num_tokens,
|
||||
@@ -92,7 +99,7 @@ class DefaultModelState(ModelState):
|
||||
req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
|
||||
)
|
||||
inputs_embeds = self.encoder_runner.get_inputs_embeds(
|
||||
self.model, input_batch.input_ids, mm_embeds, is_mm_embed
|
||||
input_batch.input_ids, mm_embeds, is_mm_embed
|
||||
)
|
||||
return inputs_embeds[: input_batch.num_tokens_after_padding]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user