[Neuron] Support inference with transformers-neuronx (#2569)
This commit is contained in:
@@ -1,10 +1,18 @@
|
||||
"""Utils for model executor."""
|
||||
import random
|
||||
import importlib
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import DeviceConfig, ModelConfig
|
||||
|
||||
DEVICE_TO_MODEL_LOADER_MAP = {
|
||||
"cuda": "model_loader",
|
||||
"neuron": "neuron_model_loader",
|
||||
}
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
@@ -33,3 +41,12 @@ def set_weight_attrs(
|
||||
assert not hasattr(
|
||||
weight, key), (f"Overwriting existing tensor attribute: {key}")
|
||||
setattr(weight, key, value)
|
||||
|
||||
|
||||
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
model_loader_module = DEVICE_TO_MODEL_LOADER_MAP[device_config.device_type]
|
||||
imported_model_loader = importlib.import_module(
|
||||
f"vllm.model_executor.{model_loader_module}")
|
||||
get_model_fn = imported_model_loader.get_model
|
||||
return get_model_fn(model_config, device_config, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user