Fix Blackwell: skip FlashMLA assertion + force CuTeDSL kernel

1. DeepseekV4MLAAttention.__init__ had a hard assertion that the
   attention backend MUST be FlashMLA. On Blackwell, FlashMLA doesn't
   work but we bypass it via _attention_impl_blackwell(). Added
   _is_blackwell flag to skip FlashMLA-specific init (fp8_ds_mla
   cache format conversion).

2. Added VLLM_NVFP4_GEMM_BACKEND=cutedsl env var to docker-compose.yml
   to force CuTeDSL kernel selection for NVFP4 linear layers.

3. Updated register_cutedsl_kernel.py to also register CuTeDSL in
   _NVFP4_BACKEND_TO_KERNEL dict (for the env var override path).
This commit is contained in:
2026-05-19 08:19:23 +00:00
parent 2856323360
commit e1a642452a
3 changed files with 15 additions and 1 deletions

View File

@@ -11,6 +11,7 @@ services:
- PYTHONUNBUFFERED=1
- VLLM_RPC_TIMEOUT_MS=600000
- CLAWMINE_DEBUG=1
- VLLM_NVFP4_GEMM_BACKEND=cutedsl
command:
- /model
- --trust-remote-code

View File

@@ -865,9 +865,17 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
assert issubclass(self.get_attn_backend(), FlashMLASparseBackend), (
"Only FlashMLA Sparse Attention backend is supported for DeepseekV4 for now"
)
# On Blackwell (SM100+), FlashMLA kernels don't work, but we bypass
# them entirely in _attention_impl_blackwell(). Skip the FP8 ds_mla
# cache conversion since our Blackwell path doesn't use FlashMLA.
_is_blackwell = (
current_platform.get_device_capability() is not None
and current_platform.get_device_capability().major >= 10
)
# FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
if (
# On Blackwell, we use our own attention path, so keep standard fp8
if not _is_blackwell and (
issubclass(self.get_attn_backend(), FlashMLASparseBackend)
and kv_cache_dtype.startswith("fp8")
and kv_cache_dtype != "fp8_ds_mla"

View File

@@ -33,6 +33,11 @@ def patch_init(path):
new = " PlatformEnum.CUDA: [\n CuTeDSLNvFp4LinearKernel,\n FlashInferCutlassNvFp4LinearKernel,"
content = content.replace(old, new)
# Also add to _NVFP4_BACKEND_TO_KERNEL so VLLM_NVFP4_GEMM_BACKEND=cutedsl works
old_backend = ' "emulation": EmulationNvFp4LinearKernel,\n}'
new_backend = ' "emulation": EmulationNvFp4LinearKernel,\n "cutedsl": CuTeDSLNvFp4LinearKernel,\n}'
content = content.replace(old_backend, new_backend)
with open(path, 'w') as f:
f.write(content)
print("Patched CuTeDSL NVFP4 kernel into", path)