[New Model] DeepSeek-V3.2 (Rebased to Main) (#25896)
Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: Lucia Fang <fanglu@meta.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Lucia Fang <fanglu@meta.com> Co-authored-by: NickLucche <nlucches@redhat.com> Co-authored-by: Siyuan Fu <siyuanf@nvidia.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Xiaozhu Meng <mxz297@gmail.com> Co-authored-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
@@ -593,6 +593,119 @@ def test_concat_and_cache_mla(
|
||||
torch.testing.assert_close(kv_cache, ref_kv_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_concat_and_cache_ds_mla(
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
num_tokens: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
if dtype.itemsize != 2:
|
||||
pytest.skip("ds_mla only supports 16-bit input")
|
||||
kv_cache_dtype = "fp8_ds_mla"
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
total_slots = num_blocks * block_size
|
||||
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe = torch.randn(num_tokens,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)
|
||||
|
||||
scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
kv_cache = _create_mla_cache(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=torch.uint8,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
device=device)
|
||||
|
||||
ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
|
||||
tile_data = torch.zeros(128, dtype=dtype, device=device)
|
||||
|
||||
for i in range(num_tokens):
|
||||
slot = slot_mapping[i].item()
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
|
||||
ref_cache_slice = ref_cache[block_idx, block_offset]
|
||||
ref_cache_16bit = ref_cache_slice.view(dtype)
|
||||
ref_cache_32bit = ref_cache_slice.view(torch.float32)
|
||||
|
||||
kv_c_data = kv_c[i]
|
||||
for tile_idx in range(4):
|
||||
tile_start = tile_idx * 128
|
||||
tile_end = (tile_idx + 1) * 128
|
||||
tile_data[:] = kv_c_data[tile_start:tile_end]
|
||||
|
||||
# tile_scale = tile_data.amax().to(torch.float32) / 448.
|
||||
# NOTE: Using torch's amax() gives different results,
|
||||
# so this must be manually computed.
|
||||
tile_data_float = tile_data.to(torch.float32)
|
||||
manual_max = abs(tile_data_float[0])
|
||||
for j in range(1, 128):
|
||||
manual_max = max(manual_max, abs(tile_data_float[j]))
|
||||
tile_scale = manual_max / 448.
|
||||
|
||||
ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale
|
||||
|
||||
ops.convert_fp8(ref_cache_slice[tile_start:tile_end],
|
||||
tile_data,
|
||||
tile_scale.item(),
|
||||
kv_dtype="fp8")
|
||||
|
||||
for j in range(qk_rope_head_dim):
|
||||
ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.concat_and_cache_mla,
|
||||
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
|
||||
kv_cache_dtype, scale)
|
||||
|
||||
for i in range(num_tokens):
|
||||
slot = slot_mapping[i].item()
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
kv_cache_slice = kv_cache[block_idx, block_offset]
|
||||
ref_cache_slice = ref_cache[block_idx, block_offset]
|
||||
|
||||
kv_nope = kv_cache_slice[:kv_lora_rank]
|
||||
ref_nope = ref_cache_slice[:kv_lora_rank]
|
||||
kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank //
|
||||
4:kv_lora_rank // 4 + 4]
|
||||
ref_scales = ref_cache_slice.view(
|
||||
torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4]
|
||||
kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
|
||||
ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
|
||||
|
||||
torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
|
||||
torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
|
||||
torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
|
||||
Reference in New Issue
Block a user