[ROCm][CI] Fix test_cudagraph_mode failure in AMD CI (#29367)
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
@@ -35,14 +35,22 @@ def temporary_environ(env_vars):
|
||||
|
||||
# test attention backend and cudagraph_mode combo
|
||||
# (backend_name, cudagraph_mode, supported)
|
||||
combo_cases_1 = [
|
||||
("FA3", "FULL", True),
|
||||
("FA3", "FULL_AND_PIECEWISE", True),
|
||||
("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
|
||||
("FA2", "FULL_AND_PIECEWISE", True),
|
||||
("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
|
||||
("FlashInfer", "FULL_AND_PIECEWISE", True),
|
||||
]
|
||||
if current_platform.is_rocm():
|
||||
combo_cases_1 = [
|
||||
("RocmAttn", "FULL", True),
|
||||
("RocmAttn", "FULL_AND_PIECEWISE", True),
|
||||
("TritonAttn", "FULL", True),
|
||||
("TritonAttn", "FULL_AND_PIECEWISE", True),
|
||||
]
|
||||
else:
|
||||
combo_cases_1 = [
|
||||
("FA3", "FULL", True),
|
||||
("FA3", "FULL_AND_PIECEWISE", True),
|
||||
("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
|
||||
("FA2", "FULL_AND_PIECEWISE", True),
|
||||
("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
|
||||
("FlashInfer", "FULL_AND_PIECEWISE", True),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend_name, cudagraph_mode, supported", combo_cases_1)
|
||||
@@ -92,18 +100,32 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
|
||||
|
||||
# test cudagraph_mode with different compilation mode.
|
||||
# (backend_name, cudagraph_mode, compilation_mode, supported)
|
||||
combo_cases_2 = [
|
||||
("FA2", "FULL", CompilationMode.NONE, True),
|
||||
("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
|
||||
("FA2", "PIECEWISE", CompilationMode.NONE, False),
|
||||
("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
||||
("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
|
||||
("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
||||
("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
|
||||
("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
|
||||
("FA2", "NONE", CompilationMode.NONE, True),
|
||||
("FA2", "NONE", CompilationMode.VLLM_COMPILE, True),
|
||||
]
|
||||
if current_platform.is_rocm():
|
||||
combo_cases_2 = [
|
||||
("RocmAttn", "FULL", CompilationMode.NONE, True),
|
||||
("RocmAttn", "FULL", CompilationMode.VLLM_COMPILE, True),
|
||||
("RocmAttn", "PIECEWISE", CompilationMode.NONE, False),
|
||||
("RocmAttn", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
||||
("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
|
||||
("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
||||
("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
|
||||
("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
|
||||
("RocmAttn", "NONE", CompilationMode.NONE, True),
|
||||
("RocmAttn", "NONE", CompilationMode.VLLM_COMPILE, True),
|
||||
]
|
||||
else:
|
||||
combo_cases_2 = [
|
||||
("FA2", "FULL", CompilationMode.NONE, True),
|
||||
("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
|
||||
("FA2", "PIECEWISE", CompilationMode.NONE, False),
|
||||
("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
||||
("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
|
||||
("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
||||
("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
|
||||
("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
|
||||
("FA2", "NONE", CompilationMode.NONE, True),
|
||||
("FA2", "NONE", CompilationMode.VLLM_COMPILE, True),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user