Files
nvfp4-megamoe-kernel/vllm/patches/patch_kv_cache_utils.py

136 lines
6.1 KiB
Python

#!/usr/bin/env python3
"""Patch vLLM kv_cache_utils.py to handle DeepseekV4 SWA page sizes.
DeepseekV4 has three cache types:
- C128A (HCA): compress_ratio=128, very small page size
- C4A (CSA): compress_ratio=4, medium page size
- SWA: compress_ratio=1, large page size
The upstream code assumes SWA page sizes <= MLA page sizes and pads
SWA pages to match MLA. This breaks when SWA pages are LARGER than
MLA pages (which is always the case for DeepseekV4).
Our fix: when SWA pages exceed MLA pages, put them in their own
separate cache group without padding.
"""
import sys
def patch(path):
with open(path, 'r') as f:
content = f.read()
if "CLAWMINE_PATCH_KV_CACHE" in content:
print("Already patched, skipping")
return
# The old code: asserts SWA pages <= MLA pages, then pads SWA to MLA
old = """ assert max(sm_page_sizes) <= max(all_page_sizes)
# Unify page size by padding layers' page_size to the nearest larger page_size.
# Compute candidate (nearest larger page_size) for each unique page size.
size_to_candidate: dict[int, int] = {}
for ps in sm_page_sizes:
size_to_candidate[ps] = min(x for x in all_page_sizes if x >= ps)
# Pad and collect layer names per page size.
for layer_name, layer_spec in sm_spec.kv_cache_specs.items():
current_size = layer_spec.page_size_bytes
candidate = size_to_candidate[current_size]
if current_size < candidate:
object.__setattr__(layer_spec, "page_size_padded", candidate)
layers_per_size[candidate].append(layer_name)
# NOTE(yifan): for now, inside a UniformKV group, each page_size should
# have the same number of layers. This also means we don't need to pad layers
# inside a partial-full layer tuple.
assert len(set(len(layers) for layers in layers_per_size.values())) == 1
num_layers_per_size = len(next(iter(layers_per_size.values())))
# Split layers inside each UniformKV group for aligned #(layers).
# See `_get_kv_cache_groups_uniform_page_size` for more details.
num_tuple_groups = cdiv(num_layers_per_size, num_layer_tuples)
layer_tuples = list(zip(*layers_per_size.values()))
for i in range(num_tuple_groups):
group_layer_tuples = layer_tuples[i::num_tuple_groups]
# Flatten tuples and build dict for from_specs
group_layer_names = [
name for layer_tuple in group_layer_tuples for name in layer_tuple
]
group_layer_specs = {
name: sm_spec.kv_cache_specs[name] for name in group_layer_names
}
sub_sm_spec = UniformTypeKVCacheSpecs.from_specs(group_layer_specs)
assert sub_sm_spec is not None
swa_mla_groups.append(
KVCacheGroupSpec(
layer_names=group_layer_names,
kv_cache_spec=sub_sm_spec,
)
)"""
# The new code: handle both cases
new = """ # CLAWMINE_PATCH_KV_CACHE: Handle DeepseekV4 where SWA page sizes
# can be larger than MLA page sizes. Two cases:
# 1. All SWA pages <= some MLA page: original padding logic
# 2. Some SWA pages > all MLA pages: separate cache group, no padding
max_mla_page = max(all_page_sizes)
can_pad = max(sm_page_sizes) <= max_mla_page
if can_pad:
# Original logic: pad SWA pages to nearest MLA page
size_to_candidate: dict[int, int] = {}
for ps in sm_page_sizes:
size_to_candidate[ps] = min(x for x in all_page_sizes if x >= ps)
for layer_name, layer_spec in sm_spec.kv_cache_specs.items():
current_size = layer_spec.page_size_bytes
candidate = size_to_candidate[current_size]
if current_size < candidate:
object.__setattr__(layer_spec, "page_size_padded", candidate)
layers_per_size[candidate].append(layer_name)
assert len(set(len(layers) for layers in layers_per_size.values())) == 1
num_layers_per_size = len(next(iter(layers_per_size.values())))
num_tuple_groups = cdiv(num_layers_per_size, num_layer_tuples)
layer_tuples = list(zip(*layers_per_size.values()))
for i in range(num_tuple_groups):
group_layer_tuples = layer_tuples[i::num_tuple_groups]
group_layer_names = [
name for layer_tuple in group_layer_tuples for name in layer_tuple
]
group_layer_specs = {
name: sm_spec.kv_cache_specs[name] for name in group_layer_names
}
sub_sm_spec = UniformTypeKVCacheSpecs.from_specs(group_layer_specs)
assert sub_sm_spec is not None
swa_mla_groups.append(
KVCacheGroupSpec(
layer_names=group_layer_names,
kv_cache_spec=sub_sm_spec,
)
)
else:
# SWA pages are larger than MLA pages.
# Put each SWA layer in its own cache group (no padding needed).
# This is the DeepseekV4 Blackwell case where compress_ratio=1
# layers have much larger pages than compressed layers.
for layer_name, layer_spec in sm_spec.kv_cache_specs.items():
group_layer_specs = {layer_name: layer_spec}
sub_sm_spec = UniformTypeKVCacheSpecs.from_specs(group_layer_specs)
if sub_sm_spec is not None:
swa_mla_groups.append(
KVCacheGroupSpec(
layer_names=[layer_name],
kv_cache_spec=sub_sm_spec,
)
)"""
if old not in content:
print("ERROR: Could not find the code to patch")
sys.exit(1)
content = content.replace(old, new)
with open(path, 'w') as f:
f.write(content)
print("Patched kv_cache_utils.py for DeepseekV4 SWA page sizes")
if __name__ == "__main__":
patch(sys.argv[1])