fix: device mismatches in decode FMHA test — dec_pos must be on per-layer GPU
This commit is contained in:
@@ -353,23 +353,23 @@ def main():
|
||||
q = pl['q_b'](q_a)
|
||||
q = unweighted_rmsnorm(q).bfloat16()
|
||||
q_heads = q.reshape(T, n_h, hd)
|
||||
q_heads = _apply_rope(q_heads, dec_pos, *rope_caches[gpu][:2], rd)
|
||||
q_heads = _apply_rope(q_heads, dec_pos.to(dev), *rope_caches[gpu][:2], rd)
|
||||
|
||||
# 2. KV projection + cache
|
||||
kv = pl['kv'].run_from_quantized(x_quant_attn)
|
||||
kv_norm_w = layer_w[li].get(f"{pfx}.kv_norm.weight")
|
||||
if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
|
||||
kv_3d = kv.reshape(T, 1, hd)
|
||||
kv_3d = _apply_rope(kv_3d, dec_pos, *rope_caches[gpu][:2], rd)
|
||||
kv_3d = _apply_rope(kv_3d, dec_pos.to(dev), *rope_caches[gpu][:2], rd)
|
||||
kv_roped = kv_3d.reshape(T, hd)
|
||||
kc.append_swa(kv_roped, dec_pos)
|
||||
kc.append_swa(kv_roped, dec_pos.to(dev))
|
||||
|
||||
# 3. Compressor → compressed KV
|
||||
compressor = compressors.get(li)
|
||||
indexer = indexers.get(li)
|
||||
comp_pos, block_bias = None, None
|
||||
if compressor is not None and compressor.ratio > 0:
|
||||
comp_kv_fp32, comp_pos, block_bias = compressor.forward(x_normed, dec_pos)
|
||||
comp_kv_fp32, comp_pos, block_bias = compressor.forward(x_normed, dec_pos.to(dev))
|
||||
if comp_kv_fp32 is not None:
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
||||
@@ -381,13 +381,13 @@ def main():
|
||||
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
||||
kc.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16, comp_pos)
|
||||
if compressor.is_csa and indexer is not None and indexer.compressor is not None:
|
||||
comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, dec_pos)
|
||||
comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, dec_pos.to(dev))
|
||||
kc.set_indexer_keys_fp8(comp_idx_kv)
|
||||
|
||||
# 4. Indexer top-k (CSA layers)
|
||||
topk_idx = None
|
||||
if indexer is not None and ratio == 4:
|
||||
topk_idx = indexer.forward(q_a, x_normed, kc, dec_pos, layer_idx=li)
|
||||
topk_idx = indexer.forward(q_a, x_normed, kc, dec_pos.to(dev), layer_idx=li
|
||||
if topk_idx is not None:
|
||||
print(f" L{li} CSA: indexer topk shape={tuple(topk_idx.shape)} "
|
||||
f"range=[{topk_idx.min().item()}, {topk_idx.max().item()}] "
|
||||
@@ -476,7 +476,7 @@ def main():
|
||||
# ---- Continue through the rest of the layer (so subsequent layers get correct X) ----
|
||||
# Apply inverse RoPE to production output
|
||||
attn_out = o_prod.permute(1, 0, 2) # (T, n_h, hd)
|
||||
attn_out = _apply_rope(attn_out, dec_pos, *rope_caches[gpu][:2], rd, inverse=True)
|
||||
attn_out = _apply_rope(attn_out, dec_pos.to(dev), *rope_caches[gpu][:2], rd, inverse=True)
|
||||
|
||||
# Output projection
|
||||
wo_a_lin = pl.get('o_a')
|
||||
|
||||
Reference in New Issue
Block a user