Add debug logging to Blackwell attention path

This commit is contained in:
2026-05-19 16:31:55 +00:00
parent d7f686bcfc
commit b8e2cf61ad

View File

@@ -659,6 +659,16 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
self._swa_inv_scale_cache = torch.zeros(
max_slots, 1, dtype=torch.bfloat16, device=kv.device,
)
# Debug: log cache shape info
import sys
print(f"[BLACKWELL] swa_kv_cache shape: {swa_kv_cache.shape}, "
f"block_size: {swa_metadata.block_size}, "
f"num_decode_tokens: {num_decode_tokens}, "
f"num_prefills: {num_prefills}, "
f"compress_ratio: {self.compress_ratio}, "
f"slot_mapping shape: {swa_metadata.slot_mapping.shape}, "
f"positions shape: {positions.shape}, "
f"kv shape: {kv.shape}", file=sys.stderr, flush=True)
blackwell_attention_kv_write(
kv, positions, swa_kv_cache, self._swa_inv_scale_cache,
swa_metadata.slot_mapping, swa_metadata.block_size,
@@ -687,6 +697,8 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
# ── Decode attention ──────────────────────────────────────
if num_decode_tokens > 0:
import sys
print(f"[BLACKWELL] DECODE: {num_decode_tokens} tokens, swa_only={swa_only}", file=sys.stderr, flush=True)
if swa_only:
# SWA-only layers: full decode attention with KV cache
q_decode = q[:num_decode_tokens]
@@ -722,6 +734,8 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
# ── Prefill attention ─────────────────────────────────────
if num_prefills > 0:
import sys
print(f"[BLACKWELL] PREFILL: {num_prefills} tokens, swa_only={swa_only}", file=sys.stderr, flush=True)
q_prefill = q[num_decode_tokens:]
kv_rope_prefill = self._apply_rope_kv(
kv[num_decode_tokens:], positions[num_decode_tokens:],