[torch.compile] rework test plans (#9866)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-10-31 22:20:17 -07:00
committed by GitHub
parent 37a4947dcd
commit 566cd27797
4 changed files with 226 additions and 31 deletions

View File

@@ -493,13 +493,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
:class:`LlavaImageInputs`
"""
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
@@ -511,7 +507,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
else:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
input_ids = None
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,