[torch.compile] auto infer dynamic_arg_dims from type annotation (#9589)

This commit is contained in:
youkaichao
2024-10-22 13:43:37 -07:00
committed by GitHub
parent cd5601ac37
commit 17c79f3c36
3 changed files with 65 additions and 19 deletions

View File

@@ -268,13 +268,7 @@ class LlamaDecoderLayer(nn.Module):
return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": 0,
"inputs_embeds": 0,
"intermediate_tensors": 0,
})
@support_torch_compile
class LlamaModel(nn.Module):
def __init__(