[bugfix] fix aria model and add torch.compile (#10645)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -29,7 +29,7 @@ from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaMLP,
|
|||||||
LlamaModel)
|
LlamaModel)
|
||||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
||||||
is_pp_missing_parameter,
|
is_pp_missing_parameter,
|
||||||
make_layers, maybe_prefix,
|
maybe_prefix,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.base import MultiModalInputs
|
from vllm.multimodal.base import MultiModalInputs
|
||||||
@@ -363,27 +363,9 @@ class AriaMoELMModel(LlamaModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
super().__init__(vllm_config=vllm_config,
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
cache_config = vllm_config.cache_config
|
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
|
|
||||||
# FIXME: this is a hack to disable the compilation of the model
|
|
||||||
self.do_not_compile = True
|
|
||||||
|
|
||||||
self.layers = None
|
|
||||||
|
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
||||||
config.num_hidden_layers,
|
|
||||||
lambda prefix: MoEDecoderLayer(
|
|
||||||
config=config,
|
|
||||||
cache_config=cache_config,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
),
|
layer_type=MoEDecoderLayer)
|
||||||
prefix=f"{prefix}.layers",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Adapted from LlamaModel.load_weights with the modification of adding
|
# Adapted from LlamaModel.load_weights with the modification of adding
|
||||||
# the expert weights mapping to `stacked_params_mapping`
|
# the expert weights mapping to `stacked_params_mapping`
|
||||||
|
|||||||
@@ -20,7 +20,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -273,7 +273,11 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class LlamaModel(nn.Module):
|
class LlamaModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
@@ -299,7 +303,7 @@ class LlamaModel(nn.Module):
|
|||||||
self.embed_tokens = PPMissingLayer()
|
self.embed_tokens = PPMissingLayer()
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: LlamaDecoderLayer(config=config,
|
lambda prefix: layer_type(config=config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix),
|
prefix=prefix),
|
||||||
|
|||||||
Reference in New Issue
Block a user