[MRV2] Fix for DS v3.2 (#38030)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-03-24 14:03:24 -07:00
committed by GitHub
parent 4e824d1c83
commit 4b53740d7f

View File

@@ -115,9 +115,12 @@ def _reshape_kv_cache(
) -> dict[str, torch.Tensor]:
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
assert isinstance(kv_cache_spec, AttentionSpec)
for layer_name in kv_cache_group_spec.layer_names:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
kv_cache_spec = kv_cache_spec.kv_cache_specs[layer_name]
assert isinstance(kv_cache_spec, AttentionSpec)
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes