[torch.compile] integration with compilation control (#9058)

This commit is contained in:
youkaichao
2024-10-10 12:39:36 -07:00
committed by GitHub
parent 78c0b4166c
commit e4d652ea3e
22 changed files with 404 additions and 98 deletions

View File

@@ -365,6 +365,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
input_ids = None
inputs_embeds = None
else:
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
@@ -375,10 +377,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
input_ids = None
else:
inputs_embeds = None
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,