diff --git a/tests/unit/test_decode_fmha_layer.py b/tests/unit/test_decode_fmha_layer.py index 5e43c3bd..7ba9abc0 100644 --- a/tests/unit/test_decode_fmha_layer.py +++ b/tests/unit/test_decode_fmha_layer.py @@ -353,23 +353,23 @@ def main(): q = pl['q_b'](q_a) q = unweighted_rmsnorm(q).bfloat16() q_heads = q.reshape(T, n_h, hd) - q_heads = _apply_rope(q_heads, dec_pos, *rope_caches[gpu][:2], rd) + q_heads = _apply_rope(q_heads, dec_pos.to(dev), *rope_caches[gpu][:2], rd) # 2. KV projection + cache kv = pl['kv'].run_from_quantized(x_quant_attn) kv_norm_w = layer_w[li].get(f"{pfx}.kv_norm.weight") if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32)) kv_3d = kv.reshape(T, 1, hd) - kv_3d = _apply_rope(kv_3d, dec_pos, *rope_caches[gpu][:2], rd) + kv_3d = _apply_rope(kv_3d, dec_pos.to(dev), *rope_caches[gpu][:2], rd) kv_roped = kv_3d.reshape(T, hd) - kc.append_swa(kv_roped, dec_pos) + kc.append_swa(kv_roped, dec_pos.to(dev)) # 3. Compressor → compressed KV compressor = compressors.get(li) indexer = indexers.get(li) comp_pos, block_bias = None, None if compressor is not None and compressor.ratio > 0: - comp_kv_fp32, comp_pos, block_bias = compressor.forward(x_normed, dec_pos) + comp_kv_fp32, comp_pos, block_bias = compressor.forward(x_normed, dec_pos.to(dev)) if comp_kv_fp32 is not None: from dsv4.kernels.cuda.loader import get_cuda_module kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"]) @@ -381,13 +381,13 @@ def main(): nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32) kc.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16, comp_pos) if compressor.is_csa and indexer is not None and indexer.compressor is not None: - comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, dec_pos) + comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, dec_pos.to(dev)) kc.set_indexer_keys_fp8(comp_idx_kv) # 4. Indexer top-k (CSA layers) topk_idx = None if indexer is not None and ratio == 4: - topk_idx = indexer.forward(q_a, x_normed, kc, dec_pos, layer_idx=li) + topk_idx = indexer.forward(q_a, x_normed, kc, dec_pos.to(dev), layer_idx=li if topk_idx is not None: print(f" L{li} CSA: indexer topk shape={tuple(topk_idx.shape)} " f"range=[{topk_idx.min().item()}, {topk_idx.max().item()}] " @@ -476,7 +476,7 @@ def main(): # ---- Continue through the rest of the layer (so subsequent layers get correct X) ---- # Apply inverse RoPE to production output attn_out = o_prod.permute(1, 0, 2) # (T, n_h, hd) - attn_out = _apply_rope(attn_out, dec_pos, *rope_caches[gpu][:2], rd, inverse=True) + attn_out = _apply_rope(attn_out, dec_pos.to(dev), *rope_caches[gpu][:2], rd, inverse=True) # Output projection wo_a_lin = pl.get('o_a')