diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 70c622fc0..f1b95df00 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -57,7 +57,7 @@ def init_attn_backend( ) attn_metadata_builders.append(attn_metadata_builder) # type: ignore - if "FLASHINFER" in attn_backend.get_name(): + if attn_backend.get_name() == "FLASHINFER": if flashinfer_workspace is None: flashinfer_workspace = attn_metadata_builder._get_workspace_buffer() else: diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 931780bbd..898d64879 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -248,7 +248,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) # TODO(woosuk): Support other backends. - supported_backends = ("FLASH_ATTN", "FLASHINFER") + supported_backends = ("FLASH_ATTN", "FLASHINFER", "FLASHINFER_MLA") for backend in self.attn_backends.values(): backend_name = backend.get_name() if backend_name not in supported_backends: