[bugfix] fix aria model and add torch.compile (#10645)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -20,7 +20,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -273,7 +273,11 @@ class LlamaDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -299,10 +303,10 @@ class LlamaModel(nn.Module):
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: LlamaDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
lambda prefix: layer_type(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
|
||||
Reference in New Issue
Block a user