[torch.compile] auto infer dynamic_arg_dims from type annotation (#9589)
This commit is contained in:
@@ -268,13 +268,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
"positions": 0,
|
||||
"inputs_embeds": 0,
|
||||
"intermediate_tensors": 0,
|
||||
})
|
||||
@support_torch_compile
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user