[torch.compile] auto infer dynamic_arg_dims from type annotation (#9589)

This commit is contained in:
youkaichao
2024-10-22 13:43:37 -07:00
committed by GitHub
parent cd5601ac37
commit 17c79f3c36
3 changed files with 65 additions and 19 deletions

View File

@@ -241,13 +241,7 @@ class Gemma2DecoderLayer(nn.Module):
return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": 0,
"inputs_embeds": 0,
"intermediate_tensors": 0,
})
@support_torch_compile
class Gemma2Model(nn.Module):
def __init__(