[Bugfix] Fix Triton FusedMoE LoRA (#30585)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -69,41 +69,54 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
|
||||
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
|
||||
|
||||
|
||||
def test_gpt_oss_lora(gptoss20b_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
max_lora_rank=8,
|
||||
max_num_seqs=2,
|
||||
max_num_batched_tokens=2048,
|
||||
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
|
||||
cudagraph_specialize_lora=False,
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("mxfp4_use_marlin", [True, False])
|
||||
def test_gpt_oss_lora(
|
||||
monkeypatch: pytest.MonkeyPatch, gptoss20b_lora_files, mxfp4_use_marlin
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_MXFP4_USE_MARLIN", "1" if mxfp4_use_marlin else "0")
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
max_lora_rank=8,
|
||||
max_num_seqs=2,
|
||||
max_num_batched_tokens=2048,
|
||||
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
|
||||
cudagraph_specialize_lora=False,
|
||||
),
|
||||
)
|
||||
|
||||
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
|
||||
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
|
||||
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
|
||||
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("fully_sharded_loras", [False, True])
|
||||
def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=2,
|
||||
max_num_seqs=2,
|
||||
max_num_batched_tokens=2048,
|
||||
tensor_parallel_size=2,
|
||||
gpu_memory_utilization=0.8,
|
||||
fully_sharded_loras=fully_sharded_loras,
|
||||
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
|
||||
cudagraph_specialize_lora=False,
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("mxfp4_use_marlin", [True, False])
|
||||
def test_gpt_oss_lora_tp2(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
gptoss20b_lora_files,
|
||||
fully_sharded_loras,
|
||||
mxfp4_use_marlin,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_MXFP4_USE_MARLIN", "1" if mxfp4_use_marlin else "0")
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=2,
|
||||
max_num_seqs=2,
|
||||
max_num_batched_tokens=2048,
|
||||
tensor_parallel_size=2,
|
||||
gpu_memory_utilization=0.8,
|
||||
fully_sharded_loras=fully_sharded_loras,
|
||||
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
|
||||
cudagraph_specialize_lora=False,
|
||||
),
|
||||
)
|
||||
|
||||
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
|
||||
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
|
||||
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
|
||||
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
|
||||
|
||||
Reference in New Issue
Block a user