[Neuron] Support inference with transformers-neuronx (#2569)

This commit is contained in:
Liangfu Chen
2024-02-28 09:34:34 -08:00
committed by GitHub
parent e46fa5d52e
commit 3b7178cfa4
18 changed files with 516 additions and 42 deletions

View File

@@ -131,9 +131,11 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()
get_model_old = get_model
def get_model_patched(model_config, device_config, lora_config=None):
return get_model_old(model_config, device_config,
LoRAConfig(max_loras=4, max_lora_rank=8))
def get_model_patched(model_config, device_config, **kwargs):
return get_model_old(model_config,
device_config,
lora_config=LoRAConfig(max_loras=4,
max_lora_rank=8))
with patch("vllm.worker.model_runner.get_model", get_model_patched):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)