[Neuron] Support inference with transformers-neuronx (#2569)
This commit is contained in:
@@ -131,9 +131,11 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
||||
cleanup()
|
||||
get_model_old = get_model
|
||||
|
||||
def get_model_patched(model_config, device_config, lora_config=None):
|
||||
return get_model_old(model_config, device_config,
|
||||
LoRAConfig(max_loras=4, max_lora_rank=8))
|
||||
def get_model_patched(model_config, device_config, **kwargs):
|
||||
return get_model_old(model_config,
|
||||
device_config,
|
||||
lora_config=LoRAConfig(max_loras=4,
|
||||
max_lora_rank=8))
|
||||
|
||||
with patch("vllm.worker.model_runner.get_model", get_model_patched):
|
||||
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
|
||||
|
||||
Reference in New Issue
Block a user