[Bugfix] Fix Triton FusedMoE LoRA (#30585)

Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
Xin Yang
2026-01-09 03:46:59 -08:00
committed by GitHub
parent 1a19e9cd87
commit e7b68f4d6c
3 changed files with 51 additions and 35 deletions

View File

@@ -69,41 +69,54 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
def test_gpt_oss_lora(gptoss20b_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=8,
max_num_seqs=2,
max_num_batched_tokens=2048,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
@pytest.mark.parametrize("mxfp4_use_marlin", [True, False])
def test_gpt_oss_lora(
monkeypatch: pytest.MonkeyPatch, gptoss20b_lora_files, mxfp4_use_marlin
):
with monkeypatch.context() as m:
m.setenv("VLLM_MXFP4_USE_MARLIN", "1" if mxfp4_use_marlin else "0")
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=8,
max_num_seqs=2,
max_num_batched_tokens=2048,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("fully_sharded_loras", [False, True])
def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=2,
max_num_seqs=2,
max_num_batched_tokens=2048,
tensor_parallel_size=2,
gpu_memory_utilization=0.8,
fully_sharded_loras=fully_sharded_loras,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
@pytest.mark.parametrize("mxfp4_use_marlin", [True, False])
def test_gpt_oss_lora_tp2(
monkeypatch: pytest.MonkeyPatch,
gptoss20b_lora_files,
fully_sharded_loras,
mxfp4_use_marlin,
):
with monkeypatch.context() as m:
m.setenv("VLLM_MXFP4_USE_MARLIN", "1" if mxfp4_use_marlin else "0")
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=2,
max_num_seqs=2,
max_num_batched_tokens=2048,
tensor_parallel_size=2,
gpu_memory_utilization=0.8,
fully_sharded_loras=fully_sharded_loras,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)