[torch.compile] auto infer dynamic_arg_dims from type annotation (#9589)
This commit is contained in:
@@ -241,13 +241,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
"positions": 0,
|
||||
"inputs_embeds": 0,
|
||||
"intermediate_tensors": 0,
|
||||
})
|
||||
@support_torch_compile
|
||||
class Gemma2Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user