Patch SWA and Indexer cache specs for Blackwell (no FlashMLA alignment)
This commit is contained in:
@@ -57,6 +57,14 @@ ARG VLLM_CORE_DIR=/usr/local/lib/python3.12/dist-packages/vllm/v1/core
|
||||
COPY vllm/patches/patch_kv_cache_utils.py /tmp/patch_kv_cache_utils.py
|
||||
RUN python3 /tmp/patch_kv_cache_utils.py ${VLLM_CORE_DIR}/kv_cache_utils.py && rm /tmp/patch_kv_cache_utils.py
|
||||
|
||||
# Patch SWA cache and Indexer cache for Blackwell (no FlashMLA alignment)
|
||||
ARG VLLM_SPARSE_SWA_DIR=/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/mla
|
||||
ARG VLLM_LAYERS_DIR2=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers
|
||||
COPY vllm/patches/patch_swa_cache.py /tmp/patch_swa_cache.py
|
||||
RUN python3 /tmp/patch_swa_cache.py ${VLLM_SPARSE_SWA_DIR}/sparse_swa.py && rm /tmp/patch_swa_cache.py
|
||||
COPY vllm/patches/patch_indexer_cache.py /tmp/patch_indexer_cache.py
|
||||
RUN python3 /tmp/patch_indexer_cache.py ${VLLM_LAYERS_DIR2}/deepseek_v4_attention.py && rm /tmp/patch_indexer_cache.py
|
||||
|
||||
# Register CuTeDSL kernel in vLLM's linear kernel selection
|
||||
ARG VLLM_LINEAR_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear
|
||||
COPY vllm/patches/register_cutedsl_kernel.py /tmp/register_cutedsl_kernel.py
|
||||
|
||||
58
vllm/patches/patch_indexer_cache.py
Normal file
58
vllm/patches/patch_indexer_cache.py
Normal file
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Patch DeepseekV4IndexerCache on Blackwell: remove FlashMLA alignment.
|
||||
|
||||
Same as patch_swa_cache but for the indexer cache class.
|
||||
"""
|
||||
import sys
|
||||
|
||||
def patch(path):
|
||||
with open(path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
if "CLAWMINE_PATCH_INDEXER_CACHE" in content:
|
||||
print("Already patched, skipping")
|
||||
return
|
||||
|
||||
# Patch the indexer cache's get_kv_cache_spec to remove FlashMLA alignment on Blackwell
|
||||
old = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# head_dim already carries the fp8 scale padding
|
||||
# compress_ratio=1 for V3.2, >1 for DeepseekV4; both use the same cache layout.
|
||||
return MLAAttentionSpec(
|
||||
block_size=self.cache_config.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_dim,
|
||||
dtype=self.dtype,
|
||||
compress_ratio=self.compress_ratio,
|
||||
# DeepseekV4 aligns indexer pages to FlashMLA's 576B so they can pack with
|
||||
# the indexer's compressor state cache. V3.2 keeps the legacy layout.
|
||||
alignment=576,
|
||||
)"""
|
||||
|
||||
new = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# CLAWMINE_PATCH_INDEXER_CACHE: No FlashMLA alignment on Blackwell
|
||||
from vllm.platforms import current_platform
|
||||
_is_blackwell = (
|
||||
current_platform.get_device_capability() is not None
|
||||
and current_platform.get_device_capability().major >= 10
|
||||
)
|
||||
return MLAAttentionSpec(
|
||||
block_size=self.cache_config.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_dim,
|
||||
dtype=self.dtype,
|
||||
compress_ratio=self.compress_ratio,
|
||||
alignment=None if _is_blackwell else 576,
|
||||
)"""
|
||||
|
||||
if old not in content:
|
||||
print("ERROR: Could not find the code to patch in " + path)
|
||||
sys.exit(1)
|
||||
|
||||
content = content.replace(old, new)
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
print("Patched DeepseekV4IndexerCache for Blackwell")
|
||||
|
||||
if __name__ == "__main__":
|
||||
patch(sys.argv[1])
|
||||
71
vllm/patches/patch_swa_cache.py
Normal file
71
vllm/patches/patch_swa_cache.py
Normal file
@@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Patch DeepseekV4SWACache on Blackwell: remove FlashMLA alignment and model_version.
|
||||
|
||||
On Blackwell (SM100+), FlashMLA doesn't work. We use our own CSA/SDPA attention.
|
||||
The SWA cache should use standard fp8 format (not fp8_ds_mla) and no FlashMLA alignment.
|
||||
"""
|
||||
import sys
|
||||
|
||||
def patch(path):
|
||||
with open(path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
if "CLAWMINE_PATCH_SWA_CACHE" in content:
|
||||
print("Already patched, skipping")
|
||||
return
|
||||
|
||||
# Patch the get_kv_cache_spec method to remove FlashMLA-specific values on Blackwell
|
||||
old = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
return SlidingWindowMLASpec(
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_dim,
|
||||
dtype=self.dtype,
|
||||
sliding_window=self.window_size,
|
||||
cache_dtype_str=self.cache_config.cache_dtype,
|
||||
alignment=576, # NOTE: FlashMLA requires 576B alignment
|
||||
model_version="deepseek_v4",
|
||||
)"""
|
||||
|
||||
new = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# CLAWMINE_PATCH_SWA_CACHE: On Blackwell, no FlashMLA = no 576B alignment
|
||||
# Use standard fp8 format (not fp8_ds_mla), no model_version override
|
||||
from vllm.platforms import current_platform
|
||||
_is_blackwell = (
|
||||
current_platform.get_device_capability() is not None
|
||||
and current_platform.get_device_capability().major >= 10
|
||||
)
|
||||
if _is_blackwell:
|
||||
return SlidingWindowMLASpec(
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_dim,
|
||||
dtype=self.dtype,
|
||||
sliding_window=self.window_size,
|
||||
cache_dtype_str=self.cache_config.cache_dtype,
|
||||
alignment=None, # No FlashMLA alignment on Blackwell
|
||||
model_version=None, # Don't use 584B deepseek_v4 format
|
||||
)
|
||||
return SlidingWindowMLASpec(
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_dim,
|
||||
dtype=self.dtype,
|
||||
sliding_window=self.window_size,
|
||||
cache_dtype_str=self.cache_config.cache_dtype,
|
||||
alignment=576, # NOTE: FlashMLA requires 576B alignment
|
||||
model_version="deepseek_v4",
|
||||
)"""
|
||||
|
||||
if old not in content:
|
||||
print("ERROR: Could not find the code to patch in " + path)
|
||||
sys.exit(1)
|
||||
|
||||
content = content.replace(old, new)
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
print("Patched DeepseekV4SWACache for Blackwell")
|
||||
|
||||
if __name__ == "__main__":
|
||||
patch(sys.argv[1])
|
||||
Reference in New Issue
Block a user