[torch.compile] generic decorators (#9258)

This commit is contained in:
youkaichao
2024-10-10 15:54:23 -07:00
committed by GitHub
parent a78c6ba7c8
commit e00c094f15
3 changed files with 74 additions and 34 deletions

View File

@@ -21,7 +21,7 @@ from torch import nn
from transformers import Gemma2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_compile_llama_style
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
@@ -239,7 +239,13 @@ class Gemma2DecoderLayer(nn.Module):
return hidden_states, residual
@support_compile_llama_style
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": 0,
"inputs_embeds": 0,
"intermediate_tensors": 0,
})
class Gemma2Model(nn.Module):
def __init__(