Fix UnboundLocalError: move num_decode_tokens before debug print

This commit is contained in:
2026-05-19 16:43:28 +00:00
parent 76fff5fc8b
commit da6fa2f1d6

View File

@@ -653,13 +653,17 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
rope_dim=self.rope_head_dim,
)
# Split prefill and decode
num_decode_tokens = swa_metadata.num_decode_tokens
num_prefills = swa_metadata.num_prefill_tokens
swa_only = self.compress_ratio <= 1
# CRITICAL FIX: Write KV to paged cache (RoPE + fp8 quant + insert)
if not hasattr(self, '_swa_inv_scale_cache'):
max_slots = swa_kv_cache.shape[0] * swa_kv_cache.shape[1]
self._swa_inv_scale_cache = torch.zeros(
max_slots, 1, dtype=torch.bfloat16, device=kv.device,
)
# Debug: log cache shape info
import sys
print(f"[BLACKWELL] swa_kv_cache shape: {swa_kv_cache.shape}, "
f"block_size: {swa_metadata.block_size}, "
@@ -677,11 +681,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
rope_dim=self.rope_head_dim,
)
# Split prefill and decode
num_decode_tokens = swa_metadata.num_decode_tokens
num_prefills = swa_metadata.num_prefill_tokens
swa_only = self.compress_ratio <= 1
# Get compressed KV cache and indexer metadata for CSA/HCA
flashmla_metadata = None
if not swa_only: