From e1a642452a7a2b183e0629488e488f1f0f5352b2 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 08:19:23 +0000 Subject: [PATCH] Fix Blackwell: skip FlashMLA assertion + force CuTeDSL kernel 1. DeepseekV4MLAAttention.__init__ had a hard assertion that the attention backend MUST be FlashMLA. On Blackwell, FlashMLA doesn't work but we bypass it via _attention_impl_blackwell(). Added _is_blackwell flag to skip FlashMLA-specific init (fp8_ds_mla cache format conversion). 2. Added VLLM_NVFP4_GEMM_BACKEND=cutedsl env var to docker-compose.yml to force CuTeDSL kernel selection for NVFP4 linear layers. 3. Updated register_cutedsl_kernel.py to also register CuTeDSL in _NVFP4_BACKEND_TO_KERNEL dict (for the env var override path). --- docker-compose.yml | 1 + vllm/patches/deepseek_v4_attention.py | 10 +++++++++- vllm/patches/register_cutedsl_kernel.py | 5 +++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/docker-compose.yml b/docker-compose.yml index b5959a3b..6f624c82 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,6 +11,7 @@ services: - PYTHONUNBUFFERED=1 - VLLM_RPC_TIMEOUT_MS=600000 - CLAWMINE_DEBUG=1 + - VLLM_NVFP4_GEMM_BACKEND=cutedsl command: - /model - --trust-remote-code diff --git a/vllm/patches/deepseek_v4_attention.py b/vllm/patches/deepseek_v4_attention.py index 6000197c..501dba63 100644 --- a/vllm/patches/deepseek_v4_attention.py +++ b/vllm/patches/deepseek_v4_attention.py @@ -865,9 +865,17 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): assert issubclass(self.get_attn_backend(), FlashMLASparseBackend), ( "Only FlashMLA Sparse Attention backend is supported for DeepseekV4 for now" ) + # On Blackwell (SM100+), FlashMLA kernels don't work, but we bypass + # them entirely in _attention_impl_blackwell(). Skip the FP8 ds_mla + # cache conversion since our Blackwell path doesn't use FlashMLA. + _is_blackwell = ( + current_platform.get_device_capability() is not None + and current_platform.get_device_capability().major >= 10 + ) # FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format # Automatically convert fp8 kv-cache format to "fp8_ds_mla" - if ( + # On Blackwell, we use our own attention path, so keep standard fp8 + if not _is_blackwell and ( issubclass(self.get_attn_backend(), FlashMLASparseBackend) and kv_cache_dtype.startswith("fp8") and kv_cache_dtype != "fp8_ds_mla" diff --git a/vllm/patches/register_cutedsl_kernel.py b/vllm/patches/register_cutedsl_kernel.py index 8951949a..681fc34e 100644 --- a/vllm/patches/register_cutedsl_kernel.py +++ b/vllm/patches/register_cutedsl_kernel.py @@ -33,6 +33,11 @@ def patch_init(path): new = " PlatformEnum.CUDA: [\n CuTeDSLNvFp4LinearKernel,\n FlashInferCutlassNvFp4LinearKernel," content = content.replace(old, new) + # Also add to _NVFP4_BACKEND_TO_KERNEL so VLLM_NVFP4_GEMM_BACKEND=cutedsl works + old_backend = ' "emulation": EmulationNvFp4LinearKernel,\n}' + new_backend = ' "emulation": EmulationNvFp4LinearKernel,\n "cutedsl": CuTeDSLNvFp4LinearKernel,\n}' + content = content.replace(old_backend, new_backend) + with open(path, 'w') as f: f.write(content) print("Patched CuTeDSL NVFP4 kernel into", path)