fix: device mismatches in decode FMHA test — dec_pos must be on per-layer GPU

This commit is contained in:
2026-06-03 04:46:24 +00:00
parent e1d96c509d
commit 693975ec92

View File

@@ -353,23 +353,23 @@ def main():
q = pl['q_b'](q_a)
q = unweighted_rmsnorm(q).bfloat16()
q_heads = q.reshape(T, n_h, hd)
q_heads = _apply_rope(q_heads, dec_pos, *rope_caches[gpu][:2], rd)
q_heads = _apply_rope(q_heads, dec_pos.to(dev), *rope_caches[gpu][:2], rd)
# 2. KV projection + cache
kv = pl['kv'].run_from_quantized(x_quant_attn)
kv_norm_w = layer_w[li].get(f"{pfx}.kv_norm.weight")
if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
kv_3d = kv.reshape(T, 1, hd)
kv_3d = _apply_rope(kv_3d, dec_pos, *rope_caches[gpu][:2], rd)
kv_3d = _apply_rope(kv_3d, dec_pos.to(dev), *rope_caches[gpu][:2], rd)
kv_roped = kv_3d.reshape(T, hd)
kc.append_swa(kv_roped, dec_pos)
kc.append_swa(kv_roped, dec_pos.to(dev))
# 3. Compressor → compressed KV
compressor = compressors.get(li)
indexer = indexers.get(li)
comp_pos, block_bias = None, None
if compressor is not None and compressor.ratio > 0:
comp_kv_fp32, comp_pos, block_bias = compressor.forward(x_normed, dec_pos)
comp_kv_fp32, comp_pos, block_bias = compressor.forward(x_normed, dec_pos.to(dev))
if comp_kv_fp32 is not None:
from dsv4.kernels.cuda.loader import get_cuda_module
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
@@ -381,13 +381,13 @@ def main():
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
kc.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16, comp_pos)
if compressor.is_csa and indexer is not None and indexer.compressor is not None:
comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, dec_pos)
comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, dec_pos.to(dev))
kc.set_indexer_keys_fp8(comp_idx_kv)
# 4. Indexer top-k (CSA layers)
topk_idx = None
if indexer is not None and ratio == 4:
topk_idx = indexer.forward(q_a, x_normed, kc, dec_pos, layer_idx=li)
topk_idx = indexer.forward(q_a, x_normed, kc, dec_pos.to(dev), layer_idx=li
if topk_idx is not None:
print(f" L{li} CSA: indexer topk shape={tuple(topk_idx.shape)} "
f"range=[{topk_idx.min().item()}, {topk_idx.max().item()}] "
@@ -476,7 +476,7 @@ def main():
# ---- Continue through the rest of the layer (so subsequent layers get correct X) ----
# Apply inverse RoPE to production output
attn_out = o_prod.permute(1, 0, 2) # (T, n_h, hd)
attn_out = _apply_rope(attn_out, dec_pos, *rope_caches[gpu][:2], rd, inverse=True)
attn_out = _apply_rope(attn_out, dec_pos.to(dev), *rope_caches[gpu][:2], rd, inverse=True)
# Output projection
wo_a_lin = pl.get('o_a')