Add debug logging to Blackwell attention path
This commit is contained in:
@@ -659,6 +659,16 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
self._swa_inv_scale_cache = torch.zeros(
|
||||
max_slots, 1, dtype=torch.bfloat16, device=kv.device,
|
||||
)
|
||||
# Debug: log cache shape info
|
||||
import sys
|
||||
print(f"[BLACKWELL] swa_kv_cache shape: {swa_kv_cache.shape}, "
|
||||
f"block_size: {swa_metadata.block_size}, "
|
||||
f"num_decode_tokens: {num_decode_tokens}, "
|
||||
f"num_prefills: {num_prefills}, "
|
||||
f"compress_ratio: {self.compress_ratio}, "
|
||||
f"slot_mapping shape: {swa_metadata.slot_mapping.shape}, "
|
||||
f"positions shape: {positions.shape}, "
|
||||
f"kv shape: {kv.shape}", file=sys.stderr, flush=True)
|
||||
blackwell_attention_kv_write(
|
||||
kv, positions, swa_kv_cache, self._swa_inv_scale_cache,
|
||||
swa_metadata.slot_mapping, swa_metadata.block_size,
|
||||
@@ -687,6 +697,8 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
|
||||
# ── Decode attention ──────────────────────────────────────
|
||||
if num_decode_tokens > 0:
|
||||
import sys
|
||||
print(f"[BLACKWELL] DECODE: {num_decode_tokens} tokens, swa_only={swa_only}", file=sys.stderr, flush=True)
|
||||
if swa_only:
|
||||
# SWA-only layers: full decode attention with KV cache
|
||||
q_decode = q[:num_decode_tokens]
|
||||
@@ -722,6 +734,8 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
|
||||
# ── Prefill attention ─────────────────────────────────────
|
||||
if num_prefills > 0:
|
||||
import sys
|
||||
print(f"[BLACKWELL] PREFILL: {num_prefills} tokens, swa_only={swa_only}", file=sys.stderr, flush=True)
|
||||
q_prefill = q[num_decode_tokens:]
|
||||
kv_rope_prefill = self._apply_rope_kv(
|
||||
kv[num_decode_tokens:], positions[num_decode_tokens:],
|
||||
|
||||
Reference in New Issue
Block a user