diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index f021df56c..48582f4f6 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -293,6 +293,48 @@ def test_invalid_backend(): AttentionConfig(backend=AttentionBackendEnum["INVALID"]) +@pytest.mark.parametrize("auto_value", ["auto", "AUTO", "Auto"]) +def test_auto_backend_string(auto_value: str): + """Test that 'auto' string value triggers automatic backend selection.""" + # Using "auto" should result in backend=None (automatic selection) + attention_config = AttentionConfig(backend=auto_value) + assert attention_config.backend is None + + +def test_auto_backend_selection_behavior(): + """Test that 'auto' backend behaves same as None (automatic selection).""" + # Create config with explicit "auto" + auto_config = AttentionConfig(backend="auto") + + # Create config with None (default) + none_config = AttentionConfig(backend=None) + + # Both should have backend=None + assert auto_config.backend is None + assert none_config.backend is None + + # Both configs should result in the same automatic backend selection + vllm_config_auto = VllmConfig(attention_config=auto_config) + vllm_config_none = VllmConfig(attention_config=none_config) + + with ( + set_current_vllm_config(vllm_config_auto), + patch("vllm.platforms.current_platform", CpuPlatform()), + ): + backend_auto = get_attn_backend(16, torch.float16, None, 16) + + _cached_get_attn_backend.cache_clear() + + with ( + set_current_vllm_config(vllm_config_none), + patch("vllm.platforms.current_platform", CpuPlatform()), + ): + backend_none = get_attn_backend(16, torch.float16, None, 16) + + # Both should select the same backend + assert backend_auto.get_name() == backend_none.get_name() + + @pytest.mark.parametrize( "backend_name,flash_attn_version,should_succeed", [ diff --git a/vllm/config/attention.py b/vllm/config/attention.py index 74bb3d68f..e05544f08 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -14,7 +14,7 @@ class AttentionConfig: """Configuration for attention mechanisms in vLLM.""" backend: AttentionBackendEnum | None = None - """Attention backend to use. If None, will be selected automatically.""" + """Attention backend to use. Use "auto" or None for automatic selection.""" flash_attn_version: Literal[2, 3, 4] | None = None """Force vllm to use a specific flash-attention version (2, 3, or 4). @@ -63,7 +63,13 @@ class AttentionConfig: @field_validator("backend", mode="before") @classmethod def validate_backend_before(cls, value: Any) -> Any: - """Enable parsing of the `backend` enum type from string.""" + """Enable parsing of the `backend` enum type from string. + + The special value "auto" is treated as None, which triggers + automatic backend selection. + """ if isinstance(value, str): + if value.lower() == "auto": + return None return AttentionBackendEnum[value.upper()] return value diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6d74e867b..93384fd78 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1816,13 +1816,10 @@ class EngineArgs: "attention_backend and attention_config.backend " "are mutually exclusive" ) - # Convert string to enum if needed (CLI parsing returns a string) - if isinstance(self.attention_backend, str): - attention_config.backend = AttentionBackendEnum[ - self.attention_backend.upper() - ] - else: - attention_config.backend = self.attention_backend + # Reuse the validator to handle "auto" and string-to-enum conversion + attention_config.backend = AttentionConfig.validate_backend_before( + self.attention_backend + ) # Kernel config overrides kernel_config = copy.deepcopy(self.kernel_config)