Fix Blackwell: skip FlashMLA assertion + force CuTeDSL kernel
1. DeepseekV4MLAAttention.__init__ had a hard assertion that the attention backend MUST be FlashMLA. On Blackwell, FlashMLA doesn't work but we bypass it via _attention_impl_blackwell(). Added _is_blackwell flag to skip FlashMLA-specific init (fp8_ds_mla cache format conversion). 2. Added VLLM_NVFP4_GEMM_BACKEND=cutedsl env var to docker-compose.yml to force CuTeDSL kernel selection for NVFP4 linear layers. 3. Updated register_cutedsl_kernel.py to also register CuTeDSL in _NVFP4_BACKEND_TO_KERNEL dict (for the env var override path).
This commit is contained in:
@@ -11,6 +11,7 @@ services:
|
||||
- PYTHONUNBUFFERED=1
|
||||
- VLLM_RPC_TIMEOUT_MS=600000
|
||||
- CLAWMINE_DEBUG=1
|
||||
- VLLM_NVFP4_GEMM_BACKEND=cutedsl
|
||||
command:
|
||||
- /model
|
||||
- --trust-remote-code
|
||||
|
||||
@@ -865,9 +865,17 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
|
||||
assert issubclass(self.get_attn_backend(), FlashMLASparseBackend), (
|
||||
"Only FlashMLA Sparse Attention backend is supported for DeepseekV4 for now"
|
||||
)
|
||||
# On Blackwell (SM100+), FlashMLA kernels don't work, but we bypass
|
||||
# them entirely in _attention_impl_blackwell(). Skip the FP8 ds_mla
|
||||
# cache conversion since our Blackwell path doesn't use FlashMLA.
|
||||
_is_blackwell = (
|
||||
current_platform.get_device_capability() is not None
|
||||
and current_platform.get_device_capability().major >= 10
|
||||
)
|
||||
# FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format
|
||||
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
|
||||
if (
|
||||
# On Blackwell, we use our own attention path, so keep standard fp8
|
||||
if not _is_blackwell and (
|
||||
issubclass(self.get_attn_backend(), FlashMLASparseBackend)
|
||||
and kv_cache_dtype.startswith("fp8")
|
||||
and kv_cache_dtype != "fp8_ds_mla"
|
||||
|
||||
@@ -33,6 +33,11 @@ def patch_init(path):
|
||||
new = " PlatformEnum.CUDA: [\n CuTeDSLNvFp4LinearKernel,\n FlashInferCutlassNvFp4LinearKernel,"
|
||||
content = content.replace(old, new)
|
||||
|
||||
# Also add to _NVFP4_BACKEND_TO_KERNEL so VLLM_NVFP4_GEMM_BACKEND=cutedsl works
|
||||
old_backend = ' "emulation": EmulationNvFp4LinearKernel,\n}'
|
||||
new_backend = ' "emulation": EmulationNvFp4LinearKernel,\n "cutedsl": CuTeDSLNvFp4LinearKernel,\n}'
|
||||
content = content.replace(old_backend, new_backend)
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
print("Patched CuTeDSL NVFP4 kernel into", path)
|
||||
|
||||
Reference in New Issue
Block a user