diff --git a/Dockerfile b/Dockerfile index 2306c231..6a5b426f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -57,6 +57,14 @@ ARG VLLM_CORE_DIR=/usr/local/lib/python3.12/dist-packages/vllm/v1/core COPY vllm/patches/patch_kv_cache_utils.py /tmp/patch_kv_cache_utils.py RUN python3 /tmp/patch_kv_cache_utils.py ${VLLM_CORE_DIR}/kv_cache_utils.py && rm /tmp/patch_kv_cache_utils.py +# Patch SWA cache and Indexer cache for Blackwell (no FlashMLA alignment) +ARG VLLM_SPARSE_SWA_DIR=/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/mla +ARG VLLM_LAYERS_DIR2=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers +COPY vllm/patches/patch_swa_cache.py /tmp/patch_swa_cache.py +RUN python3 /tmp/patch_swa_cache.py ${VLLM_SPARSE_SWA_DIR}/sparse_swa.py && rm /tmp/patch_swa_cache.py +COPY vllm/patches/patch_indexer_cache.py /tmp/patch_indexer_cache.py +RUN python3 /tmp/patch_indexer_cache.py ${VLLM_LAYERS_DIR2}/deepseek_v4_attention.py && rm /tmp/patch_indexer_cache.py + # Register CuTeDSL kernel in vLLM's linear kernel selection ARG VLLM_LINEAR_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear COPY vllm/patches/register_cutedsl_kernel.py /tmp/register_cutedsl_kernel.py diff --git a/vllm/patches/patch_indexer_cache.py b/vllm/patches/patch_indexer_cache.py new file mode 100644 index 00000000..4c001f28 --- /dev/null +++ b/vllm/patches/patch_indexer_cache.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +"""Patch DeepseekV4IndexerCache on Blackwell: remove FlashMLA alignment. + +Same as patch_swa_cache but for the indexer cache class. +""" +import sys + +def patch(path): + with open(path, 'r') as f: + content = f.read() + + if "CLAWMINE_PATCH_INDEXER_CACHE" in content: + print("Already patched, skipping") + return + + # Patch the indexer cache's get_kv_cache_spec to remove FlashMLA alignment on Blackwell + old = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + # head_dim already carries the fp8 scale padding + # compress_ratio=1 for V3.2, >1 for DeepseekV4; both use the same cache layout. + return MLAAttentionSpec( + block_size=self.cache_config.block_size, + num_kv_heads=1, + head_size=self.head_dim, + dtype=self.dtype, + compress_ratio=self.compress_ratio, + # DeepseekV4 aligns indexer pages to FlashMLA's 576B so they can pack with + # the indexer's compressor state cache. V3.2 keeps the legacy layout. + alignment=576, + )""" + + new = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + # CLAWMINE_PATCH_INDEXER_CACHE: No FlashMLA alignment on Blackwell + from vllm.platforms import current_platform + _is_blackwell = ( + current_platform.get_device_capability() is not None + and current_platform.get_device_capability().major >= 10 + ) + return MLAAttentionSpec( + block_size=self.cache_config.block_size, + num_kv_heads=1, + head_size=self.head_dim, + dtype=self.dtype, + compress_ratio=self.compress_ratio, + alignment=None if _is_blackwell else 576, + )""" + + if old not in content: + print("ERROR: Could not find the code to patch in " + path) + sys.exit(1) + + content = content.replace(old, new) + + with open(path, 'w') as f: + f.write(content) + print("Patched DeepseekV4IndexerCache for Blackwell") + +if __name__ == "__main__": + patch(sys.argv[1]) diff --git a/vllm/patches/patch_swa_cache.py b/vllm/patches/patch_swa_cache.py new file mode 100644 index 00000000..2fa6ff10 --- /dev/null +++ b/vllm/patches/patch_swa_cache.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +"""Patch DeepseekV4SWACache on Blackwell: remove FlashMLA alignment and model_version. + +On Blackwell (SM100+), FlashMLA doesn't work. We use our own CSA/SDPA attention. +The SWA cache should use standard fp8 format (not fp8_ds_mla) and no FlashMLA alignment. +""" +import sys + +def patch(path): + with open(path, 'r') as f: + content = f.read() + + if "CLAWMINE_PATCH_SWA_CACHE" in content: + print("Already patched, skipping") + return + + # Patch the get_kv_cache_spec method to remove FlashMLA-specific values on Blackwell + old = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + return SlidingWindowMLASpec( + block_size=self.block_size, + num_kv_heads=1, + head_size=self.head_dim, + dtype=self.dtype, + sliding_window=self.window_size, + cache_dtype_str=self.cache_config.cache_dtype, + alignment=576, # NOTE: FlashMLA requires 576B alignment + model_version="deepseek_v4", + )""" + + new = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + # CLAWMINE_PATCH_SWA_CACHE: On Blackwell, no FlashMLA = no 576B alignment + # Use standard fp8 format (not fp8_ds_mla), no model_version override + from vllm.platforms import current_platform + _is_blackwell = ( + current_platform.get_device_capability() is not None + and current_platform.get_device_capability().major >= 10 + ) + if _is_blackwell: + return SlidingWindowMLASpec( + block_size=self.block_size, + num_kv_heads=1, + head_size=self.head_dim, + dtype=self.dtype, + sliding_window=self.window_size, + cache_dtype_str=self.cache_config.cache_dtype, + alignment=None, # No FlashMLA alignment on Blackwell + model_version=None, # Don't use 584B deepseek_v4 format + ) + return SlidingWindowMLASpec( + block_size=self.block_size, + num_kv_heads=1, + head_size=self.head_dim, + dtype=self.dtype, + sliding_window=self.window_size, + cache_dtype_str=self.cache_config.cache_dtype, + alignment=576, # NOTE: FlashMLA requires 576B alignment + model_version="deepseek_v4", + )""" + + if old not in content: + print("ERROR: Could not find the code to patch in " + path) + sys.exit(1) + + content = content.replace(old, new) + + with open(path, 'w') as f: + f.write(content) + print("Patched DeepseekV4SWACache for Blackwell") + +if __name__ == "__main__": + patch(sys.argv[1])