[bugfix] fix aria model and add torch.compile (#10645)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-25 18:32:09 -08:00
committed by GitHub
parent 6e9ff050c8
commit 45ac4ff270
2 changed files with 14 additions and 28 deletions

View File

@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import torch
from torch import nn
@@ -273,7 +273,11 @@ class LlamaDecoderLayer(nn.Module):
@support_torch_compile
class LlamaModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
@@ -299,10 +303,10 @@ class LlamaModel(nn.Module):
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: LlamaDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
lambda prefix: layer_type(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank: