[torch.compile] support all attention backends (#10558)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-22 14:04:42 -08:00
committed by GitHub
parent db100c5cde
commit eebad39f26
77 changed files with 876 additions and 648 deletions

View File

@@ -14,8 +14,6 @@ from vllm.model_executor.models.internlm2 import (InternLM2Attention,
InternLM2MLP, InternLM2Model)
from vllm.sequence import IntermediateTensors
from .utils import make_layers, maybe_prefix
class InternLM2VEDecoderLayer(nn.Module):
@@ -105,17 +103,9 @@ class InternLM2VEDecoderLayer(nn.Module):
class InternLM2VEModel(InternLM2Model):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: InternLM2VEDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=InternLM2VEDecoderLayer)
def forward(
self,
@@ -159,7 +149,6 @@ class InternLM2VEModel(InternLM2Model):
class InternLM2VEForCausalLM(InternLM2ForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.model = InternLM2VEModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
super().__init__(vllm_config=vllm_config,
prefix=prefix,
model_type=InternLM2VEModel)